Cricles commited on
Commit
012eff7
·
verified ·
1 Parent(s): a804858

Update FRIDA/model.py

Browse files
Files changed (1) hide show
  1. FRIDA/model.py +119 -119
FRIDA/model.py CHANGED
@@ -1,120 +1,120 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from transformers import AutoTokenizer, T5EncoderModel
4
- import os
5
- from typing import List
6
- import re
7
-
8
- FRIDA_EMB_DIM = 1536
9
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
-
11
-
12
- def pool(hidden_state, mask, pooling_method="cls"):
13
- if pooling_method == "mean":
14
- s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
15
- d = mask.sum(axis=1, keepdim=True).float()
16
- return s / d
17
- elif pooling_method == "cls":
18
- return hidden_state[:, 0]
19
-
20
-
21
- class FridaClassifier(torch.nn.Module):
22
- def __init__(self):
23
- super(FridaClassifier, self).__init__()
24
- self.frida_embedder = T5EncoderModel.from_pretrained("ai-forever/FRIDA")
25
- self._freeze_embedder_grad()
26
- self.classifier = torch.nn.Sequential(
27
- torch.nn.Linear(in_features=FRIDA_EMB_DIM, out_features=500),
28
- torch.nn.Dropout(p=0.2),
29
- torch.nn.SELU(),
30
- torch.nn.Linear(in_features=500, out_features=100),
31
- torch.nn.Dropout(p=0.1),
32
- torch.nn.SELU(),
33
- torch.nn.Linear(in_features=100, out_features=2)
34
- )
35
-
36
- def _freeze_embedder_grad(self):
37
- for param in self.frida_embedder.parameters():
38
- param.requires_grad = False
39
-
40
- def forward(self, input_ids, attention_mask):
41
- with torch.no_grad(): # no gradients calculation for frida embedder
42
- outputs = self.frida_embedder(input_ids=input_ids, attention_mask=attention_mask)
43
-
44
- embeddings = pool(
45
- outputs.last_hidden_state,
46
- attention_mask,
47
- pooling_method="cls" # or try "mean"
48
- )
49
- embeddings = F.normalize(embeddings, p=2, dim=1)
50
- out = self.classifier(embeddings)
51
-
52
- return out
53
-
54
-
55
- # return model and tokenizer
56
- def load_model(head_path: str):
57
- if not os.path.isfile(head_path):
58
- raise Exception(f'no model weights with path - {head_path}')
59
- loaded_model = FridaClassifier()
60
- loaded_model.classifier.load_state_dict(torch.load(head_path, map_location='cpu', weights_only=True))
61
- loaded_model.eval()
62
- loaded_model.to(device)
63
- tokenizer = AutoTokenizer.from_pretrained("ai-forever/FRIDA")
64
-
65
- return loaded_model, tokenizer
66
-
67
-
68
- def infer(model: FridaClassifier, tokenizer: AutoTokenizer, texts: List[str], device):
69
- with torch.no_grad():
70
- model.eval()
71
- texts = ["categorize_sentiment: " + text for text in texts]
72
- tokenized_data = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors="pt")
73
- input_ids, attention_masks = tokenized_data['input_ids'].type(torch.LongTensor).to(device), tokenized_data[
74
- 'attention_mask'].type(torch.LongTensor).to(device)
75
- logits_tensor = model(input_ids, attention_masks)
76
- sft_max = torch.nn.Softmax(dim=-1)
77
- pred_probs = sft_max(logits_tensor)
78
-
79
- return pred_probs
80
-
81
-
82
- labels = {0: 'non-toxic', 1: 'toxic'}
83
-
84
-
85
- print('loading model and tokenizer...')
86
- chkp_dir = 'models/' # CHANGE ON YOUR DIR WITH HEAD WEIGHTS!
87
- model, tokenizer = load_model(os.path.join(chkp_dir, "classifier_head.pth"))
88
- print('loaded.')
89
-
90
-
91
- from typing import List
92
- from pydantic import BaseModel
93
-
94
- # Define DTOs
95
- class ToxicityPrediction(BaseModel):
96
- text: str
97
- label: str
98
- toxicity_rate: float
99
-
100
-
101
- class ToxicityPredictionResponse(BaseModel):
102
- predictions: List[ToxicityPrediction]
103
-
104
-
105
- def generate_resp(texts: List[str]):
106
- probs = infer(model, tokenizer, texts, device)
107
- probs_arr = probs.to('cpu').numpy()
108
- predictions = torch.argmax(probs, dim=-1).int().to('cpu').numpy()
109
- predicted_labels = [labels[label] for label in predictions]
110
-
111
- predictions_list = [
112
- ToxicityPrediction(
113
- text=texts[i],
114
- label=predicted_labels[i],
115
- toxicity_rate=float(probs_arr[i][1]) # Ensure float type
116
- )
117
- for i in range(len(texts))
118
- ]
119
-
120
  return ToxicityPredictionResponse(predictions=predictions_list)
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, T5EncoderModel
4
+ import os
5
+ from typing import List
6
+ import re
7
+
8
+ FRIDA_EMB_DIM = 1536
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+
12
+ def pool(hidden_state, mask, pooling_method="cls"):
13
+ if pooling_method == "mean":
14
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
15
+ d = mask.sum(axis=1, keepdim=True).float()
16
+ return s / d
17
+ elif pooling_method == "cls":
18
+ return hidden_state[:, 0]
19
+
20
+
21
+ class FridaClassifier(torch.nn.Module):
22
+ def __init__(self):
23
+ super(FridaClassifier, self).__init__()
24
+ self.frida_embedder = T5EncoderModel.from_pretrained("ai-forever/FRIDA")
25
+ self._freeze_embedder_grad()
26
+ self.classifier = torch.nn.Sequential(
27
+ torch.nn.Linear(in_features=FRIDA_EMB_DIM, out_features=500),
28
+ torch.nn.Dropout(p=0.2),
29
+ torch.nn.SELU(),
30
+ torch.nn.Linear(in_features=500, out_features=100),
31
+ torch.nn.Dropout(p=0.1),
32
+ torch.nn.SELU(),
33
+ torch.nn.Linear(in_features=100, out_features=2)
34
+ )
35
+
36
+ def _freeze_embedder_grad(self):
37
+ for param in self.frida_embedder.parameters():
38
+ param.requires_grad = False
39
+
40
+ def forward(self, input_ids, attention_mask):
41
+ with torch.no_grad(): # no gradients calculation for frida embedder
42
+ outputs = self.frida_embedder(input_ids=input_ids, attention_mask=attention_mask)
43
+
44
+ embeddings = pool(
45
+ outputs.last_hidden_state,
46
+ attention_mask,
47
+ pooling_method="cls" # or try "mean"
48
+ )
49
+ embeddings = F.normalize(embeddings, p=2, dim=1)
50
+ out = self.classifier(embeddings)
51
+
52
+ return out
53
+
54
+
55
+ # return model and tokenizer
56
+ def load_model(head_path: str):
57
+ if not os.path.isfile(head_path):
58
+ raise Exception(f'no model weights with path - {head_path}')
59
+ loaded_model = FridaClassifier()
60
+ loaded_model.classifier.load_state_dict(torch.load(head_path, map_location='cpu', weights_only=True))
61
+ loaded_model.eval()
62
+ loaded_model.to(device)
63
+ tokenizer = AutoTokenizer.from_pretrained("ai-forever/FRIDA")
64
+
65
+ return loaded_model, tokenizer
66
+
67
+
68
+ def infer(model: FridaClassifier, tokenizer: AutoTokenizer, texts: List[str], device):
69
+ with torch.no_grad():
70
+ model.eval()
71
+ texts = ["categorize_sentiment: " + text for text in texts]
72
+ tokenized_data = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors="pt")
73
+ input_ids, attention_masks = tokenized_data['input_ids'].type(torch.LongTensor).to(device), tokenized_data[
74
+ 'attention_mask'].type(torch.LongTensor).to(device)
75
+ logits_tensor = model(input_ids, attention_masks)
76
+ sft_max = torch.nn.Softmax(dim=-1)
77
+ pred_probs = sft_max(logits_tensor)
78
+
79
+ return pred_probs
80
+
81
+
82
+ labels = {0: 'non-toxic', 1: 'toxic'}
83
+
84
+
85
+ print('loading model and tokenizer...')
86
+ chkp_dir = './' # CHANGE ON YOUR DIR WITH HEAD WEIGHTS!
87
+ model, tokenizer = load_model(os.path.join(chkp_dir, "classifier_head.pth"))
88
+ print('loaded.')
89
+
90
+
91
+ from typing import List
92
+ from pydantic import BaseModel
93
+
94
+ # Define DTOs
95
+ class ToxicityPrediction(BaseModel):
96
+ text: str
97
+ label: str
98
+ toxicity_rate: float
99
+
100
+
101
+ class ToxicityPredictionResponse(BaseModel):
102
+ predictions: List[ToxicityPrediction]
103
+
104
+
105
+ def generate_resp(texts: List[str]):
106
+ probs = infer(model, tokenizer, texts, device)
107
+ probs_arr = probs.to('cpu').numpy()
108
+ predictions = torch.argmax(probs, dim=-1).int().to('cpu').numpy()
109
+ predicted_labels = [labels[label] for label in predictions]
110
+
111
+ predictions_list = [
112
+ ToxicityPrediction(
113
+ text=texts[i],
114
+ label=predicted_labels[i],
115
+ toxicity_rate=float(probs_arr[i][1]) # Ensure float type
116
+ )
117
+ for i in range(len(texts))
118
+ ]
119
+
120
  return ToxicityPredictionResponse(predictions=predictions_list)