| |
|
| | from fastapi import FastAPI
|
| | import torch
|
| | from transformers import BertTokenizer
|
| | from model.sentiment_model import SentimentAnalysisModel
|
| | from schemas.sentiment import SentimentRequest, SentimentResponse
|
| | from services.inference import predict_sentiment
|
| | from schemas.sentiment import BatchSentimentRequest
|
| | from services.inference import batch_predict
|
| | from fastapi.concurrency import run_in_threadpool
|
| |
|
| |
|
| | app = FastAPI()
|
| |
|
| |
|
| | @app.on_event("startup")
|
| | def startup_event():
|
| | global tokenizer, model, device
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| |
|
| | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| | model = SentimentAnalysisModel("bert-base-uncased")
|
| | model.load_state_dict(torch.load("bert_imdb_sentiment.pth", map_location=device))
|
| | model.to(device)
|
| | model.eval()
|
| |
|
| |
|
| | @app.post("/predict", response_model=SentimentResponse)
|
| | async def predict_api(req: SentimentRequest):
|
| | label, conf = await run_in_threadpool(
|
| | predict_sentiment, req.text, tokenizer, model, device
|
| | )
|
| | return SentimentResponse(label=label, confidence=conf)
|
| |
|
| |
|
| |
|
| | @app.post("/predict_batch")
|
| | async def predict_batch_api(req: BatchSentimentRequest):
|
| | results = await run_in_threadpool(
|
| | batch_predict, req.texts, tokenizer, model, device
|
| | )
|
| | return results
|
| |
|
| |
|
| |
|
| | @app.get("/health")
|
| | def health():
|
| | return {"status": "ok"}
|
| |
|