File size: 4,300 Bytes
39562f3 56f8c1c 72fece1 56f8c1c 3993dd4 56f8c1c 747a771 4816530 56f8c1c 4816530 747a771 39562f3 56f8c1c bea1666 56f8c1c 747a771 56f8c1c f3bbf37 e707850 72fece1 56f8c1c 747a771 56f8c1c 39562f3 56f8c1c 747a771 443ce70 bea1666 39562f3 56f8c1c bea1666 56f8c1c bea1666 56f8c1c 3993dd4 56f8c1c 39562f3 72fece1 56f8c1c 72fece1 39562f3 56f8c1c e707850 39562f3 56f8c1c 39562f3 bea1666 39562f3 56f8c1c 3993dd4 56f8c1c 3993dd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gc
import uuid
import logging
import asyncio
import transformers
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from contextlib import asynccontextmanager
from transformers import (
AutoTokenizer,
)
from optimum.pipelines import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
# ----------------------------- #
# Configurations #
# ----------------------------- #
transformers.set_seed(42)
MODEL_NAME = "distilroberta-base-climate-sentiment-onnx-quantized"
BATCH_PROCESS_INTERVAL = 0.05
MAX_BATCH_SIZE = 16
# ----------------------------- #
# Shared Storage #
# ----------------------------- #
lock = asyncio.Lock()
query_queue: asyncio.Queue = asyncio.Queue()
results: dict[str, dict] = {}
classifier = None # will be initialized in lifespan
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
logger.info("Starting the application...")
logger.info(f"Using model: {MODEL_NAME}")
logger.info(f"Batch process interval: {BATCH_PROCESS_INTERVAL}")
logger.info(f"Max batch size: {MAX_BATCH_SIZE}")
# ----------------------------- #
# Model Initialization #
# ----------------------------- #
def load_classifier(model_name: str):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = ORTModelForSequenceClassification.from_pretrained(
model_name,
)
gc.collect()
return pipeline(
task="text-classification",
accelerator="ort",
model=model,
tokenizer=tokenizer,
framework="pt",
batch_size=MAX_BATCH_SIZE,
num_workers=1,
)
# ----------------------------- #
# Pydantic Schema #
# ----------------------------- #
class Query(BaseModel):
sentence: str
# ----------------------------- #
# Queue Processing Task #
# ----------------------------- #
async def process_queue():
while True:
await asyncio.sleep(BATCH_PROCESS_INTERVAL)
batch = []
while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
batch.append(await query_queue.get())
if not batch:
continue
sentences = [item["sentence"] for item in batch]
ids = [item["id"] for item in batch]
predictions = classifier(sentences)
async with lock:
results.update(
{
query_id: {
"sentence": sentence,
"label": pred["label"],
"score": pred["score"],
}
for query_id, pred, sentence in zip(ids, predictions, sentences)
}
)
# ----------------------------- #
# Lifespan Handler #
# ----------------------------- #
@asynccontextmanager
async def lifespan(_: FastAPI):
global classifier
classifier = load_classifier(MODEL_NAME)
_ = classifier("Hi")
logger.info("Model loaded successfully.")
queue_task = asyncio.create_task(process_queue())
yield
queue_task.cancel()
logger.info("Shutting down the application...")
logger.info("Model unloaded successfully.")
classifier = None
gc.collect()
try:
await queue_task
except asyncio.CancelledError:
pass
# ----------------------------- #
# FastAPI Setup #
# ----------------------------- #
app = FastAPI(lifespan=lifespan)
# ----------------------------- #
# API Endpoints #
# ----------------------------- #
@app.post("/classify")
async def classify(query: Query):
logger.info(f"{query.sentence}")
query_id = uuid.uuid4().hex
await query_queue.put({"id": query_id, "sentence": query.sentence})
while True:
async with lock:
if query_id in results:
return {"id": query_id, "result": results.pop(query_id)}
await asyncio.sleep(0.1)
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def read_root():
return FileResponse(path="static/index.html", media_type="text/html")
|