File size: 963 Bytes
502401a
dce67e9
502401a
 
 
 
 
 
 
 
 
dce67e9
502401a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dce67e9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import logging
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline

logging.basicConfig(
    format='%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s',
    level=logging.DEBUG,
    datefmt='%Y-%m-%d %H:%M:%S'
)
classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False)
app = FastAPI()


class RequestData(BaseModel):
    multiLabel: bool
    sequence: str
    labels: list[str]


class ResponseData(BaseModel):
    sequence: str
    labels: list[str]
    scores: list[float]


@app.post("/classify", response_model=ResponseData, tags=["Classificator"])
async def classify_text(data: RequestData):
    result = classifier(data.sequence, data.labels, multi_label=data.multiLabel)
    logging.info(result)

    return result


@app.get("/ping", tags=["TEST"])
def ping():
    return "pong"


if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8000)