Coverage for apps/inners/use_cases/graphs/long_form_qa_graph.py: 69%
191 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
1from typing import Dict, List, Any
3from langchain_community.chat_models import ChatLiteLLM
4from langchain_core.documents import Document
5from langchain_core.messages import BaseMessage, HumanMessage
6from langchain_core.output_parsers import StrOutputParser
7from langchain_core.output_parsers.openai_tools import PydanticToolsParser
8from langchain_core.prompts import PromptTemplate
9from langchain_core.runnables import RunnableSerializable
10from langgraph.graph import StateGraph
11from langgraph.graph.graph import CompiledGraph, END
12from pydantic.v1 import Field
14from apps.inners.exceptions import use_case_exception
15from apps.inners.models.base_model import BaseModel
16from apps.inners.models.dtos.graph_state import LongFormQaGraphState
17from apps.inners.use_cases.graphs.passage_search_graph import PassageSearchGraph
18from apps.inners.use_cases.retrievers.hybrid_milvus_retriever import HybridMilvusRetriever
19from tools import cache_tool
22class LongFormQaGraph(PassageSearchGraph):
23 def __init__(
24 self,
25 *args: Any,
26 **kwargs: Any
27 ):
28 super().__init__(*args, **kwargs)
29 self.compiled_graph: CompiledGraph = self.compile()
31 async def node_generate_answer(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState:
32 output_state: LongFormQaGraphState = input_state
34 re_ranked_documents: List[Document] = input_state["re_ranked_documents"]
35 retriever: HybridMilvusRetriever = input_state["retriever_setting"]["retriever"]
36 re_ranked_document_ids: List[str] = [document.metadata[retriever.id_key] for document in re_ranked_documents]
37 generated_answer_hash: str = self.get_generated_answer_hash(
38 re_ranked_document_ids=re_ranked_document_ids,
39 question=input_state["question"],
40 llm_model_name=input_state["llm_setting"]["model_name"],
41 prompt=input_state["generator_setting"]["prompt"],
42 max_token=input_state["llm_setting"]["max_token"],
43 )
44 existing_generated_answer_hash: int = await self.two_datastore.async_client.exists(generated_answer_hash)
45 if existing_generated_answer_hash == 0:
46 is_generated_answer_exist: bool = False
47 elif existing_generated_answer_hash == 1:
48 is_generated_answer_exist: bool = True
49 else:
50 raise use_case_exception.ExistingGeneratedAnswerHashInvalid()
52 is_force_refresh_generated_answer: bool = input_state["generator_setting"][
53 "is_force_refresh_generated_answer"]
54 if is_generated_answer_exist is False or is_force_refresh_generated_answer is True:
55 prompt: PromptTemplate = PromptTemplate(
56 template=input_state["generator_setting"]["prompt"],
57 template_format="jinja2",
58 input_variables=["passages", "question"]
59 )
60 text: str = prompt.format(
61 passages=re_ranked_documents,
62 question=input_state["question"]
63 )
64 messages: List[BaseMessage] = [
65 HumanMessage(
66 content=[
67 {
68 "type": "text",
69 "text": text
70 }
71 ]
72 )
73 ]
74 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"]
75 chain: RunnableSerializable = llm_model | StrOutputParser()
76 generated_answer: str = await chain.ainvoke(
77 input=messages
78 )
79 await self.two_datastore.async_client.set(
80 name=generated_answer_hash,
81 value=generated_answer.encode()
82 )
83 else:
84 generated_answer_byte: bytes = await self.two_datastore.async_client.get(generated_answer_hash)
85 generated_answer: str = generated_answer_byte.decode()
87 output_state["generated_answer"] = generated_answer
88 output_state["generated_answer_hash"] = generated_answer_hash
90 return output_state
92 def get_generated_answer_hash(
93 self,
94 re_ranked_document_ids: List[str],
95 question: str,
96 llm_model_name: str,
97 prompt: str,
98 max_token: int,
99 ) -> str:
100 data: Dict[str, Any] = {
101 "re_ranked_document_ids": re_ranked_document_ids,
102 "question": question,
103 "llm_model_name": llm_model_name,
104 "prompt": prompt,
105 "max_token": max_token,
106 }
107 hashed_data: str = cache_tool.hash_by_dict(
108 data=data
109 )
110 hashed_data = f"generated_answer/{hashed_data}"
112 return hashed_data
114 async def node_grade_hallucination(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState:
115 output_state: LongFormQaGraphState = input_state
117 re_ranked_documents: List[Document] = input_state["re_ranked_documents"]
119 class GradeTool(BaseModel):
120 """Binary score for support check."""
121 binary_score: bool = Field(
122 description="Is supported binary score, either True if supported or False if not supported."
123 )
125 retriever: HybridMilvusRetriever = input_state["retriever_setting"]["retriever"]
126 generated_hallucination_grade_hash: str = self.get_generated_hallucination_grade_hash(
127 retrieved_document_ids=[document.metadata[retriever.id_key] for document in re_ranked_documents],
128 generated_answer_hash=input_state["generated_answer_hash"]
129 )
130 existing_generated_hallucination_grade_hash: int = await self.two_datastore.async_client.exists(
131 generated_hallucination_grade_hash)
132 if existing_generated_hallucination_grade_hash == 0:
133 is_generated_hallucination_grade_hash_exist: bool = False
134 elif existing_generated_hallucination_grade_hash == 1:
135 is_generated_hallucination_grade_hash_exist: bool = True
136 else:
137 raise use_case_exception.ExistingGeneratedHallucinationGradeHashInvalid()
139 is_force_refresh_generated_hallucination_grade: bool = input_state["generator_setting"][
140 "is_force_refresh_generated_hallucination_grade"]
141 if is_generated_hallucination_grade_hash_exist is False or is_force_refresh_generated_hallucination_grade is True:
142 prompt: PromptTemplate = PromptTemplate(
143 template="""
144 <instruction>
145 Assess whether an Large Language Model generated answer to the question is supported by the passages. Give one binary score of "True" or "False". "True" means that the answer to the question is supported by the passages. "False" means that the answer to the question is not supported by the passages.
146 <instruction/>
147 <passages>
148 {% for passage in passages %}
149 <passage_{{ loop.index }}>
150 {{ passage.page_content }}
151 <passage_{{ loop.index }}/>
152 {% endfor %}
153 <passages/>
154 <question>
155 {{ question }}
156 <question/>
157 <answer>
158 {{ answer }}
159 <answer/>
160 """,
161 template_format="jinja2",
162 input_variables=["passages", "question", "answer"]
163 )
164 text: str = prompt.format(
165 passages=re_ranked_documents,
166 question=input_state["question"],
167 answer=input_state["generated_answer"]
168 )
169 messages: List[BaseMessage] = [
170 HumanMessage(
171 content=[
172 {
173 "type": "text",
174 "text": text
175 }
176 ]
177 )
178 ]
179 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"]
180 tool_parser: PydanticToolsParser = PydanticToolsParser(
181 tools=[GradeTool]
182 )
183 chain: RunnableSerializable = llm_model.bind_tools(tools=tool_parser.tools,
184 tool_choice="required") | tool_parser
185 generated_tools: List[GradeTool] = await chain.ainvoke(
186 input=messages
187 )
188 generated_hallucination_grade: str = str(not generated_tools[0].binary_score)
189 await self.two_datastore.async_client.set(
190 name=generated_hallucination_grade_hash,
191 value=generated_hallucination_grade.encode()
192 )
193 else:
194 generated_hallucination_grade_byte: bytes = await self.two_datastore.async_client.get(
195 generated_hallucination_grade_hash)
196 generated_hallucination_grade: str = generated_hallucination_grade_byte.decode()
198 output_state["generated_hallucination_grade"] = generated_hallucination_grade
199 output_state["generated_hallucination_grade_hash"] = generated_hallucination_grade_hash
201 return output_state
203 def get_generated_hallucination_grade_hash(
204 self,
205 retrieved_document_ids: List[str],
206 generated_answer_hash: str,
207 ) -> str:
208 data: Dict[str, Any] = {
209 "retrieved_document_ids": retrieved_document_ids,
210 "generated_answer_hash": generated_answer_hash,
211 }
212 hashed_data: str = cache_tool.hash_by_dict(
213 data=data
214 )
215 hashed_data = f"generated_hallucination_grade/{hashed_data}"
217 return hashed_data
219 async def node_grade_answer_relevancy(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState:
220 output_state: LongFormQaGraphState = input_state
222 class GradeTool(BaseModel):
223 """Binary score for resolution check."""
224 binary_score: bool = Field(
225 description="Is resolved binary score, either True if resolved or False if not resolved."
226 )
228 generated_answer_relevancy_grade_hash: str = self.get_generated_answer_relevancy_grade_hash(
229 question=input_state["question"],
230 generated_answer_hash=input_state["generated_answer_hash"]
231 )
232 existing_generated_hallucination_grade_hash: int = await self.two_datastore.async_client.exists(
233 generated_answer_relevancy_grade_hash)
234 if existing_generated_hallucination_grade_hash == 0:
235 is_generated_hallucination_grade_hash_exist: bool = False
236 elif existing_generated_hallucination_grade_hash == 1:
237 is_generated_hallucination_grade_hash_exist: bool = True
238 else:
239 raise use_case_exception.ExistingGeneratedAnswerRelevancyGradeHashInvalid()
241 is_force_refresh_generated_answer_relevancy_grade: bool = input_state["generator_setting"][
242 "is_force_refresh_generated_answer_relevancy_grade"]
243 if is_generated_hallucination_grade_hash_exist is False or is_force_refresh_generated_answer_relevancy_grade is True:
244 prompt: PromptTemplate = PromptTemplate(
245 template="""
246 <instruction>
247 Assess whether an Large Language Model generated answer resolves a question. Give one binary score of "True" or "False". "True" means that the answer resolves the question. "False" means that the answer does not resolve the question.
248 <instruction/>
249 <question>
250 {{ question }}
251 <question/>
252 <answer>
253 {{ answer }}
254 <answer/>
255 """,
256 template_format="jinja2",
257 input_variables=["question", "answer"]
258 )
259 text: str = prompt.format(
260 question=input_state["question"],
261 answer=input_state["generated_answer"],
262 )
263 messages: List[BaseMessage] = [
264 HumanMessage(
265 content=[
266 {
267 "type": "text",
268 "text": text
269 }
270 ]
271 )
272 ]
273 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"]
274 tool_parser: PydanticToolsParser = PydanticToolsParser(
275 tools=[GradeTool]
276 )
277 chain: RunnableSerializable = llm_model.bind_tools(tools=tool_parser.tools,
278 tool_choice="required") | tool_parser
279 generated_tools: List[GradeTool] = await chain.ainvoke(
280 input=messages
281 )
282 generated_answer_relevancy_grade: str = str(generated_tools[0].binary_score)
283 await self.two_datastore.async_client.set(
284 name=generated_answer_relevancy_grade_hash,
285 value=generated_answer_relevancy_grade.encode()
286 )
287 else:
288 generated_answer_relevancy_grade_byte: bytes = await self.two_datastore.async_client.get(
289 generated_answer_relevancy_grade_hash
290 )
291 generated_answer_relevancy_grade: str = generated_answer_relevancy_grade_byte.decode()
293 output_state["generated_answer_relevancy_grade"] = generated_answer_relevancy_grade
294 output_state["generated_answer_relevancy_grade_hash"] = generated_answer_relevancy_grade_hash
296 return output_state
298 def get_generated_answer_relevancy_grade_hash(
299 self,
300 question: str,
301 generated_answer_hash: str,
302 ) -> str:
303 data: Dict[str, Any] = {
304 "question": question,
305 "generated_answer_hash": generated_answer_hash,
306 }
307 hashed_data: str = cache_tool.hash_by_dict(
308 data=data
309 )
310 hashed_data = f"generated_answer_relevancy_grade/{hashed_data}"
312 return hashed_data
314 def node_decide_transform_question_or_grade_answer_relevancy(self, input_state: LongFormQaGraphState) -> str:
315 output_state: LongFormQaGraphState = input_state
317 generated_hallucination_grade: str = input_state["generated_hallucination_grade"]
318 if generated_hallucination_grade == "False":
319 return "GRADE_ANSWER_RELEVANCY"
321 transform_question_max_retry: int = input_state["transform_question_max_retry"]
322 transform_question_current_retry: int = input_state["state"].transform_question_current_retry
324 if transform_question_current_retry >= transform_question_max_retry:
325 return "MAX_RETRY"
327 output_state["state"].transform_question_current_retry += 1
329 return "TRANSFORM_QUESTION"
331 def node_decide_transform_question_or_provide_answer(self, input_state: LongFormQaGraphState) -> str:
332 output_state: LongFormQaGraphState = input_state
334 generated_answer_relevancy_grade: str = input_state["generated_answer_relevancy_grade"]
335 if generated_answer_relevancy_grade == "True":
336 return "PROVIDE_ANSWER"
338 transform_question_max_retry: int = input_state["transform_question_max_retry"]
339 transform_question_current_retry: int = input_state["state"].transform_question_current_retry
340 if transform_question_current_retry >= transform_question_max_retry:
341 return "MAX_RETRY"
343 output_state["state"].transform_question_current_retry += 1
345 return "TRANSFORM_QUESTION"
347 async def node_transform_question(self, input_state: LongFormQaGraphState) -> LongFormQaGraphState:
348 output_state: LongFormQaGraphState = input_state
350 generated_question_hash: str = self.get_transformed_question_hash(
351 question=input_state["question"]
352 )
353 existing_generated_question_hash: int = await self.two_datastore.async_client.exists(
354 generated_question_hash
355 )
356 if existing_generated_question_hash == 0:
357 is_generated_question_exist: bool = False
358 elif existing_generated_question_hash == 1:
359 is_generated_question_exist: bool = True
360 else:
361 raise use_case_exception.ExistingGeneratedQuestionHashInvalid()
363 is_force_refresh_generated_question: bool = input_state["generator_setting"][
364 "is_force_refresh_generated_question"]
365 if is_generated_question_exist is False or is_force_refresh_generated_question is True:
366 prompt: PromptTemplate = PromptTemplate(
367 template="""
368 <instruction>
369 Converts the question to a better version that is optimized for vector store retrieval. Observe the question and try to reason about underlying semantics. Ensure the output is only the question without re-explain the instruction.
370 <instruction/>
371 <question>
372 {question}
373 <question/>
374 """,
375 input_variables=["question"]
376 )
377 text: str = prompt.format(
378 question=input_state["question"]
379 )
380 messages: List[BaseMessage] = [
381 HumanMessage(
382 content=[
383 {
384 "type": "text",
385 "text": text
386 }
387 ]
388 )
389 ]
390 llm_model: ChatLiteLLM = input_state["llm_setting"]["model"]
391 chain: RunnableSerializable = llm_model | StrOutputParser()
392 generated_question: str = await chain.ainvoke(
393 input=messages
394 )
395 await self.two_datastore.async_client.set(
396 name=generated_question_hash,
397 value=generated_question.encode()
398 )
399 else:
400 generated_question_byte: bytes = await self.two_datastore.async_client.get(
401 generated_question_hash
402 )
403 generated_question: str = generated_question_byte.decode()
405 output_state["question"] = generated_question
406 output_state["generated_question_hash"] = generated_question_hash
408 return output_state
410 def get_transformed_question_hash(
411 self,
412 question: str,
413 ) -> str:
414 data: Dict[str, Any] = {
415 "question": question,
416 }
417 hashed_data: str = cache_tool.hash_by_dict(
418 data=data
419 )
420 hashed_data = f"transformed_question/{hashed_data}"
422 return hashed_data
424 def compile(self) -> CompiledGraph:
425 graph: StateGraph = StateGraph(LongFormQaGraphState)
427 graph.add_node(
428 node=self.node_get_llm_model.__name__,
429 action=self.node_get_llm_model
430 )
431 graph.add_node(
432 node=self.node_get_categorized_documents.__name__,
433 action=self.node_get_categorized_documents
434 )
435 graph.add_node(
436 node=self.node_embed.__name__,
437 action=self.node_embed
438 )
439 graph.add_node(
440 node=self.node_get_relevant_documents.__name__,
441 action=self.node_get_relevant_documents
442 )
443 graph.add_node(
444 node=self.node_get_re_ranked_documents.__name__,
445 action=self.node_get_re_ranked_documents
446 )
447 graph.add_node(
448 node=self.node_generate_answer.__name__,
449 action=self.node_generate_answer
450 )
451 graph.add_node(
452 node=self.node_grade_hallucination.__name__,
453 action=self.node_grade_hallucination
454 )
455 graph.add_node(
456 node=self.node_grade_answer_relevancy.__name__,
457 action=self.node_grade_answer_relevancy
458 )
459 graph.add_node(
460 node=self.node_transform_question.__name__,
461 action=self.node_transform_question
462 )
464 graph.set_entry_point(
465 key=self.node_get_llm_model.__name__
466 )
467 graph.add_edge(
468 start_key=self.node_get_llm_model.__name__,
469 end_key=self.node_get_categorized_documents.__name__
470 )
471 graph.add_edge(
472 start_key=self.node_get_categorized_documents.__name__,
473 end_key=self.node_embed.__name__
474 )
475 graph.add_edge(
476 start_key=self.node_embed.__name__,
477 end_key=self.node_get_relevant_documents.__name__
478 )
479 graph.add_edge(
480 start_key=self.node_get_relevant_documents.__name__,
481 end_key=self.node_get_re_ranked_documents.__name__
482 )
483 graph.add_edge(
484 start_key=self.node_get_re_ranked_documents.__name__,
485 end_key=self.node_generate_answer.__name__
486 )
487 graph.add_edge(
488 start_key=self.node_generate_answer.__name__,
489 end_key=self.node_grade_hallucination.__name__
490 )
491 graph.add_conditional_edges(
492 source=self.node_grade_hallucination.__name__,
493 path=self.node_decide_transform_question_or_grade_answer_relevancy,
494 path_map={
495 "MAX_RETRY": END,
496 "GRADE_ANSWER_RELEVANCY": self.node_grade_answer_relevancy.__name__,
497 "TRANSFORM_QUESTION": self.node_transform_question.__name__
498 }
499 )
500 graph.add_conditional_edges(
501 source=self.node_grade_answer_relevancy.__name__,
502 path=self.node_decide_transform_question_or_provide_answer,
503 path_map={
504 "MAX_RETRY": END,
505 "PROVIDE_ANSWER": END,
506 "TRANSFORM_QUESTION": self.node_transform_question.__name__
507 }
508 )
509 graph.add_edge(
510 start_key=self.node_transform_question.__name__,
511 end_key=self.node_get_relevant_documents.__name__
512 )
514 compiled_graph: CompiledGraph = graph.compile()
516 return compiled_graph