Coverage for apps/inners/use_cases/graphs/preparation_graph.py: 81%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1import asyncio 

2import pickle 

3from typing import Dict, List, Any, Optional, Coroutine 

4from uuid import UUID 

5 

6import litellm 

7from langchain_community.chat_models import ChatLiteLLM 

8from langgraph.graph import StateGraph 

9from starlette.datastructures import State 

10from unstructured.documents.elements import Element 

11 

12from apps.inners.exceptions import use_case_exception 

13from apps.inners.models.daos.document import Document 

14from apps.inners.models.dtos.constants.document_type_constant import DocumentTypeConstant 

15from apps.inners.models.dtos.contracts.responses.managements.documents.file_document_response import \ 

16 FileDocumentResponse 

17from apps.inners.models.dtos.contracts.responses.managements.documents.text_document_response import \ 

18 TextDocumentResponse 

19from apps.inners.models.dtos.contracts.responses.managements.documents.web_document_response import WebDocumentResponse 

20from apps.inners.models.dtos.document_category import DocumentCategory 

21from apps.inners.models.dtos.element_category import ElementCategory 

22from apps.inners.models.dtos.graph_state import PreparationGraphState 

23from apps.inners.use_cases.document_processor.category_document_processor import CategoryDocumentProcessor 

24from apps.inners.use_cases.document_processor.partition_document_processor import PartitionDocumentProcessor 

25from apps.outers.datastores.two_datastore import TwoDatastore 

26from apps.outers.settings.one_llm_setting import OneLlmSetting 

27from apps.outers.settings.two_llm_setting import TwoLlmSetting 

28from tools import cache_tool 

29 

30 

31class PreparationGraph: 

32 def __init__( 

33 self, 

34 one_llm_setting: OneLlmSetting, 

35 two_llm_setting: TwoLlmSetting, 

36 two_datastore: TwoDatastore, 

37 partition_document_processor: PartitionDocumentProcessor, 

38 category_document_processor: CategoryDocumentProcessor, 

39 ): 

40 self.one_llm_setting = one_llm_setting 

41 self.two_llm_setting = two_llm_setting 

42 self.two_datastore = two_datastore 

43 self.partition_document_processor = partition_document_processor 

44 self.category_document_processor = category_document_processor 

45 self.compiled_graph = self.compile() 

46 

47 def node_get_llm_model(self, input_state: PreparationGraphState) -> PreparationGraphState: 

48 output_state: PreparationGraphState = input_state 

49 

50 litellm.anthropic_key = self.one_llm_setting.LLM_ONE_ANTHROPIC_API_KEY_ONE 

51 litellm.openai_key = self.two_llm_setting.LLM_TWO_OPENAI_API_KEY_ONE 

52 llm_model: ChatLiteLLM = ChatLiteLLM( 

53 model=input_state["llm_setting"]["model_name"], 

54 max_tokens=input_state["llm_setting"]["max_token"], 

55 streaming=True, 

56 temperature=0 

57 ) 

58 output_state["llm_setting"]["model"] = llm_model 

59 

60 return output_state 

61 

62 async def node_get_categorized_document_worker(self, input_state: PreparationGraphState, 

63 document_id: UUID) -> PreparationGraphState: 

64 output_state: PreparationGraphState = input_state 

65 

66 categorized_element_hash: str = await self.get_categorized_element_hash( 

67 state=input_state["state"], 

68 document_id=document_id, 

69 file_partition_strategy=input_state["preprocessor_setting"]["file_partition_strategy"] 

70 ) 

71 categorized_document_hash: str = self.get_categorized_document_hash( 

72 categorized_element_hash=categorized_element_hash, 

73 summarization_model_name=input_state["llm_setting"]["model_name"], 

74 is_include_tables=input_state["preprocessor_setting"]["is_include_table"], 

75 is_include_images=input_state["preprocessor_setting"]["is_include_image"], 

76 chunk_size=input_state["preprocessor_setting"]["chunk_size"], 

77 ) 

78 

79 categorized_element_hashes: Optional[Dict[UUID, str]] = input_state["categorized_element_hashes"] 

