Spaces:
Runtime error
Runtime error
import json | |
import os | |
import uuid | |
import pandas as pd | |
import streamlit as st | |
import argparse | |
import traceback | |
from typing import Dict | |
import requests | |
from utils.utils import load_data_split | |
from nsql.database import NeuralDB | |
from nsql.nsql_exec import NSQLExecutor | |
from nsql.nsql_exec_python import NPythonExecutor | |
from generation.generator import Generator | |
import time | |
ROOT_DIR = os.path.join(os.path.dirname(__file__), "./") | |
EXAMPLE_TABLES = { | |
"Estonia men's national volleyball team": (558, "what are the total number of players from france?"), | |
"Highest mountain peaks of California": (5, "which is the lowest mountain?"), | |
"2010β11 UAB Blazers men's basketball team": (1, "how many players come from alabama?"), | |
"1999 European Tour": (209, "how many consecutive times was south africa the host country?"), | |
"Nissan SR20DET": (438, "which car is the only one with more than 230 hp?"), | |
} | |
def load_data(): | |
return load_data_split("missing_squall", "validation") | |
def get_key(): | |
# print the public IP of the demo machine | |
ip = requests.get('https://checkip.amazonaws.com').text.strip() | |
print(ip) | |
URL = "http://54.242.37.195:20217/api/predict" | |
# The springboard machine we built to protect the key, 20217 is the birthday of Tianbao's girlfriend | |
# we will only let the demo machine have the access to the keys | |
one_key = requests.post(url=URL, json={"data": "Hi, binder server. Give me a key!"}).json()['data'][0] | |
return one_key | |
def read_markdown(path): | |
with open(path, "r") as f: | |
output = f.read() | |
st.markdown(output, unsafe_allow_html=True) | |
def generate_binder_program(_args, _generator, _data_item): | |
n_shots = _args.n_shots | |
few_shot_prompt = _generator.build_few_shot_prompt_from_file( | |
file_path=_args.prompt_file, | |
n_shots=n_shots | |
) | |
generate_prompt = _generator.build_generate_prompt( | |
data_item=_data_item, | |
generate_type=(_args.generate_type,) | |
) | |
prompt = few_shot_prompt + "\n\n" + generate_prompt | |
# Ensure the input length fit Codex max input tokens by shrinking the n_shots | |
max_prompt_tokens = _args.max_api_total_tokens - _args.max_generation_tokens | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=os.path.join(ROOT_DIR, "utils", "gpt2")) | |
while len(tokenizer.tokenize(prompt)) >= max_prompt_tokens: # TODO: Add shrink rows | |
n_shots -= 1 | |
assert n_shots >= 0 | |
few_shot_prompt = _generator.build_few_shot_prompt_from_file( | |
file_path=_args.prompt_file, | |
n_shots=n_shots | |
) | |
prompt = few_shot_prompt + "\n\n" + generate_prompt | |
response_dict = _generator.generate_one_pass( | |
prompts=[("0", prompt)], # the "0" is the place taker, take effect only when there are multi threads | |
verbose=_args.verbose | |
) | |
print(response_dict) | |
return response_dict["0"][0][0] | |
# Set up | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--prompt_file', type=str, default='templates/prompts/prompt_wikitq_v3.txt') | |
# Binder program generation options | |
parser.add_argument('--prompt_style', type=str, default='create_table_select_3_full_table', | |
choices=['create_table_select_3_full_table', | |
'create_table_select_full_table', | |
'create_table_select_3', | |
'create_table', | |
'create_table_select_3_full_table_w_all_passage_image', | |
'create_table_select_3_full_table_w_gold_passage_image', | |
'no_table']) | |
parser.add_argument('--generate_type', type=str, default='nsql', | |
choices=['nsql', 'sql', 'answer', 'npython', 'python']) | |
parser.add_argument('--n_shots', type=int, default=14) | |
parser.add_argument('--seed', type=int, default=42) | |
# Codex options | |
# todo: Allow adjusting Codex parameters | |
parser.add_argument('--engine', type=str, default="code-davinci-002") | |
parser.add_argument('--max_generation_tokens', type=int, default=512) | |
parser.add_argument('--max_api_total_tokens', type=int, default=8001) | |
parser.add_argument('--temperature', type=float, default=0.) | |
parser.add_argument('--sampling_n', type=int, default=1) | |
parser.add_argument('--top_p', type=float, default=1.0) | |
parser.add_argument('--stop_tokens', type=str, default='\n\n', | |
help='Split stop tokens by ||') | |
parser.add_argument('--qa_retrieve_pool_file', type=str, default='templates/qa_retrieve_pool.json') | |
# debug options | |
parser.add_argument('-v', '--verbose', action='store_false') | |
args = parser.parse_args() | |
keys = [get_key()] | |
# The title | |
st.markdown("# Binder Playground") | |
# Summary about Binder | |
read_markdown('resources/summary.md') | |
# Introduction of Binder | |
# todo: Write Binder introduction here | |
# read_markdown('resources/introduction.md') | |
st.image('resources/intro.png') | |
# Upload tables/Switch tables | |
st.markdown('### Try Binder!') | |
col1, _ = st.columns(2) | |
with col1: | |
selected_table_title = st.selectbox( | |
"Select an example table", | |
( | |
"Estonia men's national volleyball team", | |
"Highest mountain peaks of California", | |
"2010β11 UAB Blazers men's basketball team", | |
"1999 European Tour", | |
"Nissan SR20DET", | |
) | |
) | |
# Here we just use ourselves' | |
data_items = load_data() | |
data_item = data_items[EXAMPLE_TABLES[selected_table_title][0]] | |
table = data_item['table'] | |
header, rows, title = table['header'], table['rows'], table['page_title'] | |
db = NeuralDB( | |
[{"title": title, "table": table}]) # todo: try to cache this db instead of re-creating it again and again. | |
df = db.get_table_df() | |
st.markdown("Title: {}".format(title)) | |
st.dataframe(df) | |
# Let user input the question | |
question = st.text_input( | |
"Ask a question about the table:", | |
value=EXAMPLE_TABLES[selected_table_title][1] | |
) | |
with col1: | |
# todo: Why selecting language will flush the page? | |
selected_language = st.selectbox( | |
"Select a programming language", | |
("SQL", "Python"), | |
) | |
if selected_language == 'SQL': | |
args.prompt_file = 'templates/prompts/prompt_wikitq_v3.txt' | |
args.generate_type = 'nsql' | |
elif selected_language == 'Python': | |
args.prompt_file = 'templates/prompts/prompt_wikitq_python_simplified_v4.txt' | |
args.generate_type = 'npython' | |
else: | |
raise ValueError(f'{selected_language} language is not supported.') | |
button = st.button("Generate program") | |
if not button: | |
st.stop() | |
# Generate Binder Program | |
generator = Generator(args, keys=keys) | |
with st.spinner("Generating program ..."): | |
binder_program = generate_binder_program(args, generator, | |
{"question": question, "table": db.get_table_df(), "title": title}) | |
# Do execution | |
st.markdown("#### Binder program") | |
if selected_language == 'SQL': | |
with st.container(): | |
st.write(binder_program) | |
executor = NSQLExecutor(args, keys=keys) | |
elif selected_language == 'Python': | |
st.code(binder_program, language='python') | |
executor = NPythonExecutor(args, keys=keys) | |
db = db.get_table_df() | |
else: | |
raise ValueError(f'{selected_language} language is not supported.') | |
try: | |
stamp = '{}'.format(uuid.uuid4()) | |
os.makedirs('tmp_for_vis/', exist_ok=True) | |
with st.spinner("Executing program ..."): | |
exec_answer = executor.nsql_exec(stamp, binder_program, db) | |
# todo: Make it more pretty! | |
# todo: Do we need vis for Python? | |
if selected_language == 'SQL': | |
with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "r") as f: | |
steps = json.load(f) | |
st.markdown("#### Steps & Intermediate results") | |
for i, step in enumerate(steps): | |
st.markdown(step) | |
st.text("β") | |
with st.spinner('...'): | |
time.sleep(1) | |
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, i), "r") as f: | |
result_in_this_step = json.load(f) | |
if isinstance(result_in_this_step, Dict): | |
st.dataframe(pd.DataFrame(pd.DataFrame(result_in_this_step["rows"], columns=result_in_this_step["header"]))) | |
else: | |
st.markdown(result_in_this_step) | |
st.text("β") | |
elif selected_language == 'Python': | |
pass | |
if isinstance(exec_answer, list) and len(exec_answer) == 1: | |
exec_answer = exec_answer[0] | |
st.markdown(f'Execution answer: {exec_answer}') | |
except Exception as e: | |
traceback.print_exc() | |