william4416 commited on
Commit
37a9bbd
1 Parent(s): 5633877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -9
app.py CHANGED
@@ -1,8 +1,10 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import json
5
 
 
6
  app = FastAPI()
7
 
8
  # Load DialoGPT model and tokenizer
@@ -10,38 +12,60 @@ try:
10
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
11
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
12
  except Exception as e:
13
- raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
14
 
15
- # Load courses data from JSON file
16
  try:
17
  with open("uts_courses.json", "r") as file:
18
  courses_data = json.load(file)
19
  except Exception as e:
20
- raise HTTPException(status_code=500, detail=f"Failed to load courses data: {e}")
21
 
 
22
  class UserInput(BaseModel):
23
  user_input: str
24
 
 
25
  def generate_response(user_input: str):
 
 
 
 
 
 
 
 
 
26
  if user_input.lower() == "help":
27
- return "I can help you with information about UTS courses. Feel free to ask!"
28
  elif user_input.lower() == "exit":
29
  return "Goodbye!"
30
  elif user_input.lower() == "list courses":
 
31
  course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in courses_data["courses"].items()])
32
  return f"Here are the available courses:\n{course_list}"
33
  elif user_input.lower() in courses_data["courses"]:
34
- return f"The courses in {user_input} are: {', '.join(courses_data['courses'][user_input])}"
 
35
  else:
36
- # Tokenize the user input
37
  input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
38
- # Generate a response
39
  response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
40
- # Decode the response
41
  response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
42
  return response
43
 
 
44
  @app.post("/")
45
- def chat(user_input: UserInput):
 
 
 
 
 
 
 
 
 
46
  response = generate_response(user_input.user_input)
47
  return {"response": response}
 
 
1
+ # Import required libraries
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import json
6
 
7
+ # Create FastAPI app instance
8
  app = FastAPI()
9
 
10
  # Load DialoGPT model and tokenizer
 
12
  tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
13
  model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
14
  except Exception as e:
15
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {e}")
16
 
17
+ # Load courses data
18
  try:
19
  with open("uts_courses.json", "r") as file:
20
  courses_data = json.load(file)
21
  except Exception as e:
22
+ raise HTTPException(status_code=500, detail=f"Courses data loading failed: {e}")
23
 
24
+ # Define user input model
25
  class UserInput(BaseModel):
26
  user_input: str
27
 
28
+ # Generate response function
29
  def generate_response(user_input: str):
30
+ """
31
+ Generate response based on user input
32
+
33
+ Args:
34
+ user_input: User input text
35
+
36
+ Returns:
37
+ Generated response text
38
+ """
39
  if user_input.lower() == "help":
40
+ return "I can help you with UTS courses information, feel free to ask!"
41
  elif user_input.lower() == "exit":
42
  return "Goodbye!"
43
  elif user_input.lower() == "list courses":
44
+ # Generate course list
45
  course_list = "\n".join([f"{category}: {', '.join(courses)}" for category, courses in courses_data["courses"].items()])
46
  return f"Here are the available courses:\n{course_list}"
47
  elif user_input.lower() in courses_data["courses"]:
48
+ # List courses under the specified course category
49
+ return f"The courses in {user_input} category are: {', '.join(courses_data['courses'][user_input])}"
50
  else:
51
+ # Use DialoGPT model to generate response
52
  input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
 
53
  response_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
 
54
  response = tokenizer.decode(response_ids[0], skip_special_tokens=True)
55
  return response
56
 
57
+ # Define API route
58
  @app.post("/")
59
+ async def chat(user_input: UserInput):
60
+ """
61
+ Process user input and return response
62
+
63
+ Args:
64
+ user_input: User input JSON data
65
+
66
+ Returns:
67
+ JSON data containing the response text
68
+ """
69
  response = generate_response(user_input.user_input)
70
  return {"response": response}
71
+