Coverage for apps/inners/use_cases/long_form_qas/process_long_form_qa.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1from datetime import datetime, timezone 

2from typing import List 

3 

4from langchain_core.runnables import RunnableConfig 

5from starlette.datastructures import State 

6 

7from apps.inners.models.dtos.contracts.requests.long_form_qas.process_body import ProcessBody 

8from apps.inners.models.dtos.contracts.responses.long_form_qas.process_response import ProcessResponse 

9from apps.inners.models.dtos.contracts.responses.passage_searches.process_response import ReRankedDocument 

10from apps.inners.models.dtos.graph_state import LongFormQaGraphState 

11from apps.inners.use_cases.graphs.long_form_qa_graph import LongFormQaGraph 

12 

13 

14class ProcessLongFormQa: 

15 def __init__( 

16 self, 

17 long_form_qa_graph: LongFormQaGraph, 

18 ): 

19 self.long_form_qa_graph = long_form_qa_graph 

20 

21 async def process(self, state: State, body: ProcessBody) -> ProcessResponse: 

22 started_at: datetime = datetime.now(tz=timezone.utc) 

23 state.next_document_id = None 

24 state.next_categorized_document = None 

25 state.transform_question_current_retry = 0 

26 input_state: LongFormQaGraphState = { 

27 "state": state, 

28 "document_ids": body.input_setting.document_ids, 

29 "llm_setting": { 

30 "model_name": body.input_setting.llm_setting.model_name, 

31 "max_token": body.input_setting.llm_setting.max_token, 

32 "model": None, 

33 }, 

34 "preprocessor_setting": { 

35 "is_force_refresh_categorized_element": body.input_setting.preprocessor_setting.is_force_refresh_categorized_element, 

36 "is_force_refresh_categorized_document": body.input_setting.preprocessor_setting.is_force_refresh_categorized_document, 

37 "file_partition_strategy": body.input_setting.preprocessor_setting.file_partition_strategy, 

38 "chunk_size": body.input_setting.preprocessor_setting.chunk_size, 

39 "overlap_size": body.input_setting.preprocessor_setting.overlap_size, 

40 "is_include_table": body.input_setting.preprocessor_setting.is_include_table, 

41 "is_include_image": body.input_setting.preprocessor_setting.is_include_image, 

42 }, 

43 "categorized_element_hashes": None, 

44 "categorized_documents": None, 

45 "categorized_document_hashes": None, 

46 "embedder_setting": { 

47 "is_force_refresh_embedding": body.input_setting.embedder_setting.is_force_refresh_embedding, 

48 "is_force_refresh_document": body.input_setting.embedder_setting.is_force_refresh_document, 

49 "model_name": body.input_setting.embedder_setting.model_name, 

50 "query_instruction": body.input_setting.embedder_setting.query_instruction, 

51 }, 

52 "retriever_setting": { 

53 "is_force_refresh_relevant_document": body.input_setting.retriever_setting.is_force_refresh_relevant_document, 

54 "top_k": body.input_setting.retriever_setting.top_k, 

55 }, 

56 "reranker_setting": { 

57 "is_force_refresh_re_ranked_document": body.input_setting.reranker_setting.is_force_refresh_re_ranked_document, 

58 "model_name": body.input_setting.reranker_setting.model_name, 

59 "top_k": body.input_setting.reranker_setting.top_k, 

60 }, 

61 "embedded_document_ids": None, 

62 "relevant_documents": None, 

63 "relevant_document_hash": None, 

64 "re_ranked_documents": None, 

65 "re_ranked_document_hash": None, 

66 "question": body.input_setting.question, 

67 "generator_setting": { 

68 "is_force_refresh_generated_answer": body.input_setting.generator_setting.is_force_refresh_generated_answer, 

69 "is_force_refresh_generated_question": body.input_setting.generator_setting.is_force_refresh_generated_question, 

70 "is_force_refresh_generated_hallucination_grade": body.input_setting.generator_setting.is_force_refresh_generated_hallucination_grade, 

71 "is_force_refresh_generated_answer_relevancy_grade": body.input_setting.generator_setting.is_force_refresh_generated_answer_relevancy_grade, 

72 "prompt": body.input_setting.generator_setting.prompt, 

73 }, 

74 "transform_question_max_retry": body.input_setting.transform_question_max_retry, 

75 "generated_answer": None, 

76 "generated_answer_hash": None, 

77 "generated_question": None, 

78 "generated_question_hash": None, 

79 "generated_hallucination_grade": "False", 

80 "generated_hallucination_grade_hash": None, 

81 "generated_answer_relevancy_grade": "False", 

82 "generated_answer_relevancy_grade_hash": None, 

83 } 

84 graph_config: RunnableConfig = { 

85 "recursion_limit": 1000, 

86 } 

87 output_state: LongFormQaGraphState = await self.long_form_qa_graph.compiled_graph.ainvoke( 

88 input=input_state, 

89 config=graph_config 

90 ) 

91 

92 re_ranked_document_dicts: List[ReRankedDocument] = [ 

93 ReRankedDocument(**re_ranked_document.dict()) 

94 for re_ranked_document in output_state["re_ranked_documents"] 

95 ] 

96 finished_at: datetime = datetime.now(timezone.utc) 

97 process_response: ProcessResponse = ProcessResponse( 

98 re_ranked_documents=re_ranked_document_dicts, 

99 generated_answer=output_state["generated_answer"], 

100 hallucination_grade=output_state["generated_hallucination_grade"], 

101 answer_relevancy_grade=output_state["generated_answer_relevancy_grade"], 

102 started_at=started_at, 

103 finished_at=finished_at, 

104 ) 

105 

106 return process_response