Coverage for apps/inners/use_cases/vector_stores/base_milvus_vector_store.py: 77%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1from abc import abstractmethod, ABC 

2from typing import List, Dict, Any, Optional 

3 

4from pymilvus import Hits, Collection 

5from pymilvus.client.types import LoadState 

6from pymilvus.orm import utility 

7 

8 

9class BaseMilvusVectorStore(ABC): 

10 

11 def __init__( 

12 self, 

13 collection_name: str, 

14 vector_field_dimensions: Dict[str, Any], 

15 alias: str = None, 

16 consistency_level: str = "Strong", 

17 collection_properties: Optional[Dict[str, Any]] = None, 

18 drop_old_collection: bool = False, 

19 id_field_name: str = "id", 

20 search_params: Optional[Dict[str, Any]] = None, 

21 ): 

22 self.collection_name = collection_name 

23 self.vector_field_dimensions = vector_field_dimensions 

24 self.alias = alias 

25 self.consistency_level = consistency_level 

26 self.collection_properties = collection_properties 

27 self.drop_old_collection = drop_old_collection 

28 self.id_field_name = id_field_name 

29 self.search_params = search_params 

30 self._default_search_params = { 

31 "SPARSE_INVERTED_INDEX": {"metric_type": "IP"}, 

32 "IVF_FLAT": {"metric_type": "IP"}, 

33 "IVF_SQ8": {"metric_type": "IP"}, 

34 "IVF_PQ": {"metric_type": "IP"}, 

35 "HNSW": {"metric_type": "IP"}, 

36 "RHNSW_FLAT": {"metric_type": "IP"}, 

37 "RHNSW_SQ": {"metric_type": "IP"}, 

38 "RHNSW_PQ": {"metric_type": "IP"}, 

39 "IVF_HNSW": {"metric_type": "IP"}, 

40 "ANNOY": {"metric_type": "IP"}, 

41 "SCANN": {"metric_type": "IP"}, 

42 "AUTOINDEX": {"metric_type": "IP"}, 

43 "GPU_CAGRA": {"metric_type": "IP"}, 

44 "GPU_IVF_FLAT": {"metric_type": "IP"}, 

45 "GPU_IVF_PQ": {"metric_type": "IP"}, 

46 } 

47 self.collection: Optional[Collection] = None 

48 self.initialize_collection() 

49 

50 def get_search_params(self, index_type: str) -> Dict[str, Any]: 

51 if self.search_params is not None: 

52 return self.search_params 

53 else: 

54 return self._default_search_params[index_type] 

55 

56 def initialize_collection(self): 

57 if self.has_collection(): 

58 self.collection = Collection( 

59 self.collection_name, 

60 using=self.alias, 

61 ) 

62 

63 if self.collection_properties is not None: 

64 self.collection.set_properties(self.collection_properties) 

65 

66 if self.drop_old_collection: 

67 self.drop_collection() 

68 else: 

69 self.collection = None 

70 

71 if self.collection is None: 

72 self.collection = self._create_collection() 

73 

74 self._create_index() 

75 

76 if utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad: 

77 self.collection.load() 

78 

79 def drop_collection(self): 

80 self.collection.drop() 

81 self.collection = None 

82 

83 def has_collection(self) -> bool: 

84 return utility.has_collection(self.collection_name, using=self.alias) 

85 

86 @abstractmethod 

87 def _create_index(self): 

88 pass 

89 

90 @abstractmethod 

91 def _create_collection(self): 

92 pass 

93 

94 @abstractmethod 

95 def embed_texts( 

96 self, 

97 texts: List[str], 

98 ids: List[str], 

99 batch_size: int = 1000 

100 ): 

101 pass 

102 

103 @abstractmethod 

104 def search(self, query: str, top_k: int) -> Hits: 

105 pass