dilawarm commited on
Commit
2d483e1
·
verified ·
1 Parent(s): f5538a8

Delete modeling_zeranker.py

Browse files
Files changed (1) hide show
  1. modeling_zeranker.py +0 -216
modeling_zeranker.py DELETED
@@ -1,216 +0,0 @@
1
- from sentence_transformers import CrossEncoder as _CE
2
-
3
- import math
4
- from typing import cast, Any
5
- import types
6
-
7
-
8
- import torch
9
- from transformers.configuration_utils import PretrainedConfig
10
-
11
- from transformers.models.auto.configuration_auto import AutoConfig
12
- from transformers.models.auto.modeling_auto import AutoModelForCausalLM
13
- from transformers.models.auto.tokenization_auto import AutoTokenizer
14
- from transformers.models.gemma3.modeling_gemma3 import (
15
- Gemma3ForCausalLM,
16
- Gemma3ForConditionalGeneration,
17
- )
18
- from transformers.models.llama.modeling_llama import LlamaForCausalLM
19
- from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
20
- from transformers.tokenization_utils_base import BatchEncoding
21
- from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
-
23
- # pyright: reportUnknownMemberType=false
24
- # pyright: reportUnknownVariableType=false
25
-
26
- MODEL_PATH = "zeroentropy/zerank-2"
27
- PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
28
- global_device = (
29
- torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
- )
31
-
32
-
33
- def format_pointwise_datapoints(
34
- tokenizer: PreTrainedTokenizerFast,
35
- query_documents: list[tuple[str, str]],
36
- ) -> BatchEncoding:
37
- input_texts: list[str] = []
38
- for query, document in query_documents:
39
- system_prompt = f"""
40
- {query}
41
- """.strip()
42
- user_message = f"""
43
- {document}
44
- """.strip()
45
- messages = [
46
- {"role": "system", "content": system_prompt},
47
- {"role": "user", "content": user_message},
48
- ]
49
- input_text = tokenizer.apply_chat_template(
50
- messages,
51
- tokenize=False,
52
- add_generation_prompt=True,
53
- )
54
- assert isinstance(input_text, str)
55
- input_texts.append(input_text)
56
-
57
- batch_inputs = tokenizer(
58
- input_texts,
59
- padding=True,
60
- return_tensors="pt",
61
- )
62
- return batch_inputs
63
-
64
-
65
- def load_model(
66
- device: torch.device | None = None,
67
- ) -> tuple[
68
- PreTrainedTokenizerFast,
69
- LlamaForCausalLM
70
- | Gemma3ForConditionalGeneration
71
- | Gemma3ForCausalLM
72
- | Qwen3ForCausalLM,
73
- ]:
74
- if device is None:
75
- device = global_device
76
-
77
- config = AutoConfig.from_pretrained(MODEL_PATH)
78
- assert isinstance(config, PretrainedConfig)
79
-
80
- model = AutoModelForCausalLM.from_pretrained(
81
- MODEL_PATH,
82
- torch_dtype="auto",
83
- quantization_config=None,
84
- device_map={"": device},
85
- )
86
- if config.model_type == "llama":
87
- model.config.attn_implementation = "flash_attention_2"
88
- assert isinstance(
89
- model,
90
- LlamaForCausalLM
91
- | Gemma3ForConditionalGeneration
92
- | Gemma3ForCausalLM
93
- | Qwen3ForCausalLM,
94
- )
95
-
96
- tokenizer = cast(
97
- AutoTokenizer,
98
- AutoTokenizer.from_pretrained(
99
- MODEL_PATH,
100
- padding_side="right",
101
- ),
102
- )
103
- assert isinstance(tokenizer, PreTrainedTokenizerFast)
104
-
105
- if tokenizer.pad_token is None:
106
- tokenizer.pad_token = tokenizer.eos_token
107
-
108
- return tokenizer, model
109
-
110
-
111
- def predict(
112
- self,
113
- query_documents: list[tuple[str, str]] | None = None,
114
- *,
115
- sentences: Any = None,
116
- batch_size: Any = None,
117
- show_progress_bar: Any = None,
118
- activation_fn: Any = None,
119
- apply_softmax: Any = None,
120
- convert_to_numpy: Any = None,
121
- convert_to_tensor: Any = None,
122
- ) -> list[float]:
123
- if query_documents is None:
124
- if sentences is None:
125
- raise ValueError("query_documents or sentences must be provided")
126
- query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
-
128
- if not hasattr(self, "inner_model"):
129
- self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.gradient_checkpointing_enable()
131
- self.inner_model.eval()
132
- self.inner_yes_token_id = self.inner_tokenizer.encode(
133
- "Yes", add_special_tokens=False
134
- )[0]
135
-
136
- model = self.inner_model
137
- tokenizer = self.inner_tokenizer
138
-
139
- query_documents = [
140
- (query[:2_000], document[:10_000]) for query, document in query_documents
141
- ]
142
- # Sort
143
- permutation = list(range(len(query_documents)))
144
- permutation.sort(
145
- key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
146
- )
147
- query_documents = [query_documents[i] for i in permutation]
148
-
149
- # Extract document batches from this line of datapoints
150
- max_length = 0
151
- batches: list[list[tuple[str, str]]] = []
152
- for query, document in query_documents:
153
- if (
154
- len(batches) == 0
155
- or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document))
156
- > PER_DEVICE_BATCH_SIZE_TOKENS
157
- ):
158
- batches.append([])
159
- max_length = 0
160
-
161
- batches[-1].append((query, document))
162
- max_length = max(max_length, 20 + len(query) + len(document))
163
-
164
- # Inference all of the document batches
165
- all_logits: list[float] = []
166
- for batch in batches:
167
- batch_inputs = format_pointwise_datapoints(
168
- tokenizer,
169
- batch,
170
- )
171
-
172
- batch_inputs = batch_inputs.to(global_device)
173
-
174
- try:
175
- outputs = model(**batch_inputs, use_cache=False)
176
- except torch.OutOfMemoryError:
177
- print(f"GPU OOM! {torch.cuda.memory_reserved()}")
178
- torch.cuda.empty_cache()
179
- print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}")
180
- outputs = model(**batch_inputs, use_cache=False)
181
-
182
- # Extract the logits
183
- logits = cast(torch.Tensor, outputs.logits)
184
- attention_mask = cast(torch.Tensor, batch_inputs.attention_mask)
185
- last_positions = attention_mask.sum(dim=1) - 1
186
-
187
- batch_size = logits.shape[0]
188
- batch_indices = torch.arange(batch_size, device=global_device)
189
- last_logits = logits[batch_indices, last_positions]
190
-
191
- yes_logits = last_logits[:, self.inner_yes_token_id]
192
- all_logits.extend([float(logit) / 5.0 for logit in yes_logits])
193
-
194
- def sigmoid(x: float) -> float:
195
- return 1 / (1 + math.exp(-x))
196
-
197
- scores = [sigmoid(logit) for logit in all_logits]
198
-
199
- # Unsort by indices
200
- scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
201
-
202
- return scores
203
-
204
-
205
- def to_device(self: _CE, new_device: torch.device) -> None:
206
- global global_device
207
- global_device = new_device
208
-
209
-
210
- _CE.predict = predict
211
-
212
- from transformers import Qwen3Config
213
-
214
- ZEConfig = Qwen3Config
215
-
216
- _CE.to = to_device