from transformers import AutoModelForTokenClassification, AutoTokenizer | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
app = FastAPI() | |
# Setup middlewares | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Setup static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
async def ping(): | |
return {"status": "pong"} | |
async def predict(input_text: str): | |
tokenizer = AutoTokenizer.from_pretrained("your_model_name") | |
model = AutoModelForTokenClassification.from_pretrained("your_model_name") | |
inputs = tokenizer([input_text], return_tensors="pt", padding=True, truncation=True) | |
outputs = model(**inputs) | |
prediction = outputs.logits.argmax().item() | |
return {"prediction": prediction} | |