bewchatbot / app.py
william4416's picture
Update app.py
37a9bbd verified
raw
history blame
2.4 kB
# Import required libraries
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
# Create FastAPI app instance
app = FastAPI()
# Load DialoGPT model and tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Model loading failed: {e}")
# Load courses data
try:
with open("uts_courses.json", "r") as file:
courses_data = json.load(file)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Courses data loading failed: {e}")
# Define user input model
class UserInput(BaseModel):
user_input: str
# Generate response function
def generate_response(user_input: str):
"""
Generate response based on user input
Args:
user_input: User input text
Returns:
Generated response text
"""
if user_input.lower() == "help":
return "I can help you with UTS courses information, feel free to ask!"
elif user_input.lower() == "exit":
return "Goodbye!"
elif user_input.lower() == "list courses":
# Generate course list
course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in courses_data["courses"].items()])
return f"Here are the available courses:\n{course_list}"
elif user_input.lower() in courses_data["courses"]:
# List courses under the specified course category
return f"The courses in {user_input} category are: {', '.join(courses_data['courses'][user_input])}"
else:
# Use DialoGPT model to generate response
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
return response
# Define API route
@app.post("/")
async def chat(user_input: UserInput):
"""
Process user input and return response
Args:
user_input: User input JSON data
Returns:
JSON data containing the response text
"""
response = generate_response(user_input.user_input)
return {"response": response}