Coverage for apps/inners/use_cases/vector_stores/bge_m3_milvus_vector_store.py: 98%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1from typing import Any, Dict, List 

2 

3from pymilvus import FieldSchema, DataType, CollectionSchema, Hits, AnnSearchRequest, SearchResult, RRFRanker, \ 

4 Collection 

5 

6from apps.inners.use_cases.embeddings.bge_m3_embedding import BgeM3Embedding 

7from apps.inners.use_cases.vector_stores.base_milvus_vector_store import BaseMilvusVectorStore 

8 

9 

10class BgeM3MilvusVectorStore(BaseMilvusVectorStore): 

11 

12 def __init__( 

13 self, 

14 embedding_model: BgeM3Embedding, 

15 *args: Any, 

16 sparse_vector_field_name: str = "sparse_vector", 

17 dense_vector_field_name: str = "dense_vector", 

18 sparse_vector_index_type: str = "SPARSE_INVERTED_INDEX", 

19 dense_vector_index_type: str = "GPU_CAGRA", 

20 **kwargs: Any 

21 ): 

22 vector_field_dimensions: Dict[str, Any] = { 

23 sparse_vector_field_name: embedding_model.dimensions["sparse"], 

24 dense_vector_field_name: embedding_model.dimensions["dense"] 

25 } 

26 kwargs["vector_field_dimensions"] = vector_field_dimensions 

27 self.embedding_model = embedding_model 

28 self.sparse_vector_field_name = sparse_vector_field_name 

29 self.dense_vector_field_name = dense_vector_field_name 

30 self.sparse_vector_index_type = sparse_vector_index_type 

31 self.dense_vector_index_type = dense_vector_index_type 

32 super().__init__(*args, **kwargs) 

33 

34 def _create_index(self): 

35 sparse_vector_field_index_params: Dict[str, Any] = self.get_search_params(self.sparse_vector_index_type) 

36 sparse_vector_field_index_params["index_type"] = self.sparse_vector_index_type 

37 self.collection.create_index( 

38 field_name=self.sparse_vector_field_name, 

39 index_params=sparse_vector_field_index_params 

40 ) 

41 dense_vector_field_index_params: Dict[str, Any] = self.get_search_params(self.dense_vector_index_type) 

42 dense_vector_field_index_params["index_type"] = self.dense_vector_index_type 

43 self.collection.create_index( 

44 field_name=self.dense_vector_field_name, 

45 index_params=dense_vector_field_index_params 

46 ) 

47 

48 def _create_collection(self): 

49 fields: List[FieldSchema] = [ 

50 FieldSchema( 

51 name=self.id_field_name, 

52 dtype=DataType.VARCHAR, 

53 is_primary=True, 

54 auto_id=False, 

55 max_length=65535 

56 ), 

57 FieldSchema( 

58 name=self.sparse_vector_field_name, 

59 dtype=DataType.SPARSE_FLOAT_VECTOR, 

60 ), 

61 FieldSchema( 

62 name=self.dense_vector_field_name, 

63 dtype=DataType.FLOAT_VECTOR, 

64 dim=self.vector_field_dimensions[self.dense_vector_field_name] 

65 ) 

66 ] 

67 

68 schema = CollectionSchema( 

69 fields=fields, 

70 ) 

71 

72 collection: Collection = Collection( 

73 name=self.collection_name, 

74 schema=schema, 

75 using=self.alias, 

76 consistency_level=self.consistency_level, 

77 ) 

78 if self.collection_properties is not None: 

79 self.collection.set_properties(self.collection_properties) 

80 

81 return collection 

82 

83 def embed_texts( 

84 self, 

85 texts: List[str], 

86 ids: List[str], 

87 batch_size: int = 1000 

88 ): 

89 embeddings: Dict[str, Any] = self.embedding_model.encode_documents(texts) 

90 

91 total_count: int = len(ids) 

92 for start_index in range(0, total_count, batch_size): 

93 end_index: int = min(start_index + batch_size, total_count) 

94 data: List[Any] = [ 

95 ids[start_index:end_index], 

96 embeddings["sparse"][start_index:end_index], 

97 embeddings["dense"][start_index:end_index], 

98 ] 

99 self.collection.insert(data) 

100 

101 def search(self, query: str, top_k: int) -> Hits: 

102 embeddings: Dict[str, Any] = self.embedding_model.encode_queries(texts=[query]) 

103 

104 output_fields = [ 

105 self.id_field_name, 

106 ] 

107 search_requests: List[AnnSearchRequest] = [ 

108 AnnSearchRequest( 

109 data=embeddings["sparse"], 

110 anns_field=self.sparse_vector_field_name, 

111 limit=top_k, 

112 param=self.get_search_params(self.sparse_vector_index_type) 

113 ), 

114 AnnSearchRequest( 

115 data=embeddings["dense"], 

116 anns_field=self.dense_vector_field_name, 

117 limit=top_k, 

118 param=self.get_search_params(self.dense_vector_index_type) 

119 ) 

120 ] 

121 search_result: SearchResult = self.collection.hybrid_search( 

122 reqs=search_requests, 

123 output_fields=output_fields, 

124 rerank=RRFRanker(), 

125 limit=top_k, 

126 

127 ) 

128 outputs: Hits = search_result[0] 

129 

130 return outputs