Bagi4's picture
feat: new model
3bdc77f
raw
history blame
2.53 kB
import logging
import uvicorn
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
import os
os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
logging.basicConfig(
format='%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s',
level=logging.DEBUG,
datefmt='%Y-%m-%d %H:%M:%S'
)
classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False)
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
tokenizer = AutoTokenizer.from_pretrained('models/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('models/all-MiniLM-L6-v2')
app = FastAPI()
class RequestData(BaseModel):
multiLabel: bool
sequence: str
labels: list[str]
class ResponseData(BaseModel):
sequence: str
labels: list[str]
scores: list[float]
def classify(data: RequestData):
return classifier(data.sequence, data.labels, multi_label=data.multiLabel)
def similarity(data: RequestData):
sentences = [data.sequence]
sentences.extend(data.labels)
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
text_probs = sentence_embeddings[:1] @ sentence_embeddings[1:].T
return text_probs.tolist()[0]
@app.post("/classify", response_model=ResponseData, tags=["Classificator"])
async def classify_text(data: RequestData):
result = classify(data)
logging.info(result)
return result
@app.post("/similarity", response_model=ResponseData, tags=["Similarity"])
async def classify_text(data: RequestData):
result = similarity(data)
logging.info(result)
return ResponseData.model_validate({
"sequence": data.sequence,
"labels": data.labels,
"scores": result
})
@app.get("/ping", tags=["TEST"])
async def ping():
return "pong"
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000)