File size: 2,530 Bytes
502401a
dce67e9
f791980
 
 
502401a
 
 
e155850
 
 
502401a
 
 
 
 
 
dce67e9
f791980
 
 
 
 
 
 
 
3bdc77f
 
f791980
502401a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f791980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502401a
 
f791980
502401a
 
 
 
f791980
 
 
 
 
 
 
 
 
 
 
502401a
f791980
502401a
 
dce67e9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
import uvicorn
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
import os

os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'

logging.basicConfig(
    format='%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s',
    level=logging.DEBUG,
    datefmt='%Y-%m-%d %H:%M:%S'
)
classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False)


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


tokenizer = AutoTokenizer.from_pretrained('models/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('models/all-MiniLM-L6-v2')

app = FastAPI()


class RequestData(BaseModel):
    multiLabel: bool
    sequence: str
    labels: list[str]


class ResponseData(BaseModel):
    sequence: str
    labels: list[str]
    scores: list[float]


def classify(data: RequestData):
    return classifier(data.sequence, data.labels, multi_label=data.multiLabel)


def similarity(data: RequestData):
    sentences = [data.sequence]
    sentences.extend(data.labels)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    with torch.no_grad():
        model_output = model(**encoded_input)

    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    text_probs = sentence_embeddings[:1] @ sentence_embeddings[1:].T
    return text_probs.tolist()[0]


@app.post("/classify", response_model=ResponseData, tags=["Classificator"])
async def classify_text(data: RequestData):
    result = classify(data)
    logging.info(result)
    return result


@app.post("/similarity", response_model=ResponseData, tags=["Similarity"])
async def classify_text(data: RequestData):
    result = similarity(data)
    logging.info(result)
    return ResponseData.model_validate({
        "sequence": data.sequence,
        "labels": data.labels,
        "scores": result
    })


@app.get("/ping", tags=["TEST"])
async def ping():
    return "pong"


if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8000)