"""Streamlit app for demoing nsql-llama-2-70B.""" import json import os import pandas as pd import requests import streamlit as st from manifest import Manifest, Response from manifest.connections.client_pool import ClientConnection STOP_TOKENS = ["###", ";", "--", "```"] def generate_prompt(question, schema): return f"""{schema}\n\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- {question}\n""" def generate_sql(question, schema): prompt = generate_prompt(question, schema) url = st.secrets["backend_url"] headers = { "Content-Type": "application/json", "key": st.secrets["key"], } data = { "inputs": [prompt], "params": { "do_sample": {"type": "bool", "value": "false"}, "max_tokens_to_generate": {"type": "int", "value": "1000"}, "repetition_penalty": {"type": "float", "value": "1"}, "temperature": {"type": "float", "value": "1"}, "top_k": {"type": "int", "value": "50"}, "top_logprobs": {"type": "int", "value": "0"}, "top_p": {"type": "float", "value": "1"}, }, } r = requests.post(url, headers=headers, data=json.dumps(data), stream=True) if r.encoding is None: r.encoding = "utf-8" for line in r.iter_lines(decode_unicode=True): if line and line.startswith("data: "): output = json.loads(line[len("data: ") :]) token = output.get("stream_token", "") if len(token) > 0: yield token st.title("nsql-llama-2-70B Demo") expander = st.expander("Database Schema") # Input field for text prompt # TODO(Bo Li): update this with the new example default_schema = """CREATE TABLE stadium ( stadium_id number, location text, name text, capacity number, highest number, lowest number, average number ) CREATE TABLE singer ( singer_id number, name text, country text, song_name text, song_release_year text, age number, is_male others ) CREATE TABLE concert ( concert_id number, concert_name text, theme text, stadium_id text, year text ) CREATE TABLE singer_in_concert ( concert_id number, singer_id text )""" schema = expander.text_area("Current schema:", value=default_schema, height=500) # Input field for text prompt text_prompt = st.text_input( "Please let me know what question do you want to ask?", value="What is the maximum, the average, and the minimum capacity of stadiums ?", ) # if text_prompt or if st.button("Generate SQL"): sql_query = generate_sql(text_prompt, schema) st.write_stream(sql_query)