Coverage for apps/inners/use_cases/vector_stores/bge_m3_milvus_vector_store.py: 98%
42 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
1from typing import Any, Dict, List
3from pymilvus import FieldSchema, DataType, CollectionSchema, Hits, AnnSearchRequest, SearchResult, RRFRanker, \
4 Collection
6from apps.inners.use_cases.embeddings.bge_m3_embedding import BgeM3Embedding
7from apps.inners.use_cases.vector_stores.base_milvus_vector_store import BaseMilvusVectorStore
10class BgeM3MilvusVectorStore(BaseMilvusVectorStore):
12 def __init__(
13 self,
14 embedding_model: BgeM3Embedding,
15 *args: Any,
16 sparse_vector_field_name: str = "sparse_vector",
17 dense_vector_field_name: str = "dense_vector",
18 sparse_vector_index_type: str = "SPARSE_INVERTED_INDEX",
19 dense_vector_index_type: str = "GPU_CAGRA",
20 **kwargs: Any
21 ):
22 vector_field_dimensions: Dict[str, Any] = {
23 sparse_vector_field_name: embedding_model.dimensions["sparse"],
24 dense_vector_field_name: embedding_model.dimensions["dense"]
25 }
26 kwargs["vector_field_dimensions"] = vector_field_dimensions
27 self.embedding_model = embedding_model
28 self.sparse_vector_field_name = sparse_vector_field_name
29 self.dense_vector_field_name = dense_vector_field_name
30 self.sparse_vector_index_type = sparse_vector_index_type
31 self.dense_vector_index_type = dense_vector_index_type
32 super().__init__(*args, **kwargs)
34 def _create_index(self):
35 sparse_vector_field_index_params: Dict[str, Any] = self.get_search_params(self.sparse_vector_index_type)
36 sparse_vector_field_index_params["index_type"] = self.sparse_vector_index_type
37 self.collection.create_index(
38 field_name=self.sparse_vector_field_name,
39 index_params=sparse_vector_field_index_params
40 )
41 dense_vector_field_index_params: Dict[str, Any] = self.get_search_params(self.dense_vector_index_type)
42 dense_vector_field_index_params["index_type"] = self.dense_vector_index_type
43 self.collection.create_index(
44 field_name=self.dense_vector_field_name,
45 index_params=dense_vector_field_index_params
46 )
48 def _create_collection(self):
49 fields: List[FieldSchema] = [
50 FieldSchema(
51 name=self.id_field_name,
52 dtype=DataType.VARCHAR,
53 is_primary=True,
54 auto_id=False,
55 max_length=65535
56 ),
57 FieldSchema(
58 name=self.sparse_vector_field_name,
59 dtype=DataType.SPARSE_FLOAT_VECTOR,
60 ),
61 FieldSchema(
62 name=self.dense_vector_field_name,
63 dtype=DataType.FLOAT_VECTOR,
64 dim=self.vector_field_dimensions[self.dense_vector_field_name]
65 )
66 ]
68 schema = CollectionSchema(
69 fields=fields,
70 )
72 collection: Collection = Collection(
73 name=self.collection_name,
74 schema=schema,
75 using=self.alias,
76 consistency_level=self.consistency_level,
77 )
78 if self.collection_properties is not None:
79 self.collection.set_properties(self.collection_properties)
81 return collection
83 def embed_texts(
84 self,
85 texts: List[str],
86 ids: List[str],
87 batch_size: int = 1000
88 ):
89 embeddings: Dict[str, Any] = self.embedding_model.encode_documents(texts)
91 total_count: int = len(ids)
92 for start_index in range(0, total_count, batch_size):
93 end_index: int = min(start_index + batch_size, total_count)
94 data: List[Any] = [
95 ids[start_index:end_index],
96 embeddings["sparse"][start_index:end_index],
97 embeddings["dense"][start_index:end_index],
98 ]
99 self.collection.insert(data)
101 def search(self, query: str, top_k: int) -> Hits:
102 embeddings: Dict[str, Any] = self.embedding_model.encode_queries(texts=[query])
104 output_fields = [
105 self.id_field_name,
106 ]
107 search_requests: List[AnnSearchRequest] = [
108 AnnSearchRequest(
109 data=embeddings["sparse"],
110 anns_field=self.sparse_vector_field_name,
111 limit=top_k,
112 param=self.get_search_params(self.sparse_vector_index_type)
113 ),
114 AnnSearchRequest(
115 data=embeddings["dense"],
116 anns_field=self.dense_vector_field_name,
117 limit=top_k,
118 param=self.get_search_params(self.dense_vector_index_type)
119 )
120 ]
121 search_result: SearchResult = self.collection.hybrid_search(
122 reqs=search_requests,
123 output_fields=output_fields,
124 rerank=RRFRanker(),
125 limit=top_k,
127 )
128 outputs: Hits = search_result[0]
130 return outputs