File size: 2,860 Bytes
23c1edb
f65b03e
9d6743e
 
 
 
 
 
 
 
 
 
 
 
ca38751
9d6743e
 
 
 
e030ac0
9d6743e
 
 
 
 
 
829e215
 
 
9d6743e
 
 
 
 
 
c17ba77
829e215
 
f65b03e
9d6743e
829e215
9d6743e
8a4fd9e
69beb29
46920ac
 
 
 
cee9f1f
 
 
46920ac
cee9f1f
46920ac
 
 
 
 
 
 
 
 
 
 
 
cee9f1f
46920ac
 
 
 
f65b03e
23c1edb
b92090c
f65b03e
4e86ef1
f65b03e
 
 
 
830c2c9
 
9d6743e
 
23c1edb
4e86ef1
f65b03e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch

import numpy as np
import time
import os
#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" 
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)

# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)

sql_response = None

def predict(input, history=[]):

    global sql_response
    # Check if the user input is a question
    is_question = "?" in input


    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')

     # Convert history tensor to a list
    history_list = history.tolist() if isinstance(history, torch.Tensor) else history
    
    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history_list), new_user_input_ids], dim=-1)
      
    # generate a response
    history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
    response_dialog = tokenizer.decode(history[0])

    # Use the SQL model to generate a response
    encoding = sql_tokenizer(table=table, query=response_dialog, return_tensors="pt")
    outputs = sql_model.generate(**encoding)
    response_sql = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Add the SQL model's response to the chat history
    history.extend(response_sql)
  
    # convert the tokens to text, and then split the responses into the right format
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list

    return response, history


import gradio as gr

interface = gr.Interface(
    fn=predict,
    theme="default",
    css=".footer {display:none !important}",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

if __name__ == '__main__':
    interface.launch()