Coverage for apps/inners/use_cases/embeddings/bge_m3_embedding.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1from typing import Optional, List, Dict, Any 

2 

3from milvus_model.hybrid import BGEM3EmbeddingFunction 

4 

5from apps.inners.use_cases.embeddings.base_embedding import BaseEmbedding 

6 

7 

8class BgeM3Embedding(BaseEmbedding): 

9 

10 def __init__( 

11 self, 

12 model_name: str = "BAAI/bge-m3", 

13 batch_size: int = 16, 

14 device: Optional[str] = None, 

15 normalize_embeddings: bool = True, 

16 use_fp16: bool = True, 

17 return_dense: bool = True, 

18 return_sparse: bool = True, 

19 return_colbert_vecs: bool = True, 

20 ): 

21 self.model_name = model_name 

22 self.batch_size = batch_size 

23 self.device = device 

24 self.normalize_embeddings = normalize_embeddings 

25 self.use_fp16 = use_fp16 

26 self.return_dense = return_dense 

27 self.return_sparse = return_sparse 

28 self.return_colbert_vecs = return_colbert_vecs 

29 self._embedding_model: BGEM3EmbeddingFunction = BGEM3EmbeddingFunction( 

30 model_name=self.model_name, 

31 batch_size=self.batch_size, 

32 device=self.device, 

33 normalize_embeddings=self.normalize_embeddings, 

34 use_fp16=self.use_fp16, 

35 return_dense=self.return_dense, 

36 return_sparse=self.return_sparse, 

37 return_colbert_vecs=self.return_colbert_vecs, 

38 ) 

39 

40 def encode_documents(self, texts: List[str]) -> Dict[str, Any]: 

41 return self._embedding_model.encode_documents(texts) 

42 

43 def encode_queries(self, texts: List[str]) -> Dict[str, Any]: 

44 return self._embedding_model.encode_queries(texts) 

45 

46 @property 

47 def dimensions(self) -> Dict[str, Any]: 

48 return self._embedding_model.dim