Coverage for apps/inners/use_cases/vector_stores/base_milvus_vector_store.py: 77%
52 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
1from abc import abstractmethod, ABC
2from typing import List, Dict, Any, Optional
4from pymilvus import Hits, Collection
5from pymilvus.client.types import LoadState
6from pymilvus.orm import utility
9class BaseMilvusVectorStore(ABC):
11 def __init__(
12 self,
13 collection_name: str,
14 vector_field_dimensions: Dict[str, Any],
15 alias: str = None,
16 consistency_level: str = "Strong",
17 collection_properties: Optional[Dict[str, Any]] = None,
18 drop_old_collection: bool = False,
19 id_field_name: str = "id",
20 search_params: Optional[Dict[str, Any]] = None,
21 ):
22 self.collection_name = collection_name
23 self.vector_field_dimensions = vector_field_dimensions
24 self.alias = alias
25 self.consistency_level = consistency_level
26 self.collection_properties = collection_properties
27 self.drop_old_collection = drop_old_collection
28 self.id_field_name = id_field_name
29 self.search_params = search_params
30 self._default_search_params = {
31 "SPARSE_INVERTED_INDEX": {"metric_type": "IP"},
32 "IVF_FLAT": {"metric_type": "IP"},
33 "IVF_SQ8": {"metric_type": "IP"},
34 "IVF_PQ": {"metric_type": "IP"},
35 "HNSW": {"metric_type": "IP"},
36 "RHNSW_FLAT": {"metric_type": "IP"},
37 "RHNSW_SQ": {"metric_type": "IP"},
38 "RHNSW_PQ": {"metric_type": "IP"},
39 "IVF_HNSW": {"metric_type": "IP"},
40 "ANNOY": {"metric_type": "IP"},
41 "SCANN": {"metric_type": "IP"},
42 "AUTOINDEX": {"metric_type": "IP"},
43 "GPU_CAGRA": {"metric_type": "IP"},
44 "GPU_IVF_FLAT": {"metric_type": "IP"},
45 "GPU_IVF_PQ": {"metric_type": "IP"},
46 }
47 self.collection: Optional[Collection] = None
48 self.initialize_collection()
50 def get_search_params(self, index_type: str) -> Dict[str, Any]:
51 if self.search_params is not None:
52 return self.search_params
53 else:
54 return self._default_search_params[index_type]
56 def initialize_collection(self):
57 if self.has_collection():
58 self.collection = Collection(
59 self.collection_name,
60 using=self.alias,
61 )
63 if self.collection_properties is not None:
64 self.collection.set_properties(self.collection_properties)
66 if self.drop_old_collection:
67 self.drop_collection()
68 else:
69 self.collection = None
71 if self.collection is None:
72 self.collection = self._create_collection()
74 self._create_index()
76 if utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad:
77 self.collection.load()
79 def drop_collection(self):
80 self.collection.drop()
81 self.collection = None
83 def has_collection(self) -> bool:
84 return utility.has_collection(self.collection_name, using=self.alias)
86 @abstractmethod
87 def _create_index(self):
88 pass
90 @abstractmethod
91 def _create_collection(self):
92 pass
94 @abstractmethod
95 def embed_texts(
96 self,
97 texts: List[str],
98 ids: List[str],
99 batch_size: int = 1000
100 ):
101 pass
103 @abstractmethod
104 def search(self, query: str, top_k: int) -> Hits:
105 pass