File size: 4,008 Bytes
012eff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2167f0
 
 
 
012eff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa61f7e
012eff7
 
 
 
 
 
 
 
 
 
 
 
 
 
583265a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, T5EncoderModel
import os
from typing import List
import re

FRIDA_EMB_DIM = 1536
device = 'cuda' if torch.cuda.is_available() else 'cpu'


def pool(hidden_state, mask, pooling_method="cls"):
    if pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d
    elif pooling_method == "cls":
        return hidden_state[:, 0]


class FridaClassifier(torch.nn.Module):
    def __init__(self):
        super(FridaClassifier, self).__init__()
        self.frida_embedder = T5EncoderModel.from_pretrained("ai-forever/FRIDA")
        self._freeze_embedder_grad()
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=FRIDA_EMB_DIM, out_features=500),
            torch.nn.Dropout(p=0.2),
            torch.nn.SELU(),
            torch.nn.Linear(in_features=500, out_features=100),
            torch.nn.Dropout(p=0.1),
            torch.nn.SELU(),
            torch.nn.Linear(in_features=100, out_features=2)
        )

    def _freeze_embedder_grad(self):
        for param in self.frida_embedder.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():  # no gradients calculation for frida embedder
            outputs = self.frida_embedder(input_ids=input_ids, attention_mask=attention_mask)

            embeddings = pool(
                outputs.last_hidden_state,
                attention_mask,
                pooling_method="cls"  # or try "mean"
            )
            embeddings = F.normalize(embeddings, p=2, dim=1)
        out = self.classifier(embeddings)

        return out


# return model and tokenizer
def load_model(head_path: str):
    if not os.path.isfile(head_path):
        raise Exception(f'no model weights with path - {head_path}')
    loaded_model = FridaClassifier()
    loaded_model.classifier.load_state_dict(torch.load(head_path, map_location='cpu', weights_only=True))
    loaded_model.eval()
    loaded_model.to(device)
    tokenizer = AutoTokenizer.from_pretrained("ai-forever/FRIDA")

    return loaded_model, tokenizer


def infer(model: FridaClassifier, tokenizer: AutoTokenizer, texts: List[str], device):
    with torch.no_grad():
        model.eval()
        texts = ["categorize_sentiment: " + text for text in texts]
        tokenized_data = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors="pt")
        input_ids, attention_masks = tokenized_data['input_ids'].type(torch.LongTensor).to(device), tokenized_data[
            'attention_mask'].type(torch.LongTensor).to(device)
        logits_tensor = model(input_ids, attention_masks)
        sft_max = torch.nn.Softmax(dim=-1)
        pred_probs = sft_max(logits_tensor)

        return pred_probs


labels = {0: 'non-toxic', 1: 'toxic'}


#print('loading model and tokenizer...')
#chkp_dir = './'    # CHANGE ON YOUR DIR WITH HEAD WEIGHTS!
#model, tokenizer = load_model(os.path.join(chkp_dir, "classifier_head.pth"))
#print('loaded.')


from typing import List
from pydantic import BaseModel

# Define DTOs
class ToxicityPrediction(BaseModel):
    text: str
    label: str
    toxicity_rate: float


class ToxicityPredictionResponse(BaseModel):
    predictions: List[ToxicityPrediction]


def generate_resp(texts: List[str],model, tokenizer):
    probs = infer(model, tokenizer, texts, device)
    probs_arr = probs.to('cpu').numpy()
    predictions = torch.argmax(probs, dim=-1).int().to('cpu').numpy()
    predicted_labels = [labels[label] for label in predictions]

    predictions_list = [
        ToxicityPrediction(
            text=texts[i],
            label=predicted_labels[i],
            toxicity_rate=float(probs_arr[i][1])  # Ensure float type
        )
        for i in range(len(texts))
    ]

    return ToxicityPredictionResponse(predictions=predictions_list)