80 if categorized_element_hashes is None: 

81 output_state["categorized_element_hashes"] = {} 

82 output_state["categorized_element_hashes"][document_id] = categorized_element_hash 

83 is_categorized_element_exist: int = await self.two_datastore.async_client.exists( 

84 categorized_element_hash 

85 ) 

86 if is_categorized_element_exist == 0: 

87 is_categorized_element_exist: bool = False 

88 elif is_categorized_element_exist == 1: 

89 is_categorized_element_exist: bool = True 

90 else: 

91 raise use_case_exception.ExistingCategorizedElementHashInvalid() 

92 

93 is_force_refresh_categorized_element: bool = input_state["preprocessor_setting"][ 

94 "is_force_refresh_categorized_element"] 

95 if is_categorized_element_exist is False or is_force_refresh_categorized_element is True: 

96 elements: List[Element] = await self.partition_document_processor.partition( 

97 state=input_state["state"], 

98 document_id=document_id, 

99 file_partition_strategy=input_state["preprocessor_setting"]["file_partition_strategy"] 

100 ) 

101 categorized_elements: ElementCategory = await self.category_document_processor.categorize_elements( 

102 elements=elements 

103 ) 

104 await self.two_datastore.async_client.set( 

105 name=categorized_element_hash, 

106 value=pickle.dumps(categorized_elements) 

107 ) 

108 else: 

109 found_categorized_element_bytes: bytes = await self.two_datastore.async_client.get( 

110 categorized_element_hash 

111 ) 

112 categorized_elements: ElementCategory = pickle.loads(found_categorized_element_bytes) 

113 

114 categorized_document_hashes: Optional[Dict[UUID, str]] = input_state["categorized_document_hashes"] 

115 if categorized_document_hashes is None: 

116 output_state["categorized_document_hashes"] = {} 

117 output_state["categorized_document_hashes"][document_id] = categorized_document_hash 

118 existing_categorized_document_hash: int = await self.two_datastore.async_client.exists( 

119 categorized_document_hash 

120 ) 

121 if existing_categorized_document_hash == 0: 

122 is_categorized_document_exist: bool = False 

123 elif existing_categorized_document_hash == 1: 

124 is_categorized_document_exist: bool = True 

125 else: 

126 raise use_case_exception.ExistingCategorizedDocumentHashInvalid() 

127 

128 is_force_refresh_categorized_document: bool = input_state["preprocessor_setting"][ 

129 "is_force_refresh_categorized_document"] 

130 if is_categorized_document_exist is False or is_force_refresh_categorized_document is True or is_force_refresh_categorized_element is True: 

131 categorized_document: DocumentCategory = await self.category_document_processor.get_categorized_documents( 

132 categorized_elements=categorized_elements, 

133 summarization_model=input_state["llm_setting"]["model"], 

134 is_include_table=input_state["preprocessor_setting"]["is_include_table"], 

135 is_include_image=input_state["preprocessor_setting"]["is_include_image"], 

136 chunk_size=input_state["preprocessor_setting"]["chunk_size"], 

137 overlap_size=input_state["preprocessor_setting"]["overlap_size"], 

138 metadata={ 

139 "document_id": document_id 

140 } 

141 ) 

142 await self.two_datastore.async_client.set( 

143 name=categorized_document_hash, 

144 value=pickle.dumps(categorized_document) 

145 ) 

146 else: 

147 found_categorized_document_bytes: bytes = await self.two_datastore.async_client.get( 

148 categorized_document_hash 

149 ) 

150 categorized_document: DocumentCategory = pickle.loads(found_categorized_document_bytes) 

151 

152 output_state["categorized_documents"][document_id] = categorized_document 

153 

154 return output_state 

155 

156 async def node_get_categorized_documents(self, input_state: PreparationGraphState) -> PreparationGraphState: 

157 output_state: PreparationGraphState = input_state 

158 

159 document_ids: List[UUID] = input_state["document_ids"] 

160 output_state["categorized_element_hashes"] = {} 

161 output_state["categorized_document_hashes"] = {} 

