Shing Yee commited on
Commit
70c7861
Β·
unverified Β·
1 Parent(s): 885748b

update models

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -7,15 +7,12 @@ from utils import (
7
  embeddings_predict_relevance,
8
  stsb_model,
9
  stsb_tokenizer,
10
- ms_model,
11
- ms_tokenizer,
12
  cross_encoder_predict_relevance
13
  )
14
 
15
  def predict(system_prompt, user_prompt):
16
  predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
17
  predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
18
- predicted_label_ms, probabilities_ms = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device)
19
 
20
  result = f"""
21
  **Prediction Summary**
@@ -27,10 +24,6 @@ def predict(system_prompt, user_prompt):
27
  **2. Model: cross-encoder/stsb-roberta-base**
28
  - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_stsb==1 else "🟩 On-topic"}
29
  - **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
30
-
31
- **3. Model: cross-encoder/ms-marco-MiniLM-L-6-v2**
32
- - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_ms==1 else "🟩 On-topic"}
33
- - **Probability of being off-topic**: {probabilities_ms[0][1]:.2%}
34
  """
35
 
36
  return result
 
7
  embeddings_predict_relevance,
8
  stsb_model,
9
  stsb_tokenizer,
 
 
10
  cross_encoder_predict_relevance
11
  )
12
 
13
  def predict(system_prompt, user_prompt):
14
  predicted_label_jina, probabilities_jina = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device)
15
  predicted_label_stsb, probabilities_stsb = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device)
 
16
 
17
  result = f"""
18
  **Prediction Summary**
 
24
  **2. Model: cross-encoder/stsb-roberta-base**
25
  - **Prediction**: {"πŸŸ₯ Off-topic" if predicted_label_stsb==1 else "🟩 On-topic"}
26
  - **Probability of being off-topic**: {probabilities_stsb[0][1]:.2%}
 
 
 
 
27
  """
28
 
29
  return result
models/cross-encoder-ms-marco-MiniLM-L-6-v2-CrossEncoder-OffTopic-Classifier-20240918-090615.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:78a99fac3bc5b4729fee844d2154ea625aa9ceac2928cd648984ee1da5b8a203
3
- size 91236352
 
 
 
 
models/cross-encoder-stsb-roberta-base-CrossEncoder-OffTopic-Classifier-20240920-174009.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e90752828e92bc2f8ec567b85b3de5a0c8c5ddc331c1907d4dfa950624f71ce
3
- size 500085976
 
 
 
 
models/jinaai-jina-embeddings-v2-small-en-TwinEncoder-OffTopic-Classifier-20240915-151858.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:223687abc28cf0fa198d326d2786374000396d841e66d684c022941da2ca9628
3
- size 144076480
 
 
 
 
utils.py CHANGED
@@ -1,8 +1,11 @@
 
1
  import torch
2
  from torch import nn
3
  from safetensors.torch import load_file
4
  from transformers import AutoModel, AutoTokenizer
 
5
 
 
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
8
  # Load the model state_dict from safetensors
@@ -13,9 +16,9 @@ def load_model_safetensors(model, load_path="model.safetensors"):
13
  model.load_state_dict(state_dict)
14
  return model
15
 
16
- ##########################
17
  # JINA EMBEDDINGS
18
- ##########################
19
 
20
  # Jina Configs
21
  JINA_CONTEXT_LEN = 1024
@@ -101,7 +104,7 @@ class CrossEncoderWithSharedBase(nn.Module):
101
  logits = self.classifier(projected)
102
  return logits
103
 
104
- # Prediction function
105
  def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
106
  model.eval()
107
  inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
@@ -117,23 +120,32 @@ def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device)
117
  predicted_label = torch.argmax(probabilities, dim=1).item()
118
  return predicted_label, probabilities.cpu().numpy()
119
 
120
- # Jina model
121
- JINA_MODEL_NAME = "jinaai/jina-embeddings-v2-small-en"
 
 
 
 
 
 
 
 
 
122
  jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
123
  jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
124
  jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)
125
- jina_model = load_model_safetensors(jina_model, load_path="models/jinaai-jina-embeddings-v2-small-en-TwinEncoder-OffTopic-Classifier-20240915-151858.safetensors")
126
 
127
- ##########################
 
 
 
 
128
  # CROSS-ENCODER
129
- ##########################
130
 
131
- # STSB Configs
132
  STSB_CONTEXT_LEN = 512
133
 
134
- # ms-macro Configs
135
- MS_CONTEXT_LEN = 512
136
-
137
  class CrossEncoderWithMLP(nn.Module):
138
  def __init__(self, base_model, num_labels=2):
139
  super(CrossEncoderWithMLP, self).__init__()
