File size: 3,556 Bytes
23c1edb ca38751 23c1edb f24bed6 23c1edb 4e86ef1 23c1edb 4e86ef1 23c1edb e030ac0 c17ba77 23c1edb c17ba77 23c1edb 69beb29 830c2c9 23c1edb 830c2c9 501ede0 23c1edb c17ba77 830c2c9 501ede0 23c1edb 830c2c9 501ede0 830c2c9 4e86ef1 23c1edb 830c2c9 23c1edb 4e86ef1 23c1edb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch
import numpy as np
import time
import os
#import pkg_resources
'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" #"gpt2"
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
# Load the SQL Model
#wikisql take longer to process
#model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
#model_name = "microsoft/tapex-base-finetuned-wikisql"
#model_name = "microsoft/tapex-base-finetuned-wtq"
model_name = "microsoft/tapex-large-finetuned-wtq"
#model_name = "google/tapas-base-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
sql_model = BartForConditionalGeneration.from_pretrained(model_name)
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)
def chatbot_response(user_message, history=[]):
# Check if the user input is a question
is_question = "?" in user_message
if is_question:
# If the user input is a question, use TAPEx for question-answering
#inputs = user_query
encoding = sql_tokenizer(table=table, query=user_message, return_tensors="pt")
outputs = sql_model.generate(**encoding)
response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
else:
# Generate chatbot response using the chatbot model
'''
inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
'''
# tokenize the new input sentence
new_user_input_ids = chatbot_tokenizer.encode(input + chatbot_tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id).tolist()
# convert the tokens to text, and then split the responses into the right format
response = chatbot_tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
return response, history
# Define the chatbot and SQL execution interfaces using Gradio
chatbot_interface = gr.Interface(
fn=chatbot_response,
#inputs=gr.Textbox(prompt="You:"),
#outputs=gr.Textbox(),
inputs=["text", "state"],
outputs=["chatbot", "state"],
live=True,
capture_session=True,
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
# Launch the Gradio interface
if __name__ == "__main__":
chatbot_interface.launch()
|