Spaces:
Runtime error
Runtime error
File size: 1,545 Bytes
e5b2387 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from app.model_utils import load_model_and_tokenizer, generate_summary
from app.classifier import train_classifier, classify_text
app = FastAPI()
# Load model and tokenizer for the /rag endpoint
model_name = "sshleifer/distilbart-cnn-6-6" # Example model
model, tokenizer = load_model_and_tokenizer(model_name)
# Dummy data and classifier for the /classification endpoint
dummy_data = [
("I feel very sad and hopeless.", "Depression"),
("I have trouble sleeping at night.", "Insomnia"),
("I am constantly worrying about everything.", "Anxiety"),
("I feel energetic and happy.", "Happiness"),
("My mood swings a lot and I feel irritable.", "Mood Disorder")
]
classifier, vectorizer = train_classifier(dummy_data)
class Prompt(BaseModel):
prompt: str
class ClassificationInput(BaseModel):
data: str
@app.post("/rag")
def rag_endpoint(prompt: Prompt):
try:
summary = generate_summary(prompt.prompt, model, tokenizer)
return {"summary": summary}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classification")
def classification_endpoint(input: ClassificationInput):
try:
category = classify_text(input.data, classifier, vectorizer)
return {"category": category}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|