File size: 4,300 Bytes
39562f3
56f8c1c
72fece1
56f8c1c
 
 
3993dd4
 
56f8c1c
 
 
 
 
747a771
 
4816530
56f8c1c
 
 
 
4816530
747a771
39562f3
 
56f8c1c
 
 
 
bea1666
56f8c1c
747a771
56f8c1c
f3bbf37
e707850
 
 
 
 
 
 
72fece1
 
 
 
56f8c1c
 
 
 
 
 
 
747a771
56f8c1c
 
39562f3
 
56f8c1c
747a771
 
443ce70
 
 
bea1666
39562f3
56f8c1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bea1666
56f8c1c
bea1666
 
 
 
 
 
 
 
 
 
 
56f8c1c
 
 
 
 
 
3993dd4
56f8c1c
 
39562f3
72fece1
56f8c1c
 
 
72fece1
 
39562f3
 
56f8c1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e707850
39562f3
56f8c1c
 
39562f3
bea1666
 
39562f3
 
56f8c1c
 
3993dd4
 
 
56f8c1c
 
3993dd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gc
import uuid
import logging
import asyncio
import transformers
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import (
    AutoTokenizer,
)
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification

# ----------------------------- #
#         Configurations        #
# ----------------------------- #
transformers.set_seed(42)

MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
BATCH_PROCESS_INTERVAL = 0.05
MAX_BATCH_SIZE = 16

# ----------------------------- #
#        Shared Storage         #
# ----------------------------- #
lock = asyncio.Lock()
query_queue: asyncio.Queue = asyncio.Queue()
results: dict[str, dict] = {}
classifier = None  # will be initialized in lifespan

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger = logging.getLogger(__name__)
logger.info("Starting the application...")
logger.info(f"Using model: {MODEL_NAME}")
logger.info(f"Batch process interval: {BATCH_PROCESS_INTERVAL}")
logger.info(f"Max batch size: {MAX_BATCH_SIZE}")


# ----------------------------- #
#      Model Initialization     #
# ----------------------------- #
def load_classifier(model_name: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = ORTModelForSequenceClassification.from_pretrained(
        model_name,
    )

    gc.collect()
    return pipeline(
        task="text-classification",
        accelerator="ort",
        model=model,
        tokenizer=tokenizer,
        framework="pt",
        batch_size=MAX_BATCH_SIZE,
        num_workers=1,
    )


# ----------------------------- #
#         Pydantic Schema       #
# ----------------------------- #
class Query(BaseModel):
    sentence: str


# ----------------------------- #
#      Queue Processing Task    #
# ----------------------------- #
async def process_queue():
    while True:
        await asyncio.sleep(BATCH_PROCESS_INTERVAL)

        batch = []
        while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
            batch.append(await query_queue.get())

        if not batch:
            continue

        sentences = [item["sentence"] for item in batch]
        ids = [item["id"] for item in batch]
        predictions = classifier(sentences)

        async with lock:
            results.update(
                {
                    query_id: {
                        "sentence": sentence,
                        "label": pred["label"],
                        "score": pred["score"],
                    }
                    for query_id, pred, sentence in zip(ids, predictions, sentences)
                }
            )


# ----------------------------- #
#        Lifespan Handler       #
# ----------------------------- #
@asynccontextmanager
async def lifespan(_: FastAPI):
    global classifier
    classifier = load_classifier(MODEL_NAME)
    _ = classifier("Hi")
    logger.info("Model loaded successfully.")
    queue_task = asyncio.create_task(process_queue())
    yield
    queue_task.cancel()
    logger.info("Shutting down the application...")
    logger.info("Model unloaded successfully.")
    classifier = None
    gc.collect()
    try:
        await queue_task
    except asyncio.CancelledError:
        pass


# ----------------------------- #
#         FastAPI Setup         #
# ----------------------------- #
app = FastAPI(lifespan=lifespan)


# ----------------------------- #
#         API Endpoints         #
# ----------------------------- #
@app.post("/classify")
async def classify(query: Query):
    logger.info(f"{query.sentence}")
    query_id = uuid.uuid4().hex
    await query_queue.put({"id": query_id, "sentence": query.sentence})

    while True:
        async with lock:
            if query_id in results:
                return {"id": query_id, "result": results.pop(query_id)}
        await asyncio.sleep(0.1)


app.mount("/", StaticFiles(directory="static", html=True), name="static")


@app.get("/")
def read_root():
    return FileResponse(path="static/index.html", media_type="text/html")