Coverage for apps/inners/use_cases/rerankers/bge_reranker.py: 48%
48 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-22 19:03 +0000
1from typing import Optional, List, Dict, Any
3from FlagEmbedding import FlagLLMReranker, LayerWiseFlagLLMReranker, FlagReranker
4from milvus_model.base import RerankResult, BaseRerankFunction
6from apps.inners.use_cases.rerankers.base_reranker import BaseReranker
9class BgeRerankFunction(BaseRerankFunction):
10 def __init__(
11 self,
12 model_name: str = "BAAI/bge-reranker-v2-m3",
13 use_fp16: bool = True,
14 batch_size: int = 32,
15 normalize: bool = True,
16 device: Optional[str] = None,
17 cutoff_layers: Optional[List[int]] = None,
18 ):
19 self.model_name = model_name
20 self.batch_size = batch_size
21 self.normalize = normalize
22 self.device = device
23 if self.model_name.endswith("m3"):
24 self.reranker = FlagReranker(self.model_name, use_fp16=use_fp16, device=self.device)
25 elif self.model_name.endswith("gemma"):
26 self.reranker = FlagLLMReranker(self.model_name, use_fp16=use_fp16, device=self.device)
27 elif self.model_name.endswith("minicpm-layerwise"):
28 if cutoff_layers is None:
29 self.cutoff_layers = [28]
30 self.reranker = LayerWiseFlagLLMReranker(self.model_name, use_fp16=use_fp16, device=self.device)
32 def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
33 return [texts[i: i + batch_size] for i in range(0, len(texts), batch_size)]
35 def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
36 batched_texts = self._batchify(documents, self.batch_size)
37 scores = []
38 for batched_text in batched_texts:
39 query_document_pairs = [(query, text) for text in batched_text]
40 if type(self.reranker) is LayerWiseFlagLLMReranker:
41 batch_score = self.reranker.compute_score(
42 query_document_pairs,
43 normalize=self.normalize,
44 cutoff_layers=self.cutoff_layers
45 )
46 else:
47 batch_score = self.reranker.compute_score(
48 query_document_pairs,
49 normalize=self.normalize
50 )
51 scores.extend(batch_score)
52 ranked_order = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
54 if top_k:
55 ranked_order = ranked_order[:top_k]
57 results = []
58 for index in ranked_order:
59 results.append(RerankResult(text=documents[index], score=scores[index], index=index))
60 return results
63class BgeReranker(BaseReranker):
64 """
65 For bge-reranker-v2-m3, bge-reranker-v2-gemma, bge-reranker-v2-minicpm-layerwise models only.
66 """
68 def __init__(
69 self,
70 model_name: str = "BAAI/bge-reranker-v2-m3",
71 use_fp16: bool = True,
72 batch_size: int = 32,
73 normalize: bool = True,
74 cutoff_layers: Optional[List[int]] = None,
75 device: Optional[str] = None,
76 ):
77 self.model_name = model_name
78 self.use_fp16 = use_fp16
79 self.batch_size = batch_size
80 self.normalize = normalize
81 self.device = device
82 self.model: BgeRerankFunction = BgeRerankFunction(
83 model_name=self.model_name,
84 use_fp16=self.use_fp16,
85 batch_size=self.batch_size,
86 normalize=self.normalize,
87 cutoff_layers=cutoff_layers,
88 device=self.device
89 )
91 def rerank(self, query: str, texts: List[str], top_k: int) -> List[Dict[str, Any]]:
92 results: List[RerankResult] = self.model(
93 query=query,
94 documents=texts,
95 top_k=top_k
96 )
97 result_dicts: List[Dict[str, Any]] = [result.to_dict() for result in results]
99 return result_dicts