Coverage for apps/inners/use_cases/document_processor/category_document_processor.py: 56%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1import base64 

2import os 

3import uuid 

4from typing import List, Tuple, Optional, Dict, Any 

5 

6import more_itertools 

7from langchain_community.chat_models import ChatLiteLLM 

8from langchain_core.documents import Document 

9from unstructured.chunking.basic import chunk_elements 

10from unstructured.documents.elements import Element, Text, NarrativeText, Table, Image, ListItem 

11 

12from apps.inners.models.dtos.document_category import DocumentCategory 

13from apps.inners.models.dtos.element_category import ElementCategory 

14from apps.inners.use_cases.document_processor.summary_document_processor import SummaryDocumentProcessor 

15 

16 

17class CategoryDocumentProcessor: 

18 def __init__( 

19 self, 

20 summary_document_processor: SummaryDocumentProcessor, 

21 ): 

22 self.summary_document_processor = summary_document_processor 

23 

24 async def categorize_elements(self, elements: List[Element]) -> ElementCategory: 

25 categorized_elements: ElementCategory = ElementCategory( 

26 texts=[], 

27 tables=[], 

28 images=[] 

29 ) 

30 

31 for element in elements: 

32 if any( 

33 element_type == element.__class__.__name__ for element_type in 

34 [Text.__name__, NarrativeText.__name__, ListItem.__name__] 

35 ): 

36 categorized_elements.texts.append(element) 

37 elif any( 

38 element_type == element.__class__.__name__ for element_type in 

39 [Table.__name__] 

40 ): 

41 categorized_elements.tables.append(element) 

42 elif any( 

43 element_type == element.__class__.__name__ for element_type in 

44 [Image.__name__] 

45 ): 

46 file_io = open(element.metadata.image_path, "rb") 

47 element.metadata.image_mime_type = "image/jpeg" 

48 element.metadata.image_base64 = base64.b64encode(file_io.read()).decode() 

49 file_io.close() 

50 os.remove(element.metadata.image_path) 

51 categorized_elements.images.append(element) 

52 else: 

53 print(f"BaseDocumentProcessor.categorize_elements: Ignoring element type {element.__class__.__name__}.") 

54 

55 return categorized_elements 

56 

57 async def get_categorized_documents( 

58 self, 

59 categorized_elements: ElementCategory, 

60 summarization_model: ChatLiteLLM, 

61 is_include_table: bool = False, 

62 is_include_image: bool = False, 

63 chunk_size: int = 400, 

64 overlap_size: int = 50, 

65 separators: Tuple[str, ...] = ("\n", " "), 

66 id_key: str = "id", 

67 metadata: Optional[Dict[str, Any]] = None, 

68 ) -> DocumentCategory: 

69 if metadata is None: 

70 metadata = {} 

71 

72 document_category: DocumentCategory = DocumentCategory( 

73 texts=[], 

74 tables=[], 

75 images=[], 

76 id_key=id_key 

77 ) 

78 chunked_texts: List[Element] = chunk_elements( 

79 elements=categorized_elements.texts, 

80 include_orig_elements=True, 

81 max_characters=chunk_size - overlap_size, 

82 ) 

83 if len(chunked_texts) < 2: 

84 for text in chunked_texts: 

85 orig_elements: List[Element] = text.metadata.orig_elements 

86 orig_element_metadata: List[Dict[str, Any]] = [ 

87 orig_element.metadata.to_dict() for orig_element in orig_elements 

88 ] 

89 document: Document = Document( 

90 page_content=text.text, 

91 metadata={ 

92 id_key: str(uuid.uuid4()), 

93 "category": "text", 

94 "origin_metadata": orig_element_metadata, 

95 **metadata 

96 } 

97 ) 

98 document_category.texts.append(document) 

99 else: 

100 for text_before, text_after in more_itertools.windowed(chunked_texts, n=2): 

101 orig_elements: List[Element] = [] 

102 for orig_element in text_before.metadata.orig_elements + text_after.metadata.orig_elements: 

103 if not any(orig_element.id == existing_orig_element.id for existing_orig_element in orig_elements): 

104 orig_elements.append(orig_element) 

105 orig_elements_metadata: Dict[str, Any] = { 

106 "origin_metadata": [ 

107 orig_element.metadata.to_dict() for orig_element in orig_elements 

108 ], 

109 "category": "text" 

110 } 

111 last_index_of_separators: int = -1 

112 for separator in separators: 

113 last_index_of_separator = text_before.text.rfind(separator, 0, len(text_before.text) - overlap_size) 

114 last_index_of_separators = max(last_index_of_separators, last_index_of_separator) 

115 

116 text = text_before.text[last_index_of_separators + 1:] + " " + text_after.text 

117 document: Document = Document( 

118 page_content=text, 

119 metadata={ 

120 id_key: str(uuid.uuid4()), 

121 **orig_elements_metadata, 

122 **metadata, 

123 } 

124 ) 

125 document_category.texts.append(document) 

126 

127 if is_include_table: 

128 summarized_tables: List[str] = await self.summary_document_processor.summarize_tables( 

129 tables=categorized_elements.tables, 

130 llm_model=summarization_model 

131 ) 

132 for table, summarized_table in zip(categorized_elements.tables, summarized_tables, strict=True): 

133 document: Document = Document( 

134 page_content=summarized_table, 

135 metadata={ 

136 id_key: str(uuid.uuid4()), 

137 "category": "table", 

138 **table.metadata.to_dict(), 

139 **metadata 

140 } 

141 ) 

142 document_category.tables.append(document) 

143 

144 if is_include_image: 

145 summarized_images: List[str] = await self.summary_document_processor.summarize_images( 

146 images=categorized_elements.images, 

147 llm_model=summarization_model 

148 ) 

149 for image, summarized_image in zip(categorized_elements.images, summarized_images, strict=True): 

150 document: Document = Document( 

151 page_content=summarized_image, 

152 metadata={ 

153 id_key: str(uuid.uuid4()), 

154 "category": "image", 

155 **image.metadata.to_dict(), 

156 **metadata 

157 } 

158 ) 

159 document_category.images.append(document) 

160 

161 return document_category