teaevo commited on
Commit
0f46be8
1 Parent(s): 4876387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -130
app.py CHANGED
@@ -1,138 +1,45 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from transformers import TapexTokenizer, BartForConditionalGeneration
4
- import pandas as pd
5
  import gradio as gr
 
6
 
7
- import numpy as np
8
- import time
9
- import os
10
 
11
- #import pyodbc
 
12
 
13
- #import pkg_resources
 
14
 
15
- '''
16
- # Get a list of installed packages and their versions
17
- installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
 
 
18
 
19
- # Print the list of packages
20
- for package, version in installed_packages.items():
21
- print(f"{package}=={version}")
22
- '''
 
23
 
24
- '''
25
- # Replace the connection parameters with your SQL Server information
26
- server = 'your_server'
27
- database = 'your_database'
28
- username = 'your_username'
29
- password = 'your_password'
30
- driver = 'SQL Server' # This depends on the ODBC driver installed on your system
31
-
32
- # Create the connection string
33
- connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'
34
-
35
- # Connect to the SQL Server
36
- conn = pyodbc.connect(connection_string)
37
-
38
- #============================================================================
39
- # Replace "your_query" with your SQL query to fetch data from the database
40
- query = 'SELECT * FROM your_table_name'
41
-
42
- # Use pandas to read data from the SQL Server and store it in a DataFrame
43
- df = pd.read_sql_query(query, conn)
44
-
45
- # Close the SQL connection
46
- conn.close()
47
- '''
48
-
49
- data = {
50
- "year": [1896, 1900, 1904, 2004, 2008, 2012],
51
- "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
52
- }
53
- table = pd.DataFrame.from_dict(data)
54
-
55
-
56
- # Load the chatbot model
57
- chatbot_model_name = "microsoft/DialoGPT-medium"
58
- tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
59
- model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
60
-
61
- # Load the SQL Model
62
- sql_model_name = "microsoft/tapex-large-finetuned-wtq"
63
- sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
64
- sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
65
-
66
- #sql_response = None
67
-
68
- def predict(input, history=[]):
69
-
70
- #global sql_response
71
- # Check if the user input is a question
72
- #is_question = "?" in input
73
-
74
- '''
75
- if is_question:
76
- sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
77
- sql_outputs = sql_model.generate(**sql_encoding)
78
- sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
79
-
80
- else:
81
- '''
82
-
83
- # tokenize the new input sentence
84
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
85
-
86
- # append the new user input tokens to the chat history
87
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
88
-
89
- # generate a response
90
- history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
91
-
92
- # convert the tokens to text, and then split the responses into the right format
93
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
94
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
95
 
96
- return response, history
97
-
98
-
99
- def sqlquery(input):
100
-
101
- sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
102
- sql_outputs = sql_model.generate(**sql_encoding)
103
- sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
104
-
105
- return sql_response
106
-
107
-
108
- chat_interface = gr.Interface(
109
- fn=predict,
110
- theme="default",
111
- css=".footer {display:none !important}",
112
- inputs=["text", "state"],
113
- outputs=["chatbot", "state"],
114
- title="ST Chatbot",
115
- description="Type your message in the box above, and the chatbot will respond.",
116
- )
117
-
118
- sql_interface = gr.Interface(
119
- fn=sqlquery,
120
- theme="default",
121
- inputs=gr.Textbox(prompt="You:"),
122
- outputs=gr.Textbox(),
123
- live=True,
124
- capture_session=True,
125
- title="ST SQL Chat",
126
- description="Type your message in the box above, and the chatbot will respond.",
127
- )
128
-
129
- combine_interface = gr.TabbedInterface(
130
- interface_list=[
131
- chat_interface,
132
- sql_interface
133
- ],
134
- tab_names=['Chatbot' ,'SQL Chat'],
135
- )
136
-
137
- if __name__ == '__main__':
138
- combine_interface.launch()
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TapasForQuestionAnswering, TapasTokenizer
3
 
4
+ # Load the models and tokenizers
5
+ tapas_model_name = "google/tapas-large-finetuned-wtq"
6
+ dialogpt_model_name = "microsoft/DialoGPT-medium"
7
 
8
+ tapas_tokenizer = TapasTokenizer.from_pretrained(tapas_model_name)
9
+ tapas_model = TapasForQuestionAnswering.from_pretrained(tapas_model_name)
10
 
11
+ dialogpt_tokenizer = AutoTokenizer.from_pretrained(dialogpt_model_name)
12
+ dialogpt_model = AutoModelForSeq2SeqLM.from_pretrained(dialogpt_model_name)
13
 
14
+ def answer_table_question(table, question):
15
+ encoding = tapas_tokenizer(table=table, query=question, return_tensors="pt")
16
+ outputs = tapas_model.generate(**encoding)
17
+ response = tapas_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
18
+ return response
19
 
20
+ def generate_dialog_response(prompt, conversation_history):
21
+ bot_input = dialogpt_tokenizer.encode(prompt + dialogpt_tokenizer.eos_token, return_tensors="pt")
22
+ chat_history_ids = dialogpt_model.generate(bot_input, max_length=1000, pad_token_id=dialogpt_tokenizer.eos_token_id)
23
+ response = dialogpt_tokenizer.decode(chat_history_ids[:, bot_input.shape[-1]:][0], skip_special_tokens=True)
24
+ return response
25
 
26
+ def chatbot_interface(user_input, table=gr.inputs.Textbox()):
27
+ global conversation_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ conversation_history.append(user_input)
30
+
31
+ # Check if user asks a question related to the table
32
+ if "table" in user_input:
33
+ question = user_input
34
+ answer = answer_table_question(table, question)
35
+ conversation_history.append(answer)
36
+ return "Bot (TAPAS): " + answer
37
+ else:
38
+ dialog_prompt = "User: " + " ".join(conversation_history) + "\nBot:"
39
+ response = generate_dialog_response(dialog_prompt, conversation_history)
40
+ conversation_history.append(response)
41
+ return "Bot (DialoGPT): " + response
42
+
43
+ conversation_history = []
44
+ iface = gr.Interface(fn=chatbot_interface, inputs=["text", "text"], outputs="text", live=True)
45
+ iface.launch()