Coverage for apps/inners/use_cases/rerankers/bge_reranker.py: 48%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-22 19:03 +0000

1from typing import Optional, List, Dict, Any 

2 

3from FlagEmbedding import FlagLLMReranker, LayerWiseFlagLLMReranker, FlagReranker 

4from milvus_model.base import RerankResult, BaseRerankFunction 

5 

6from apps.inners.use_cases.rerankers.base_reranker import BaseReranker 

7 

8 

9class BgeRerankFunction(BaseRerankFunction): 

10 def __init__( 

11 self, 

12 model_name: str = "BAAI/bge-reranker-v2-m3", 

13 use_fp16: bool = True, 

14 batch_size: int = 32, 

15 normalize: bool = True, 

16 device: Optional[str] = None, 

17 cutoff_layers: Optional[List[int]] = None, 

18 ): 

19 self.model_name = model_name 

20 self.batch_size = batch_size 

21 self.normalize = normalize 

22 self.device = device 

23 if self.model_name.endswith("m3"): 

24 self.reranker = FlagReranker(self.model_name, use_fp16=use_fp16, device=self.device) 

25 elif self.model_name.endswith("gemma"): 

26 self.reranker = FlagLLMReranker(self.model_name, use_fp16=use_fp16, device=self.device) 

27 elif self.model_name.endswith("minicpm-layerwise"): 

28 if cutoff_layers is None: 

29 self.cutoff_layers = [28] 

30 self.reranker = LayerWiseFlagLLMReranker(self.model_name, use_fp16=use_fp16, device=self.device) 

31 

32 def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]: 

33 return [texts[i: i + batch_size] for i in range(0, len(texts), batch_size)] 

34 

35 def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: 

36 batched_texts = self._batchify(documents, self.batch_size) 

37 scores = [] 

38 for batched_text in batched_texts: 

39 query_document_pairs = [(query, text) for text in batched_text] 

40 if type(self.reranker) is LayerWiseFlagLLMReranker: 

41 batch_score = self.reranker.compute_score( 

42 query_document_pairs, 

43 normalize=self.normalize, 

44 cutoff_layers=self.cutoff_layers 

45 ) 

46 else: 

47 batch_score = self.reranker.compute_score( 

48 query_document_pairs, 

49 normalize=self.normalize 

50 ) 

51 scores.extend(batch_score) 

52 ranked_order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) 

53 

54 if top_k: 

55 ranked_order = ranked_order[:top_k] 

56 

57 results = [] 

58 for index in ranked_order: 

59 results.append(RerankResult(text=documents[index], score=scores[index], index=index)) 

60 return results 

61 

62 

63class BgeReranker(BaseReranker): 

64 """ 

65 For bge-reranker-v2-m3, bge-reranker-v2-gemma, bge-reranker-v2-minicpm-layerwise models only. 

66 """ 

67 

68 def __init__( 

69 self, 

70 model_name: str = "BAAI/bge-reranker-v2-m3", 

71 use_fp16: bool = True, 

72 batch_size: int = 32, 

73 normalize: bool = True, 

74 cutoff_layers: Optional[List[int]] = None, 

75 device: Optional[str] = None, 

76 ): 

77 self.model_name = model_name 

78 self.use_fp16 = use_fp16 

79 self.batch_size = batch_size 

80 self.normalize = normalize 

81 self.device = device 

82 self.model: BgeRerankFunction = BgeRerankFunction( 

83 model_name=self.model_name, 

84 use_fp16=self.use_fp16, 

85 batch_size=self.batch_size, 

86 normalize=self.normalize, 

87 cutoff_layers=cutoff_layers, 

88 device=self.device 

89 ) 

90 

91 def rerank(self, query: str, texts: List[str], top_k: int) -> List[Dict[str, Any]]: 

92 results: List[RerankResult] = self.model( 

93 query=query, 

94 documents=texts, 

95 top_k=top_k 

96 ) 

97 result_dicts: List[Dict[str, Any]] = [result.to_dict() for result in results] 

98 

99 return result_dicts