162 output_state["categorized_documents"] = {} 

163 

164 future_tasks: List[Coroutine] = [] 

165 for document_id in document_ids: 

166 future_task = self.node_get_categorized_document_worker( 

167 input_state=input_state, 

168 document_id=document_id 

169 ) 

170 future_tasks.append(future_task) 

171 

172 for task_result in await asyncio.gather(*future_tasks): 

173 output_state["categorized_element_hashes"].update(task_result["categorized_element_hashes"]) 

174 output_state["categorized_document_hashes"].update(task_result["categorized_document_hashes"]) 

175 output_state["categorized_documents"].update(task_result["categorized_documents"]) 

176 

177 return output_state 

178 

179 async def get_categorized_element_hash( 

180 self, 

181 state: State, 

182 document_id: UUID, 

183 file_partition_strategy: str 

184 ): 

185 data: Dict[str, Any] = { 

186 "document_id": document_id, 

187 } 

188 found_document: Document = await self.partition_document_processor.document_management.find_one_by_id_with_authorization( 

189 state=state, 

190 id=document_id 

191 ) 

192 if found_document.document_type_id == DocumentTypeConstant.FILE: 

193 found_file_document: FileDocumentResponse = await self.partition_document_processor.file_document_management.find_one_by_id_with_authorization( 

194 state=state, 

195 id=document_id 

196 ) 

197 data["document_detail_hash"] = found_file_document.file_data_hash 

198 data["file_partition_strategy"] = file_partition_strategy 

199 elif found_document.document_type_id == DocumentTypeConstant.TEXT: 

200 found_text_document: TextDocumentResponse = await self.partition_document_processor.text_document_management.find_one_by_id_with_authorization( 

201 state=state, 

202 id=document_id 

203 ) 

204 data["document_detail_hash"] = found_text_document.text_content_hash 

205 elif found_document.document_type_id == DocumentTypeConstant.WEB: 

206 found_web_document: WebDocumentResponse = await self.partition_document_processor.web_document_management.find_one_by_id_with_authorization( 

207 state=state, 

208 id=document_id 

209 ) 

210 data["document_detail_hash"] = found_web_document.web_url_hash 

211 else: 

212 raise use_case_exception.DocumentTypeNotSupported() 

213 

214 hashed_data: str = cache_tool.hash_by_dict( 

215 data=data 

216 ) 

217 hashed_data = f"categorized_element/{hashed_data}" 

218 

219 return hashed_data 

220 

221 def get_categorized_document_hash( 

222 self, 

223 categorized_element_hash: str, 

224 summarization_model_name: str, 

225 is_include_tables: bool, 

226 is_include_images: bool, 

227 chunk_size: int, 

228 ) -> str: 

229 data: Dict[str, Any] = { 

230 "categorized_element_hash": categorized_element_hash, 

231 "summarization_model_name": summarization_model_name, 

232 "is_include_tables": is_include_tables, 

233 "is_include_images": is_include_images, 

234 "chunk_size": chunk_size, 

235 } 

236 hashed_data: str = cache_tool.hash_by_dict( 

237 data=data 

238 ) 

239 hashed_data = f"categorized_document/{hashed_data}" 

240 

241 return hashed_data 

242 

243 def compile(self): 

244 graph: StateGraph = StateGraph(PreparationGraphState) 

245 

246 graph.add_node( 

247 node=self.node_get_llm_model.__name__, 

248 action=self.node_get_llm_model 

249 ) 

250 graph.add_node( 

251 node=self.node_get_categorized_documents.__name__, 

252 action=self.node_get_categorized_documents 

253 ) 

254 

255 graph.set_entry_point( 

256 key=self.node_get_llm_model.__name__ 

257 ) 

258 graph.add_edge( 

259 start_key=self.node_get_llm_model.__name__, 

260 end_key=self.node_get_categorized_documents.__name__ 

261 ) 

262 graph.set_finish_point( 

263 key=self.node_get_categorized_documents.__name__ 

264 ) 

265 

266 compiled_graph = graph.compile() 

267 

268 return compiled_graph