Spaces:
Sleeping
Sleeping
# Import required libraries | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import json | |
# Create FastAPI app instance | |
app = FastAPI() | |
# Load DialoGPT model and tokenizer | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Model loading failed: {e}") | |
# Load courses data | |
try: | |
with open("uts_courses.json", "r") as file: | |
courses_data = json.load(file) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Courses data loading failed: {e}") | |
# Define user input model | |
class UserInput(BaseModel): | |
user_input: str | |
# Generate response function | |
def generate_response(user_input: str): | |
""" | |
Generate response based on user input | |
Args: | |
user_input: User input text | |
Returns: | |
Generated response text | |
""" | |
if user_input.lower() == "help": | |
return "I can help you with UTS courses information, feel free to ask!" | |
elif user_input.lower() == "exit": | |
return "Goodbye!" | |
elif user_input.lower() == "list courses": | |
# Generate course list | |
course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in courses_data["courses"].items()]) | |
return f"Here are the available courses:\n{course_list}" | |
elif user_input.lower() in courses_data["courses"]: | |
# List courses under the specified course category | |
return f"The courses in {user_input} category are: {', '.join(courses_data['courses'][user_input])}" | |
else: | |
# Use DialoGPT model to generate response | |
input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") | |
response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id) | |
response = tokenizer.decode(response_ids[0], skip_special_tokens=True) | |
return response | |
# Define API route | |
async def chat(user_input: UserInput): | |
""" | |
Process user input and return response | |
Args: | |
user_input: User input JSON data | |
Returns: | |
JSON data containing the response text | |
""" | |
response = generate_response(user_input.user_input) | |
return {"response": response} | |