File size: 1,545 Bytes
e5b2387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from app.model_utils import load_model_and_tokenizer, generate_summary
from app.classifier import train_classifier, classify_text

app = FastAPI()

# Load model and tokenizer for the /rag endpoint
model_name = "sshleifer/distilbart-cnn-6-6"  # Example model 
model, tokenizer = load_model_and_tokenizer(model_name)

# Dummy data and classifier for the /classification endpoint
dummy_data = [
    ("I feel very sad and hopeless.", "Depression"),
    ("I have trouble sleeping at night.", "Insomnia"),
    ("I am constantly worrying about everything.", "Anxiety"),
    ("I feel energetic and happy.", "Happiness"),
    ("My mood swings a lot and I feel irritable.", "Mood Disorder")
]

classifier, vectorizer = train_classifier(dummy_data)

class Prompt(BaseModel):
    prompt: str

class ClassificationInput(BaseModel):
    data: str

@app.post("/rag")
def rag_endpoint(prompt: Prompt):
    try:
        summary = generate_summary(prompt.prompt, model, tokenizer)
        return {"summary": summary}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/classification")
def classification_endpoint(input: ClassificationInput):
    try:
        category = classify_text(input.data, classifier, vectorizer)
        return {"category": category}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)