File size: 4,595 Bytes
2714773 5782838 9d6743e 2714773 d7a34dd 2714773 d7a34dd 2714773 9d6743e 2714773 ca38751 2714773 9c35ba8 2714773 978fd4d 2714773 6651d18 2714773 5758bb4 2714773 0f46be8 2714773 5758bb4 2714773 0bb8f5d 37e8e73 9bb7983 0bb8f5d 2714773 9c35ba8 b15f680 37e8e73 9c35ba8 95934ef 9c35ba8 2714773 978fd4d 5758bb4 2714773 0bb8f5d 9c35ba8 2714773 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr
import numpy as np
import time
import os
#import pyodbc
#import pkg_resources
'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server' # This depends on the ODBC driver installed on your system
# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'
# Connect to the SQL Server
conn = pyodbc.connect(connection_string)
#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'
# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)
# Close the SQL connection
conn.close()
'''
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)
# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
#sql_response = None
#conversation_history = []
def chat(input, history=[]):
#global sql_response
# Check if the user input is a question
#is_question = "?" in input
'''
if is_question:
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
else:
'''
# tokenize the new input sentence
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
# convert the tokens to text, and then split the responses into the right format
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
return response, history
def sqlquery(input):
#input_text = " ".join(conversation_history) + " " + input
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
'''
global conversation_history
# Maintain the conversation history
conversation_history.append("User: " + input + "\n")
conversation_history.append("Bot: " + " ".join(sql_response) + "\n" )
output = " ".join(conversation_history)
return output
'''
return sql_response
chat_interface = gr.Interface(
fn=chat,
theme="default",
css=".footer {display:none !important}",
inputs=["text", "state"],
outputs=["chatbot", "state"],
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
sql_interface = gr.Interface(
fn=sqlquery,
theme="default",
inputs=gr.Textbox(prompt="You:"),
outputs=gr.Textbox(),
#live=True,
#capture_session=True,
title="ST SQL Chat",
description="Type your message in the box above, and the chatbot will respond.",
)
combine_interface = gr.TabbedInterface(
interface_list=[
chat_interface,
sql_interface
],
tab_names=['Chatbot' ,'SQL Chat'],
)
if __name__ == '__main__':
combine_interface.launch() |