Coverage for apps/inners/use_cases/document_processor/category_document_processor.py: 56%
66 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
1import base64
2import os
3import uuid
4from typing import List, Tuple, Optional, Dict, Any
6import more_itertools
7from langchain_community.chat_models import ChatLiteLLM
8from langchain_core.documents import Document
9from unstructured.chunking.basic import chunk_elements
10from unstructured.documents.elements import Element, Text, NarrativeText, Table, Image, ListItem
12from apps.inners.models.dtos.document_category import DocumentCategory
13from apps.inners.models.dtos.element_category import ElementCategory
14from apps.inners.use_cases.document_processor.summary_document_processor import SummaryDocumentProcessor
17class CategoryDocumentProcessor:
18 def __init__(
19 self,
20 summary_document_processor: SummaryDocumentProcessor,
21 ):
22 self.summary_document_processor = summary_document_processor
24 async def categorize_elements(self, elements: List[Element]) -> ElementCategory:
25 categorized_elements: ElementCategory = ElementCategory(
26 texts=[],
27 tables=[],
28 images=[]
29 )
31 for element in elements:
32 if any(
33 element_type == element.__class__.__name__ for element_type in
34 [Text.__name__, NarrativeText.__name__, ListItem.__name__]
35 ):
36 categorized_elements.texts.append(element)
37 elif any(
38 element_type == element.__class__.__name__ for element_type in
39 [Table.__name__]
40 ):
41 categorized_elements.tables.append(element)
42 elif any(
43 element_type == element.__class__.__name__ for element_type in
44 [Image.__name__]
45 ):
46 file_io = open(element.metadata.image_path, "rb")
47 element.metadata.image_mime_type = "image/jpeg"
48 element.metadata.image_base64 = base64.b64encode(file_io.read()).decode()
49 file_io.close()
50 os.remove(element.metadata.image_path)
51 categorized_elements.images.append(element)
52 else:
53 print(f"BaseDocumentProcessor.categorize_elements: Ignoring element type {element.__class__.__name__}.")
55 return categorized_elements
57 async def get_categorized_documents(
58 self,
59 categorized_elements: ElementCategory,
60 summarization_model: ChatLiteLLM,
61 is_include_table: bool = False,
62 is_include_image: bool = False,
63 chunk_size: int = 400,
64 overlap_size: int = 50,
65 separators: Tuple[str, ...] = ("\n", " "),
66 id_key: str = "id",
67 metadata: Optional[Dict[str, Any]] = None,
68 ) -> DocumentCategory:
69 if metadata is None:
70 metadata = {}
72 document_category: DocumentCategory = DocumentCategory(
73 texts=[],
74 tables=[],
75 images=[],
76 id_key=id_key
77 )
78 chunked_texts: List[Element] = chunk_elements(
79 elements=categorized_elements.texts,
80 include_orig_elements=True,
81 max_characters=chunk_size - overlap_size,
82 )
83 if len(chunked_texts) < 2:
84 for text in chunked_texts:
85 orig_elements: List[Element] = text.metadata.orig_elements
86 orig_element_metadata: List[Dict[str, Any]] = [
87 orig_element.metadata.to_dict() for orig_element in orig_elements
88 ]
89 document: Document = Document(
90 page_content=text.text,
91 metadata={
92 id_key: str(uuid.uuid4()),
93 "category": "text",
94 "origin_metadata": orig_element_metadata,
95 **metadata
96 }
97 )
98 document_category.texts.append(document)
99 else:
100 for text_before, text_after in more_itertools.windowed(chunked_texts, n=2):
101 orig_elements: List[Element] = []
102 for orig_element in text_before.metadata.orig_elements + text_after.metadata.orig_elements:
103 if not any(orig_element.id == existing_orig_element.id for existing_orig_element in orig_elements):
104 orig_elements.append(orig_element)
105 orig_elements_metadata: Dict[str, Any] = {
106 "origin_metadata": [
107 orig_element.metadata.to_dict() for orig_element in orig_elements
108 ],
109 "category": "text"
110 }
111 last_index_of_separators: int = -1
112 for separator in separators:
113 last_index_of_separator = text_before.text.rfind(separator, 0, len(text_before.text) - overlap_size)
114 last_index_of_separators = max(last_index_of_separators, last_index_of_separator)
116 text = text_before.text[last_index_of_separators + 1:] + " " + text_after.text
117 document: Document = Document(
118 page_content=text,
119 metadata={
120 id_key: str(uuid.uuid4()),
121 **orig_elements_metadata,
122 **metadata,
123 }
124 )
125 document_category.texts.append(document)
127 if is_include_table:
128 summarized_tables: List[str] = await self.summary_document_processor.summarize_tables(
129 tables=categorized_elements.tables,
130 llm_model=summarization_model
131 )
132 for table, summarized_table in zip(categorized_elements.tables, summarized_tables, strict=True):
133 document: Document = Document(
134 page_content=summarized_table,
135 metadata={
136 id_key: str(uuid.uuid4()),
137 "category": "table",
138 **table.metadata.to_dict(),
139 **metadata
140 }
141 )
142 document_category.tables.append(document)
144 if is_include_image:
145 summarized_images: List[str] = await self.summary_document_processor.summarize_images(
146 images=categorized_elements.images,
147 llm_model=summarization_model
148 )
149 for image, summarized_image in zip(categorized_elements.images, summarized_images, strict=True):
150 document: Document = Document(
151 page_content=summarized_image,
152 metadata={
153 id_key: str(uuid.uuid4()),
154 "category": "image",
155 **image.metadata.to_dict(),
156 **metadata
157 }
158 )
159 document_category.images.append(document)
161 return document_category