import logging import uvicorn from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline import os os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/' logging.basicConfig( format='%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s', level=logging.DEBUG, datefmt='%Y-%m-%d %H:%M:%S' ) classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False) def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] # First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) tokenizer = AutoTokenizer.from_pretrained('models/all-MiniLM-L6-v2') model = AutoModel.from_pretrained('models/all-MiniLM-L6-v2') app = FastAPI() class RequestData(BaseModel): multiLabel: bool sequence: str labels: list[str] class ResponseData(BaseModel): sequence: str labels: list[str] scores: list[float] def classify(data: RequestData): return classifier(data.sequence, data.labels, multi_label=data.multiLabel) def similarity(data: RequestData): sentences = [data.sequence] sentences.extend(data.labels) encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) text_probs = sentence_embeddings[:1] @ sentence_embeddings[1:].T return text_probs.tolist()[0] @app.post("/classify", response_model=ResponseData, tags=["Classificator"]) async def classify_text(data: RequestData): result = classify(data) logging.info(result) return result @app.post("/similarity", response_model=ResponseData, tags=["Similarity"]) async def classify_text(data: RequestData): result = similarity(data) logging.info(result) return ResponseData.model_validate({ "sequence": data.sequence, "labels": data.labels, "scores": result }) @app.get("/ping", tags=["TEST"]) async def ping(): return "pong" if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8000)