william4416 commited on
Commit
bf950aa
·
verified ·
1 Parent(s): a495b29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -56
app.py CHANGED
@@ -1,68 +1,67 @@
1
- import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import json
 
 
 
 
 
4
 
5
- # Load pre-trained model and tokenizer (replace with desired model if needed)
6
- model_name = "microsoft/DialoGPT-large" # Replace with your preferred model
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
- # Function to process user input and generate response
11
- def chat(message, history):
12
- # Preprocess user input
13
- input_ids = tokenizer(message, return_tensors="pt")["input_ids"]
14
 
15
- # Generate response with beam search to improve fluency
16
- generated_outputs = model.generate(
17
- input_ids,
18
- max_length=512, # Adjust max_length as needed for response length
19
- num_beams=5, # Experiment with num_beams for better phrasing
20
- no_repeat_ngram_size=2, # Prevent repetition in responses
21
- early_stopping=True, # Stop generation if response seems complete
22
- )
23
 
24
- # Decode generated tokens to text
25
- response = tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)[0]
 
26
 
27
- # Access and process JSON files
28
- json_files = {
29
- "fileone.json": "your_key_in_fileone",
30
- "filesecond.json": "your_key_in_filesecond",
31
- "filethird.json": "your_key_in_filethird",
32
- "filefourth.json": "your_key_in_filefourth",
33
- "filefifth.json": "your_key_in_filefifth",
34
- }
35
 
36
- for filename, key in json_files.items():
37
- if key.lower() in message.lower():
38
- try:
39
- with open(filename, "r") as f:
40
- data = json.load(f)
41
- relevant_info = data.get(key, "No relevant information found")
42
- response += f"\nHere's some information I found in {filename}: {relevant_info}"
43
- except FileNotFoundError:
44
- response += f"\nCouldn't find the file: {filename}"
45
- except json.JSONDecodeError:
46
- response += f"\nError processing the JSON data in file: {filename}"
47
-
48
- # Update history with current conversation (optional)
49
- # history.append([message, response]) # Uncomment if you want conversation history
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return response
52
 
53
- # Check Gradio version and handle catch_exceptions accordingly
54
- try:
55
- # Option 1: Use catch_exceptions if Gradio supports it
56
- interface = gr.Interface(chat, inputs="textbox", outputs="textbox", catch_exceptions=True)
57
- except TypeError: # If catch_exceptions is not supported
58
- # Option 2: Manual error handling within chat function
59
- def chat_with_error_handling(message, history):
60
- try:
61
- return chat(message, history)
62
- except Exception as e:
63
- return f"An error occurred: {str(e)}"
64
 
65
- interface = gr.Interface(chat_with_error_handling, inputs="textbox", outputs="textbox")
 
 
 
 
 
66
 
67
- # Launch the Gradio app and share link
68
- interface.launch(share=True)
 
 
 
1
  import json
2
+ from transformers import pipeline
3
+ import nltk
4
+ from nltk.corpus import stopwords
5
+ from nltk.tokenize import word_tokenize
6
+ from nltk.stem import WordNetLemmatizer
7
 
8
+ # Download NLTK resources
9
+ nltk.download('punkt')
10
+ nltk.download('wordnet')
11
+ nltk.download('stopwords')
12
 
13
+ # Load the JSON data from the file
14
+ with open('uts_courses.json') as f:
15
+ data = json.load(f)
 
16
 
17
+ # Load the question-answering pipeline
18
+ qa_pipeline = pipeline("question-answering")
 
 
 
 
 
 
19
 
20
+ # Define stop words and lemmatizer
21
+ stop_words = set(stopwords.words('english'))
22
+ lemmatizer = WordNetLemmatizer()
23
 
24
+ # Function to preprocess user input
25
+ def preprocess_input(user_input):
26
+ tokens = word_tokenize(user_input.lower())
27
+ filtered_tokens = [lemmatizer.lemmatize(word) for word in tokens if word.isalnum() and word not in stop_words]
28
+ return " ".join(filtered_tokens)
 
 
 
29
 
30
+ # Function to find courses by field of study
31
+ def find_courses_by_field(field):
32
+ if field in data['courses']:
33
+ return data['courses'][field]
34
+ else:
35
+ return []
 
 
 
 
 
 
 
 
36
 
37
+ # Function to handle user input and generate responses
38
+ def generate_response(user_input):
39
+ user_input = preprocess_input(user_input)
40
+ if user_input == 'exit':
41
+ return "Exiting the program."
42
+ elif "courses" in user_input and "available" in user_input:
43
+ field = user_input.split("in ")[1]
44
+ courses = find_courses_by_field(field)
45
+ if courses:
46
+ response = f"Courses in {field}: {', '.join(courses)}"
47
+ else:
48
+ response = f"No courses found in {field}."
49
+ else:
50
+ answer = qa_pipeline(question=user_input, context=data)
51
+ response = answer['answer']
52
  return response
53
 
54
+ # Main function to interact with the user
55
+ def main():
56
+ print("Welcome! I'm the UTS Course Chatbot. How can I assist you today?")
57
+ print("You can ask questions about UTS courses or type 'exit' to end the conversation.")
 
 
 
 
 
 
 
58
 
59
+ while True:
60
+ user_input = input("You: ")
61
+ response = generate_response(user_input)
62
+ print("Bot:", response)
63
+ if response == "Exiting the program.":
64
+ break
65
 
66
+ if __name__ == "__main__":
67
+ main()