Coverage for apps/inners/use_cases/graphs/preparation_graph.py: 81%
124 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
1import asyncio
2import pickle
3from typing import Dict, List, Any, Optional, Coroutine
4from uuid import UUID
6import litellm
7from langchain_community.chat_models import ChatLiteLLM
8from langgraph.graph import StateGraph
9from starlette.datastructures import State
10from unstructured.documents.elements import Element
12from apps.inners.exceptions import use_case_exception
13from apps.inners.models.daos.document import Document
14from apps.inners.models.dtos.constants.document_type_constant import DocumentTypeConstant
15from apps.inners.models.dtos.contracts.responses.managements.documents.file_document_response import \
16 FileDocumentResponse
17from apps.inners.models.dtos.contracts.responses.managements.documents.text_document_response import \
18 TextDocumentResponse
19from apps.inners.models.dtos.contracts.responses.managements.documents.web_document_response import WebDocumentResponse
20from apps.inners.models.dtos.document_category import DocumentCategory
21from apps.inners.models.dtos.element_category import ElementCategory
22from apps.inners.models.dtos.graph_state import PreparationGraphState
23from apps.inners.use_cases.document_processor.category_document_processor import CategoryDocumentProcessor
24from apps.inners.use_cases.document_processor.partition_document_processor import PartitionDocumentProcessor
25from apps.outers.datastores.two_datastore import TwoDatastore
26from apps.outers.settings.one_llm_setting import OneLlmSetting
27from apps.outers.settings.two_llm_setting import TwoLlmSetting
28from tools import cache_tool
31class PreparationGraph:
32 def __init__(
33 self,
34 one_llm_setting: OneLlmSetting,
35 two_llm_setting: TwoLlmSetting,
36 two_datastore: TwoDatastore,
37 partition_document_processor: PartitionDocumentProcessor,
38 category_document_processor: CategoryDocumentProcessor,
39 ):
40 self.one_llm_setting = one_llm_setting
41 self.two_llm_setting = two_llm_setting
42 self.two_datastore = two_datastore
43 self.partition_document_processor = partition_document_processor
44 self.category_document_processor = category_document_processor
45 self.compiled_graph = self.compile()
47 def node_get_llm_model(self, input_state: PreparationGraphState) -> PreparationGraphState:
48 output_state: PreparationGraphState = input_state
50 litellm.anthropic_key = self.one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE
51 litellm.openai_key = self.two_llm_setting.LLM_TWO_OPENAI_API_KEY_ONE
52 llm_model: ChatLiteLLM = ChatLiteLLM(
53 model=input_state["llm_setting"]["model_name"],
54 max_tokens=input_state["llm_setting"]["max_token"],
55 streaming=True,
56 temperature=0
57 )
58 output_state["llm_setting"]["model"] = llm_model
60 return output_state
62 async def node_get_categorized_document_worker(self, input_state: PreparationGraphState,
63 document_id: UUID) -> PreparationGraphState:
64 output_state: PreparationGraphState = input_state
66 categorized_element_hash: str = await self.get_categorized_element_hash(
67 state=input_state["state"],
68 document_id=document_id,
69 file_partition_strategy=input_state["preprocessor_setting"]["file_partition_strategy"]
70 )
71 categorized_document_hash: str = self.get_categorized_document_hash(
72 categorized_element_hash=categorized_element_hash,
73 summarization_model_name=input_state["llm_setting"]["model_name"],
74 is_include_tables=input_state["preprocessor_setting"]["is_include_table"],
75 is_include_images=input_state["preprocessor_setting"]["is_include_image"],
76 chunk_size=input_state["preprocessor_setting"]["chunk_size"],
77 )
79 categorized_element_hashes: Optional[Dict[UUID, str]] = input_state["categorized_element_hashes"]
80 if categorized_element_hashes is None:
81 output_state["categorized_element_hashes"] = {}
82 output_state["categorized_element_hashes"][document_id] = categorized_element_hash
83 is_categorized_element_exist: int = await self.two_datastore.async_client.exists(
84 categorized_element_hash
85 )
86 if is_categorized_element_exist == 0:
87 is_categorized_element_exist: bool = False
88 elif is_categorized_element_exist == 1:
89 is_categorized_element_exist: bool = True
90 else:
91 raise use_case_exception.ExistingCategorizedElementHashInvalid()
93 is_force_refresh_categorized_element: bool = input_state["preprocessor_setting"][
94 "is_force_refresh_categorized_element"]
95 if is_categorized_element_exist is False or is_force_refresh_categorized_element is True:
96 elements: List[Element] = await self.partition_document_processor.partition(
97 state=input_state["state"],
98 document_id=document_id,
99 file_partition_strategy=input_state["preprocessor_setting"]["file_partition_strategy"]
100 )
101 categorized_elements: ElementCategory = await self.category_document_processor.categorize_elements(
102 elements=elements
103 )
104 await self.two_datastore.async_client.set(
105 name=categorized_element_hash,
106 value=pickle.dumps(categorized_elements)
107 )
108 else:
109 found_categorized_element_bytes: bytes = await self.two_datastore.async_client.get(
110 categorized_element_hash
111 )
112 categorized_elements: ElementCategory = pickle.loads(found_categorized_element_bytes)
114 categorized_document_hashes: Optional[Dict[UUID, str]] = input_state["categorized_document_hashes"]
115 if categorized_document_hashes is None:
116 output_state["categorized_document_hashes"] = {}
117 output_state["categorized_document_hashes"][document_id] = categorized_document_hash
118 existing_categorized_document_hash: int = await self.two_datastore.async_client.exists(
119 categorized_document_hash
120 )
121 if existing_categorized_document_hash == 0:
122 is_categorized_document_exist: bool = False
123 elif existing_categorized_document_hash == 1:
124 is_categorized_document_exist: bool = True
125 else:
126 raise use_case_exception.ExistingCategorizedDocumentHashInvalid()
128 is_force_refresh_categorized_document: bool = input_state["preprocessor_setting"][
129 "is_force_refresh_categorized_document"]
130 if is_categorized_document_exist is False or is_force_refresh_categorized_document is True or is_force_refresh_categorized_element is True:
131 categorized_document: DocumentCategory = await self.category_document_processor.get_categorized_documents(
132 categorized_elements=categorized_elements,
133 summarization_model=input_state["llm_setting"]["model"],
134 is_include_table=input_state["preprocessor_setting"]["is_include_table"],
135 is_include_image=input_state["preprocessor_setting"]["is_include_image"],
136 chunk_size=input_state["preprocessor_setting"]["chunk_size"],
137 overlap_size=input_state["preprocessor_setting"]["overlap_size"],
138 metadata={
139 "document_id": document_id
140 }
141 )
142 await self.two_datastore.async_client.set(
143 name=categorized_document_hash,
144 value=pickle.dumps(categorized_document)
145 )
146 else:
147 found_categorized_document_bytes: bytes = await self.two_datastore.async_client.get(
148 categorized_document_hash
149 )
150 categorized_document: DocumentCategory = pickle.loads(found_categorized_document_bytes)
152 output_state["categorized_documents"][document_id] = categorized_document
154 return output_state
156 async def node_get_categorized_documents(self, input_state: PreparationGraphState) -> PreparationGraphState:
157 output_state: PreparationGraphState = input_state
159 document_ids: List[UUID] = input_state["document_ids"]
160 output_state["categorized_element_hashes"] = {}
161 output_state["categorized_document_hashes"] = {}
162 output_state["categorized_documents"] = {}
164 future_tasks: List[Coroutine] = []
165 for document_id in document_ids:
166 future_task = self.node_get_categorized_document_worker(
167 input_state=input_state,
168 document_id=document_id
169 )
170 future_tasks.append(future_task)
172 for task_result in await asyncio.gather(*future_tasks):
173 output_state["categorized_element_hashes"].update(task_result["categorized_element_hashes"])
174 output_state["categorized_document_hashes"].update(task_result["categorized_document_hashes"])
175 output_state["categorized_documents"].update(task_result["categorized_documents"])
177 return output_state
179 async def get_categorized_element_hash(
180 self,
181 state: State,
182 document_id: UUID,
183 file_partition_strategy: str
184 ):
185 data: Dict[str, Any] = {
186 "document_id": document_id,
187 }
188 found_document: Document = await self.partition_document_processor.document_management.find_one_by_id_with_authorization(
189 state=state,
190 id=document_id
191 )
192 if found_document.document_type_id == DocumentTypeConstant.FILE:
193 found_file_document: FileDocumentResponse = await self.partition_document_processor.file_document_management.find_one_by_id_with_authorization(
194 state=state,
195 id=document_id
196 )
197 data["document_detail_hash"] = found_file_document.file_data_hash
198 data["file_partition_strategy"] = file_partition_strategy
199 elif found_document.document_type_id == DocumentTypeConstant.TEXT:
200 found_text_document: TextDocumentResponse = await self.partition_document_processor.text_document_management.find_one_by_id_with_authorization(
201 state=state,
202 id=document_id
203 )
204 data["document_detail_hash"] = found_text_document.text_content_hash
205 elif found_document.document_type_id == DocumentTypeConstant.WEB:
206 found_web_document: WebDocumentResponse = await self.partition_document_processor.web_document_management.find_one_by_id_with_authorization(
207 state=state,
208 id=document_id
209 )
210 data["document_detail_hash"] = found_web_document.web_url_hash
211 else:
212 raise use_case_exception.DocumentTypeNotSupported()
214 hashed_data: str = cache_tool.hash_by_dict(
215 data=data
216 )
217 hashed_data = f"categorized_element/{hashed_data}"
219 return hashed_data
221 def get_categorized_document_hash(
222 self,
223 categorized_element_hash: str,
224 summarization_model_name: str,
225 is_include_tables: bool,
226 is_include_images: bool,
227 chunk_size: int,
228 ) -> str:
229 data: Dict[str, Any] = {
230 "categorized_element_hash": categorized_element_hash,
231 "summarization_model_name": summarization_model_name,
232 "is_include_tables": is_include_tables,
233 "is_include_images": is_include_images,
234 "chunk_size": chunk_size,
235 }
236 hashed_data: str = cache_tool.hash_by_dict(
237 data=data
238 )
239 hashed_data = f"categorized_document/{hashed_data}"
241 return hashed_data
243 def compile(self):
244 graph: StateGraph = StateGraph(PreparationGraphState)
246 graph.add_node(
247 node=self.node_get_llm_model.__name__,
248 action=self.node_get_llm_model
249 )
250 graph.add_node(
251 node=self.node_get_categorized_documents.__name__,
252 action=self.node_get_categorized_documents
253 )
255 graph.set_entry_point(
256 key=self.node_get_llm_model.__name__
257 )
258 graph.add_edge(
259 start_key=self.node_get_llm_model.__name__,
260 end_key=self.node_get_categorized_documents.__name__
261 )
262 graph.set_finish_point(
263 key=self.node_get_categorized_documents.__name__
264 )
266 compiled_graph = graph.compile()
268 return compiled_graph