Update app.py
Browse files
app.py
CHANGED
@@ -1,135 +1,94 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
""
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
17 |
-
import torch
|
18 |
|
19 |
-
# Download Microsoft's DialoGPT model and tokenizer
|
20 |
-
# The Hugging Face checkpoint for the model and its tokenizer is `"microsoft/DialoGPT-medium"`
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
#
|
25 |
-
|
26 |
-
#
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
# once chat starts, the history will be stored for chat continuity
|
35 |
-
self.chat_history_ids = None
|
36 |
-
# make input ids global to use them anywhere within the object
|
37 |
-
self.bot_input_ids = None
|
38 |
-
# a flag to check whether to end the conversation
|
39 |
-
self.end_chat = False
|
40 |
-
# greet while starting
|
41 |
-
self.welcome()
|
42 |
-
|
43 |
-
def welcome(self):
|
44 |
-
print("Initializing ChatBot ...")
|
45 |
-
# some time to get user ready
|
46 |
-
time.sleep(2)
|
47 |
-
print('Type "bye" or "quit" or "exit" to end chat \n')
|
48 |
-
# give time to read what has been printed
|
49 |
-
time.sleep(3)
|
50 |
-
# Greet and introduce
|
51 |
-
greeting = np.random.choice([
|
52 |
-
"Welcome, I am ChatBot, here for your kind service",
|
53 |
-
"Hey, Great day! I am your virtual assistant",
|
54 |
-
"Hello, it's my pleasure meeting you",
|
55 |
-
"Hi, I am a ChatBot. Let's chat!"
|
56 |
-
])
|
57 |
-
print("ChatBot >> " + greeting)
|
58 |
-
|
59 |
-
def user_input(self):
|
60 |
-
# receive input from user
|
61 |
-
text = input("User >> ")
|
62 |
-
# end conversation if user wishes so
|
63 |
-
if text.lower().strip() in ['bye', 'quit', 'exit']:
|
64 |
-
# turn flag on
|
65 |
-
self.end_chat=True
|
66 |
-
# a closing comment
|
67 |
-
print('ChatBot >> See you soon! Bye!')
|
68 |
-
time.sleep(1)
|
69 |
-
print('\nQuitting ChatBot ...')
|
70 |
-
else:
|
71 |
-
# continue chat, preprocess input text
|
72 |
-
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
73 |
-
self.new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, \
|
74 |
-
return_tensors='pt')
|
75 |
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
# append the new user input tokens to the chat history
|
78 |
-
|
79 |
-
|
80 |
-
self.bot_input_ids = torch.cat([self.chat_history_ids, self.new_user_input_ids], dim=-1)
|
81 |
else:
|
82 |
-
|
83 |
-
self.bot_input_ids = self.new_user_input_ids
|
84 |
|
85 |
-
# define the new chat_history_ids based on the preceding chats
|
86 |
# generated a response while limiting the total chat history to 1000 tokens,
|
87 |
-
|
88 |
-
pad_token_id=tokenizer.eos_token_id)
|
89 |
-
|
90 |
-
# last ouput tokens from bot
|
91 |
-
response = tokenizer.decode(self.chat_history_ids[:, self.bot_input_ids.shape[-1]:][0], \
|
92 |
-
skip_special_tokens=True)
|
93 |
-
# in case, bot fails to answer
|
94 |
-
if response == "":
|
95 |
-
response = self.random_response()
|
96 |
-
# print bot response
|
97 |
-
print('ChatBot >> '+ response)
|
98 |
-
|
99 |
-
# in case there is no response from model
|
100 |
-
def random_response(self):
|
101 |
-
i = -1
|
102 |
-
response = tokenizer.decode(self.chat_history_ids[:, self.bot_input_ids.shape[i]:][0], \
|
103 |
-
skip_special_tokens=True)
|
104 |
-
# iterate over history backwards to find the last token
|
105 |
-
while response == '':
|
106 |
-
i = i-1
|
107 |
-
response = tokenizer.decode(self.chat_history_ids[:, self.bot_input_ids.shape[i]:][0], \
|
108 |
-
skip_special_tokens=True)
|
109 |
-
# if it is a question, answer suitably
|
110 |
-
if response.strip() == '?':
|
111 |
-
reply = np.random.choice(["I don't know",
|
112 |
-
"I am not sure"])
|
113 |
-
# not a question? answer suitably
|
114 |
-
else:
|
115 |
-
reply = np.random.choice(["Great",
|
116 |
-
"Fine. What's up?",
|
117 |
-
"Okay"
|
118 |
-
])
|
119 |
-
return reply
|
120 |
|
|
|
|
|
|
|
121 |
|
122 |
-
#
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
bot.bot_response()
|
133 |
-
|
134 |
|
135 |
-
#
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
from transformers import TapexTokenizer, BartForConditionalGeneration
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
#import pkg_resources
|
7 |
|
8 |
+
'''
|
9 |
+
# Get a list of installed packages and their versions
|
10 |
+
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
|
11 |
|
12 |
+
# Print the list of packages
|
13 |
+
for package, version in installed_packages.items():
|
14 |
+
print(f"{package}=={version}")
|
15 |
+
'''
|
16 |
|
17 |
+
# Load the chatbot model
|
18 |
+
chatbot_model_name = "microsoft/DialoGPT-medium" #"gpt2"
|
19 |
+
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
|
20 |
+
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
|
|
|
|
|
21 |
|
|
|
|
|
22 |
|
23 |
+
# Load the SQL Model
|
24 |
+
#wikisql take longer to process
|
25 |
+
#model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
|
26 |
+
#model_name = "microsoft/tapex-base-finetuned-wikisql"
|
27 |
+
#model_name = "microsoft/tapex-base-finetuned-wtq"
|
28 |
+
model_name = "microsoft/tapex-large-finetuned-wtq"
|
29 |
+
#model_name = "google/tapas-base-finetuned-wtq"
|
30 |
+
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
|
31 |
+
sql_model = BartForConditionalGeneration.from_pretrained(model_name)
|
32 |
|
33 |
+
data = {
|
34 |
+
"year": [1896, 1900, 1904, 2004, 2008, 2012],
|
35 |
+
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
|
36 |
+
}
|
37 |
+
table = pd.DataFrame.from_dict(data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
chat_history_ids = None
|
40 |
+
bot_input_ids = None
|
41 |
+
|
42 |
+
|
43 |
+
def chatbot_response(user_message):
|
44 |
+
|
45 |
+
global new_chat
|
46 |
+
global chat_history_ids
|
47 |
+
# Check if the user input is a question
|
48 |
+
is_question = "?" in user_message
|
49 |
+
|
50 |
+
if is_question:
|
51 |
+
# If the user input is a question, use TAPEx for question-answering
|
52 |
+
#inputs = user_query
|
53 |
+
encoding = sql_tokenizer(table=table, query=user_message, return_tensors="pt")
|
54 |
+
outputs = sql_model.generate(**encoding)
|
55 |
+
response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
56 |
+
else:
|
57 |
+
# Generate chatbot response using the chatbot model
|
58 |
+
'''
|
59 |
+
inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
|
60 |
+
outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
|
61 |
+
response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
62 |
+
'''
|
63 |
+
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
64 |
+
new_user_input_ids = chatbot_tokenizer.encode("User: " + user_message + chatbot_tokenizer.eos_token, return_tensors='pt')
|
65 |
+
|
66 |
# append the new user input tokens to the chat history
|
67 |
+
if chat_history_ids is not None:
|
68 |
+
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
|
|
|
69 |
else:
|
70 |
+
bot_input_ids = new_user_input_ids
|
|
|
71 |
|
|
|
72 |
# generated a response while limiting the total chat history to 1000 tokens,
|
73 |
+
chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
response = chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
76 |
+
|
77 |
+
return response
|
78 |
|
79 |
+
# Define the chatbot and SQL execution interfaces using Gradio
|
80 |
+
chatbot_interface = gr.Interface(
|
81 |
+
fn=chatbot_response,
|
82 |
+
inputs=gr.Textbox(prompt="You:"),
|
83 |
+
outputs=gr.Textbox(),
|
84 |
+
live=True,
|
85 |
+
capture_session=True,
|
86 |
+
title="ST Chatbot",
|
87 |
+
description="Type your message in the box above, and the chatbot will respond.",
|
88 |
+
)
|
|
|
|
|
89 |
|
90 |
+
# Launch the Gradio interface
|
91 |
+
if __name__ == "__main__":
|
92 |
+
chatbot_interface.launch()
|
93 |
+
|
94 |
+
|