MienOlle's picture
More JSON fixes
06912ed
from huggingface_hub import hf_hub_download
import torch
from transformers import AutoModelForSequenceClassification as modelSC, AutoTokenizer as token
from fastapi import FastAPI
from pydantic import BaseModel
import os
from typing import List
app = FastAPI()
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
model_path = hf_hub_download(repo_id="MienOlle/sentiment_analysis_api",
filename="sentimentAnalysis.pth",
cache_dir=os.environ["HF_HOME"]
)
modelToken = token.from_pretrained("mdhugol/indonesia-bert-sentiment-classification", cache_dir=os.environ["TRANSFORMERS_CACHE"])
model = modelSC.from_pretrained("mdhugol/indonesia-bert-sentiment-classification", num_labels=3, cache_dir=os.environ["TRANSFORMERS_CACHE"])
device = "cuda" if torch.cuda.is_available() else "cpu"
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)
model.eval()
class TextInput(BaseModel):
text: List[str]
def predict(input):
inputs = modelToken(input, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: tensor.to(device) for key, tensor in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
rets = logits.argmax(dim = 1).tolist()
labels = ["positive", "neutral", "negative"]
return [labels[ret] for ret in rets]
@app.post("/predict")
def get_sentiment(data: TextInput):
return {"predictions": predict(data.text)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)