Coverage for apps/inners/use_cases/graphs/long_form_qa_graph.py: 69%

191 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1from typing import Dict, List, Any 

2 

3from langchain_community.chat_models import ChatLiteLLM 

4from langchain_core.documents import Document 

5from langchain_core.messages import BaseMessage, HumanMessage 

6from langchain_core.output_parsers import StrOutputParser 

7from langchain_core.output_parsers.openai_tools import PydanticToolsParser 

8from langchain_core.prompts import PromptTemplate 

9from langchain_core.runnables import RunnableSerializable 

10from langgraph.graph import StateGraph 

11from langgraph.graph.graph import CompiledGraph, END 

12from pydantic.v1 import Field 

13 

14from apps.inners.exceptions import use_case_exception 

15from apps.inners.models.base_model import BaseModel 

16from apps.inners.models.dtos.graph_state import LongFormQaGraphState 

17from apps.inners.use_cases.graphs.passage_search_graph import PassageSearchGraph 

18from apps.inners.use_cases.retrievers.hybrid_milvus_retriever import HybridMilvusRetriever 

19from tools import cache_tool 

20 

21 

22class LongFormQaGraph(PassageSearchGraph): 

23 def __init__( 

24 self, 

25 *args: Any, 

26 **kwargs: Any 

27 ): 

28 super().__init__(*args, **kwargs) 

29 self.compiled_graph: CompiledGraph = self.compile() 

30 