@@ -162,6 +174,7 @@ class CrossEncoderWithMLP(nn.Module):
162
  logits = self.classifier(mlp_output)
163
  return logits
164
 
 
165
  def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
166
  model.eval()
167
  # Tokenize the pair of sentences
@@ -187,16 +200,20 @@ def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, devi
187
  predicted_label = torch.argmax(probabilities, dim=1).item()
188
  return predicted_label, probabilities.cpu().numpy()
189
 
190
- # STSB model
191
- STSB_MODEL_NAME = "cross-encoder/stsb-roberta-base"
 
 
 
 
 
 
 
 
192
  stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME)
193
  stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
194
  stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)
195
- stsb_model = load_model_safetensors(stsb_model, load_path="models/cross-encoder-stsb-roberta-base-CrossEncoder-OffTopic-Classifier-20240920-174009.safetensors")
196
-
197
- # MS model
198
- MS_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
199
- ms_tokenizer = AutoTokenizer.from_pretrained(MS_MODEL_NAME)
200
- ms_base_model = AutoModel.from_pretrained(MS_MODEL_NAME)
201
- ms_model = CrossEncoderWithMLP(ms_base_model, num_labels=2)
202
- ms_model = load_model_safetensors(ms_model, load_path="models/cross-encoder-ms-marco-MiniLM-L-6-v2-CrossEncoder-OffTopic-Classifier-20240918-090615.safetensors")
 
1
+ import json
2
  import torch
3
  from torch import nn
4
  from safetensors.torch import load_file
5
  from transformers import AutoModel, AutoTokenizer
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ # Set device
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
  # Load the model state_dict from safetensors
 
16
  model.load_state_dict(state_dict)
17
  return model
18
 
19
+ ###################
20
  # JINA EMBEDDINGS
21
+ ###################
22
 
23
  # Jina Configs
24
  JINA_CONTEXT_LEN = 1024
 
104
  logits = self.classifier(projected)
105
  return logits
106
 
107
+ # Prediction function for embeddings relevance
108
  def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
109
  model.eval()
110
  inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
 
120
  predicted_label = torch.argmax(probabilities, dim=1).item()
121
  return predicted_label, probabilities.cpu().numpy()
122
 
123
+ # Load configuration file
124
+ jina_repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
125
+ jina_config_path = hf_hub_download(repo_id=jina_repo_path, filename="config.json")
126
+ with open(jina_config_path, 'r') as f:
127
+ jina_config = json.load(f)
128
+
129
+ # Load Jina model configuration
130
+ JINA_MODEL_NAME = jina_config['classifier']['embedding']['model_name']
131
+ jina_model_weights_fp = jina_config['classifier']['embedding']['model_weights_fp']
132
+
133
+ # Load tokenizer and model
134
  jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
135
  jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
136
  jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)
 
137
 
138
+ # Load model weights from safetensors
139
+ jina_model_weights_path = hf_hub_download(repo_id=jina_repo_path, filename=jina_model_weights_fp)
140
+ jina_model = load_model_safetensors(jina_model, jina_model_weights_path)
141
+
142
+ #################
143
  # CROSS-ENCODER
144
+ #################
145
 
146
+ # STSB Configuration
147
  STSB_CONTEXT_LEN = 512
148
 
 
 
 
149
  class CrossEncoderWithMLP(nn.Module):
150
  def __init__(self, base_model, num_labels=2):
151
  super(CrossEncoderWithMLP, self).__init__()
 
174
  logits = self.classifier(mlp_output)
175
  return logits
176
 
177
+ # Prediction function for cross-encoder
178
  def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
179
  model.eval()
180
  # Tokenize the pair of sentences
 
200
  predicted_label = torch.argmax(probabilities, dim=1).item()
201
  return predicted_label, probabilities.cpu().numpy()
202
 
203
+ # Load STSB model configuration
204
+ stsb_repo_path = "govtech/stsb-roberta-base-off-topic"
205
+ stsb_config_path = hf_hub_download(repo_id=stsb_repo_path, filename="config.json")
206
+ with open(stsb_config_path, 'r') as f:
207
+ stsb_config = json.load(f)
208
+
209
+ STSB_MODEL_NAME = stsb_config['classifier']['embedding']['model_name']
210
+ stsb_model_weights_fp = stsb_config['classifier']['embedding']['model_weights_fp']
211
+
212
+ # Load STSB tokenizer and model
213
  stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME)
214
  stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
215
  stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)
216
+
217
+ # Load model weights from safetensors for STSB
218
+ stsb_model_weights_path = hf_hub_download(repo_id=stsb_repo_path, filename=stsb_model_weights_fp)
219
+ stsb_model = load_model_safetensors(stsb_model, stsb_model_weights_path)