import os os.system('pip install gradio==2.3.5b0') os.system('pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html') import gradio as gr from transformers import pipeline import pandas as pd table = pd.DataFrame() tqa = pipeline(task="table-question-answering", model="google/tapas-base-finetuned-wtq") def chat(message): history = gr.get_state() or [] global table if message.startswith('http'): table = pd.read_csv(message) table = table.astype(str) response = 'thank you for the dataset... now you can ask questions about it' elif table.empty: response = 'Hi! Please send a url of a dataset in csv format. Then ask as many questions as you want about it. If you want to talk about another dataset, just send a new link.' else: response = tqa(table=table, query=message)["answer"] history.append((message, response)) gr.set_state(history) html = "
" for user_msg, resp_msg in history: html += f"
{user_msg}
" html += f"
{resp_msg}
" html += "
" return html iface = gr.Interface(chat, "text", "html", css=""" .chatbox {display:flex;flex-direction:column} .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} .user_msg {background-color:cornflowerblue;color:white;align-self:start} .resp_msg {background-color:lightgray;align-self:self-end} """, allow_screenshot=False, allow_flagging=False) if __name__ == "__main__": iface.launch(debug=True)