31 async def node_generate_answer(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState: 

32 output_state: LongFormQaGraphState = input_state 

33 

34 re_ranked_documents: List[Document] = input_state["re_ranked_documents"] 

35 retriever: HybridMilvusRetriever = input_state["retriever_setting"]["retriever"] 

36 re_ranked_document_ids: List[str] = [document.metadata[retriever.id_key] for document in re_ranked_documents] 

37 generated_answer_hash: str = self.get_generated_answer_hash( 

38 re_ranked_document_ids=re_ranked_document_ids, 

39 question=input_state["question"], 

40 llm_model_name=input_state["llm_setting"]["model_name"], 

41 prompt=input_state["generator_setting"]["prompt"], 

42 max_token=input_state["llm_setting"]["max_token"], 

43 ) 

44 existing_generated_answer_hash: int = await self.two_datastore.async_client.exists(generated_answer_hash) 

45 if existing_generated_answer_hash == 0: 

46 is_generated_answer_exist: bool = False 

47 elif existing_generated_answer_hash == 1: 

48 is_generated_answer_exist: bool = True 

49 else: 

50 raise use_case_exception.ExistingGeneratedAnswerHashInvalid() 

51 

52 is_force_refresh_generated_answer: bool = input_state["generator_setting"][ 

53 "is_force_refresh_generated_answer"] 

54 if is_generated_answer_exist is False or is_force_refresh_generated_answer is True: 

55 prompt: PromptTemplate = PromptTemplate( 

56 template=input_state["generator_setting"]["prompt"], 

57 template_format="jinja2", 

58 input_variables=["passages", "question"] 

59 ) 

60 text: str = prompt.format( 

61 passages=re_ranked_documents, 

62 question=input_state["question"] 

63 ) 

64 messages: List[BaseMessage] = [ 

65 HumanMessage( 

66 content=[ 

67 { 

68 "type": "text", 

69 "text": text 

70 } 

71 ] 

72 ) 

73 ] 

74 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"] 

75 chain: RunnableSerializable = llm_model | StrOutputParser() 

76 generated_answer: str = await chain.ainvoke( 

77 input=messages 

78 ) 

79 await self.two_datastore.async_client.set( 

80 name=generated_answer_hash, 

81 value=generated_answer.encode() 

82 ) 

83 else: 

84 generated_answer_byte: bytes = await self.two_datastore.async_client.get(generated_answer_hash) 

85 generated_answer: str = generated_answer_byte.decode() 

86 

87 output_state["generated_answer"] = generated_answer 

88 output_state["generated_answer_hash"] = generated_answer_hash 

89 

90 return output_state 

91 

92 def get_generated_answer_hash( 

93 self, 

94 re_ranked_document_ids: List[str], 

95 question: str, 

96 llm_model_name: str, 

97 prompt: str, 

98 max_token: int, 

99 ) -> str: 

100 data: Dict[str, Any] = { 

101 "re_ranked_document_ids": re_ranked_document_ids, 

102 "question": question, 

103 "llm_model_name": llm_model_name, 

104 "prompt": prompt, 

105 "max_token": max_token, 

106 } 

107 hashed_data: str = cache_tool.hash_by_dict( 

108 data=data 

109 ) 

110 hashed_data = f"generated_answer/{hashed_data}" 

111 

112 return hashed_data 

113 

114 async def node_grade_hallucination(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState: 

115 output_state: LongFormQaGraphState = input_state 

116 

117 re_ranked_documents: List[Document] = input_state["re_ranked_documents"] 

118 

119 class GradeTool(BaseModel): 

120 """Binary score for support check.""" 

121 binary_score: bool = Field( 

122 description="Is supported binary score, either True if supported or False if not supported." 

123 ) 

124 

125 retriever: HybridMilvusRetriever = input_state["retriever_setting"]["retriever"] 

126 generated_hallucination_grade_hash: str = self.get_generated_hallucination_grade_hash( 

127 retrieved_document_ids=[document.metadata[retriever.id_key] for document in re_ranked_documents], 

128 generated_answer_hash=input_state["generated_answer_hash"] 

129 ) 

130 existing_generated_hallucination_grade_hash: int = await self.two_datastore.async_client.exists( 

131 generated_hallucination_grade_hash) 

132 if existing_generated_hallucination_grade_hash == 0: 

133 is_generated_hallucination_grade_hash_exist: bool = False 

134 elif existing_generated_hallucination_grade_hash == 1: 

135 is_generated_hallucination_grade_hash_exist: bool = True 

136 else: 

137 raise use_case_exception.ExistingGeneratedHallucinationGradeHashInvalid() 

138 

139 is_force_refresh_generated_hallucination_grade: bool = input_state["generator_setting"][ 

140 "is_force_refresh_generated_hallucination_grade"] 

141 if is_generated_hallucination_grade_hash_exist is False or is_force_refresh_generated_hallucination_grade is True: 

142 prompt: PromptTemplate = PromptTemplate( 

143 template=""" 

144 <instruction> 

145 Assess whether an Large Language Model generated answer to the question is supported by the passages. Give one binary score of "True" or "False". "True" means that the answer to the question is supported by the passages. "False" means that the answer to the question is not supported by the passages. 

146 <instruction/> 

147 <passages> 

148 {% for passage in passages %} 

149 <passage_{{ loop.index }}> 

150 {{ passage.page_content }} 

151 <passage_{{ loop.index }}/> 

152 {% endfor %} 

153 <passages/> 

154 <question> 

155 {{ question }} 

156 <question/> 

157 <answer> 

158 {{ answer }} 

159 <answer/> 

160 """, 

161 template_format="jinja2", 

162 input_variables=["passages", "question", "answer"] 

163 ) 

164 text: str = prompt.format( 

165 passages=re_ranked_documents, 

166 question=input_state["question"], 

167 answer=input_state["generated_answer"] 

168 ) 

169 messages: List[BaseMessage] = [ 

170 HumanMessage( 

171 content=[ 

172 { 

173 "type": "text", 

174 "text": text 

175 } 

176 ] 

177 ) 

178 ] 

179 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"] 

180 tool_parser: PydanticToolsParser = PydanticToolsParser( 

181 tools=[GradeTool] 

182 ) 

183 chain: RunnableSerializable = llm_model.bind_tools(tools=tool_parser.tools, 

184 tool_choice="required") | tool_parser 

185 generated_tools: List[GradeTool] = await chain.ainvoke( 

186 input=messages 

187 ) 

188 generated_hallucination_grade: str = str(not generated_tools[0].binary_score) 

189 await self.two_datastore.async_client.set( 

190 name=generated_hallucination_grade_hash, 

191 value=generated_hallucination_grade.encode() 

192 ) 

193 else: 

194 generated_hallucination_grade_byte: bytes = await self.two_datastore.async_client.get( 

195 generated_hallucination_grade_hash) 

196 generated_hallucination_grade: str = generated_hallucination_grade_byte.decode() 

197 

198 output_state["generated_hallucination_grade"] = generated_hallucination_grade 

199 output_state["generated_hallucination_grade_hash"] = generated_hallucination_grade_hash 

200 

201 return output_state 

202 

203 def get_generated_hallucination_grade_hash( 

204 self, 

205 retrieved_document_ids: List[str], 

206 generated_answer_hash: str, 

207 ) -> str: 

208 data: Dict[str, Any] = { 

209 "retrieved_document_ids": retrieved_document_ids, 

210 "generated_answer_hash": generated_answer_hash, 

211 } 

212 hashed_data: str = cache_tool.hash_by_dict( 

213 data=data 

214 ) 

215 hashed_data = f"generated_hallucination_grade/{hashed_data}" 

216 

217 return hashed_data 

218 

219 async def node_grade_answer_relevancy(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState: 

220 output_state: LongFormQaGraphState = input_state 

221 

222 class GradeTool(BaseModel): 

223 """Binary score for resolution check.""" 

224 binary_score: bool = Field( 

225 description="Is resolved binary score, either True if resolved or False if not resolved." 

226 ) 

227 

228 generated_answer_relevancy_grade_hash: str = self.get_generated_answer_relevancy_grade_hash( 

229 question=input_state["question"], 

230 generated_answer_hash=input_state["generated_answer_hash"] 

231 ) 

232 existing_generated_hallucination_grade_hash: int = await self.two_datastore.async_client.exists( 

233 generated_answer_relevancy_grade_hash) 

234 if existing_generated_hallucination_grade_hash == 0: 

235 is_generated_hallucination_grade_hash_exist: bool = False 

236 elif existing_generated_hallucination_grade_hash == 1: 

237 is_generated_hallucination_grade_hash_exist: bool = True 

238 else: 

239 raise use_case_exception.ExistingGeneratedAnswerRelevancyGradeHashInvalid() 

240 

241 is_force_refresh_generated_answer_relevancy_grade: bool = input_state["generator_setting"][ 

242 "is_force_refresh_generated_answer_relevancy_grade"] 

243 if is_generated_hallucination_grade_hash_exist is False or is_force_refresh_generated_answer_relevancy_grade is True: 

244 prompt: PromptTemplate = PromptTemplate( 

245 template=""" 

246 <instruction> 

247 Assess whether an Large Language Model generated answer resolves a question. Give one binary score of "True" or "False". "True" means that the answer resolves the question. "False" means that the answer does not resolve the question. 

248 <instruction/> 

249 <question> 

250 {{ question }} 

251 <question/> 

252 <answer> 

253 {{ answer }} 

254 <answer/> 

255 """, 

256 template_format="jinja2", 

257 input_variables=["question", "answer"] 

258 ) 

259 text: str = prompt.format( 

260 question=input_state["question"], 

261 answer=input_state["generated_answer"], 

262 ) 

263 messages: List[BaseMessage] = [ 

264 HumanMessage( 

265 content=[ 

266 { 

267 "type": "text", 

268 "text": text 

269 } 

270 ] 

271 ) 

272 ] 

273 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"] 

274 tool_parser: PydanticToolsParser = PydanticToolsParser( 

275 tools=[GradeTool] 

276 ) 

277 chain: RunnableSerializable = llm_model.bind_tools(tools=tool_parser.tools, 

278 tool_choice="required") | tool_parser 

279 generated_tools: List[GradeTool] = await chain.ainvoke( 

280 input=messages 

281 ) 

282 generated_answer_relevancy_grade: str = str(generated_tools[0].binary_score) 

283 await self.two_datastore.async_client.set( 

284 name=generated_answer_relevancy_grade_hash, 

285 value=generated_answer_relevancy_grade.encode() 

286 ) 

287 else: 

288 generated_answer_relevancy_grade_byte: bytes = await self.two_datastore.async_client.get( 

289 generated_answer_relevancy_grade_hash 

290 ) 

291 generated_answer_relevancy_grade: str = generated_answer_relevancy_grade_byte.decode() 

292 

293 output_state["generated_answer_relevancy_grade"] = generated_answer_relevancy_grade 

294 output_state["generated_answer_relevancy_grade_hash"] = generated_answer_relevancy_grade_hash 

295 

296 return output_state 

297 

298 def get_generated_answer_relevancy_grade_hash( 

299 self, 

300 question: str, 

301 generated_answer_hash: str, 

302 ) -> str: 

303 data: Dict[str, Any] = { 

304 "question": question, 

305 "generated_answer_hash": generated_answer_hash, 

306 } 

307 hashed_data: str = cache_tool.hash_by_dict( 

308 data=data 

309 ) 

310 hashed_data = f"generated_answer_relevancy_grade/{hashed_data}" 

311 

312 return hashed_data 

313 

314 def node_decide_transform_question_or_grade_answer_relevancy(self, input_state: LongFormQaGraphState) -> str: 

315 output_state: LongFormQaGraphState = input_state 

316 

317 generated_hallucination_grade: str = input_state["generated_hallucination_grade"] 

318 if generated_hallucination_grade == "False": 

319 return "GRADE_ANSWER_RELEVANCY" 

320 

321 transform_question_max_retry: int = input_state["transform_question_max_retry"] 

322 transform_question_current_retry: int = input_state["state"].transform_question_current_retry 

323 

324 if transform_question_current_retry >= transform_question_max_retry: 

325 return "MAX_RETRY" 

326 

327 output_state["state"].transform_question_current_retry += 1 

328 

329 return "TRANSFORM_QUESTION" 

330 

331 def node_decide_transform_question_or_provide_answer(self, input_state: LongFormQaGraphState) -> str: 

332 output_state: LongFormQaGraphState = input_state 

333 

334 generated_answer_relevancy_grade: str = input_state["generated_answer_relevancy_grade"] 

335 if generated_answer_relevancy_grade == "True": 

336 return "PROVIDE_ANSWER" 

337 

338 transform_question_max_retry: int = input_state["transform_question_max_retry"] 

339 transform_question_current_retry: int = input_state["state"].transform_question_current_retry 

340 if transform_question_current_retry >= transform_question_max_retry: 

341 return "MAX_RETRY" 

342 

343 output_state["state"].transform_question_current_retry += 1 

344 

345 return "TRANSFORM_QUESTION" 

346 

347 async def node_transform_question(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState: 

348 output_state: LongFormQaGraphState = input_state 

349 

350 generated_question_hash: str = self.get_transformed_question_hash( 

351 question=input_state["question"] 

352 ) 

353 existing_generated_question_hash: int = await self.two_datastore.async_client.exists( 

354 generated_question_hash 

355 ) 

356 if existing_generated_question_hash == 0: 

357 is_generated_question_exist: bool = False 

358 elif existing_generated_question_hash == 1: 

359 is_generated_question_exist: bool = True 

360 else: 

361 raise use_case_exception.ExistingGeneratedQuestionHashInvalid() 

362 

363 is_force_refresh_generated_question: bool = input_state["generator_setting"][ 

364 "is_force_refresh_generated_question"] 

365 if is_generated_question_exist is False or is_force_refresh_generated_question is True: 

366 prompt: PromptTemplate = PromptTemplate( 

367 template=""" 

368 <instruction> 

369 Converts the question to a better version that is optimized for vector store retrieval. Observe the question and try to reason about underlying semantics. Ensure the output is only the question without re-explain the instruction. 

370 <instruction/> 

371 <question> 

372 {question} 

373 <question/> 

374 """, 

375 input_variables=["question"] 

376 ) 

377 text: str = prompt.format( 

378 question=input_state["question"] 

379 ) 

380 messages: List[BaseMessage] = [ 

381 HumanMessage( 

382 content=[ 

383 { 

384 "type": "text", 

385 "text": text 

386 } 

387 ] 

388 ) 

389 ] 

390 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"] 

391 chain: RunnableSerializable = llm_model | StrOutputParser() 

392 generated_question: str = await chain.ainvoke( 

393 input=messages 

394 ) 

395 await self.two_datastore.async_client.set( 

396 name=generated_question_hash, 

397 value=generated_question.encode() 

398 ) 

399 else: 

400 generated_question_byte: bytes = await self.two_datastore.async_client.get( 

401 generated_question_hash 

402 ) 

403 generated_question: str = generated_question_byte.decode() 

404 

405 output_state["question"] = generated_question 

406 output_state["generated_question_hash"] = generated_question_hash 

407 

408 return output_state 

409 

410 def get_transformed_question_hash( 

411 self, 

412 question: str, 

413 ) -> str: 

414 data: Dict[str, Any] = { 

415 "question": question, 

416 } 

417 hashed_data: str = cache_tool.hash_by_dict( 

418 data=data 

419 ) 

420 hashed_data = f"transformed_question/{hashed_data}" 

421 

422 return hashed_data 

423 

424 def compile(self) -> CompiledGraph: 

425 graph: StateGraph = StateGraph(LongFormQaGraphState) 

426 

427 graph.add_node( 

428 node=self.node_get_llm_model.__name__, 

429 action=self.node_get_llm_model 

430 ) 

431 graph.add_node( 

432 node=self.node_get_categorized_documents.__name__, 

433 action=self.node_get_categorized_documents 

434 ) 

435 graph.add_node( 

436 node=self.node_embed.__name__, 

437 action=self.node_embed 

438 ) 

439 graph.add_node( 

440 node=self.node_get_relevant_documents.__name__, 

441 action=self.node_get_relevant_documents 

442 ) 

443 graph.add_node( 

444 node=self.node_get_re_ranked_documents.__name__, 

445 action=self.node_get_re_ranked_documents 

446 ) 

447 graph.add_node( 

448 node=self.node_generate_answer.__name__, 

449 action=self.node_generate_answer 

450 ) 

451 graph.add_node( 

452 node=self.node_grade_hallucination.__name__, 

453 action=self.node_grade_hallucination 

454 ) 

455 graph.add_node( 

456 node=self.node_grade_answer_relevancy.__name__, 

457 action=self.node_grade_answer_relevancy 

458 ) 

459 graph.add_node( 

460 node=self.node_transform_question.__name__, 

461 action=self.node_transform_question 

462 ) 

463 

464 graph.set_entry_point( 

465 key=self.node_get_llm_model.__name__ 

466 ) 

467 graph.add_edge( 

468 start_key=self.node_get_llm_model.__name__, 

469 end_key=self.node_get_categorized_documents.__name__ 

470 ) 

471 graph.add_edge( 

472 start_key=self.node_get_categorized_documents.__name__, 

473 end_key=self.node_embed.__name__ 

474 ) 

475 graph.add_edge( 

476 start_key=self.node_embed.__name__, 

477 end_key=self.node_get_relevant_documents.__name__ 

478 ) 

479 graph.add_edge( 

480 start_key=self.node_get_relevant_documents.__name__, 

481 end_key=self.node_get_re_ranked_documents.__name__ 

482 ) 

483 graph.add_edge( 

484 start_key=self.node_get_re_ranked_documents.__name__, 

485 end_key=self.node_generate_answer.__name__ 

486 ) 

487 graph.add_edge( 

488 start_key=self.node_generate_answer.__name__, 

489 end_key=self.node_grade_hallucination.__name__ 

490 ) 

491 graph.add_conditional_edges( 

492 source=self.node_grade_hallucination.__name__, 

493 path=self.node_decide_transform_question_or_grade_answer_relevancy, 

494 path_map={ 

495 "MAX_RETRY": END, 

496 "GRADE_ANSWER_RELEVANCY": self.node_grade_answer_relevancy.__name__, 

497 "TRANSFORM_QUESTION": self.node_transform_question.__name__ 

498 } 

499 ) 

500 graph.add_conditional_edges( 

501 source=self.node_grade_answer_relevancy.__name__, 

502 path=self.node_decide_transform_question_or_provide_answer, 

503 path_map={ 

504 "MAX_RETRY": END, 

505 "PROVIDE_ANSWER": END, 

506 "TRANSFORM_QUESTION": self.node_transform_question.__name__ 

507 } 

508 ) 

509 graph.add_edge( 

510 start_key=self.node_transform_question.__name__, 

511 end_key=self.node_get_relevant_documents.__name__ 

512 ) 

513 

514 compiled_graph: CompiledGraph = graph.compile() 

515 

516 return compiled_graph