|
import os |
|
from threading import Event, Thread |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
TextIteratorStreamer, |
|
) |
|
import gradio as gr |
|
import torch |
|
import sqlparse |
|
|
|
model_name = os.getenv("HF_MODEL_NAME", None) |
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
|
|
max_new_tokens = 1024 |
|
|
|
print(f"Starting to load the model {model_name}") |
|
|
|
m = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map=0, |
|
|
|
) |
|
|
|
m.config.pad_token_id = m.config.eos_token_id |
|
m.generation_config.pad_token_id = m.config.eos_token_id |
|
|
|
stop_tokens = [";", "###", "Result"] |
|
stop_token_ids = tok.convert_tokens_to_ids(stop_tokens) |
|
|
|
print(f"Successfully loaded the model {model_name} into memory") |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_id in stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
def bot(input_message: str, db_info="", temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08): |
|
stop = StopOnTokens() |
|
|
|
|
|
messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n" |
|
|
|
input_ids = tok(messages, return_tensors="pt").input_ids |
|
input_ids = input_ids.to(m.device) |
|
streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
do_sample=temperature > 0.0, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
streamer=streamer, |
|
stopping_criteria=StoppingCriteriaList([stop]), |
|
) |
|
|
|
stream_complete = Event() |
|
|
|
def generate_and_signal_complete(): |
|
m.generate(**generate_kwargs) |
|
stream_complete.set() |
|
|
|
t1 = Thread(target=generate_and_signal_complete) |
|
t1.start() |
|
|
|
partial_text = "" |
|
for new_text in streamer: |
|
partial_text += new_text |
|
|
|
|
|
final_query = partial_text.split("|")[1].strip() |
|
|
|
try: |
|
|
|
formatted_query = sqlparse.format(final_query, reindent=True, keyword_case='upper') |
|
except Exception: |
|
|
|
formatted_query = final_query |
|
|
|
|
|
final_query_markdown = f"```sql\n{formatted_query}\n```" |
|
return final_query_markdown |
|
|
|
with gr.Blocks(css_theme="light") as demo: |
|
header_md = gr.Markdown(""" |
|
# SQL Skeleton WizardCoder Demo |
|
""") |
|
|
|
output_box = gr.Code(label="Generated SQL", lines=2, placeholder="Output will appear here after running the model.") |
|
input_text = gr.Textbox(lines=3, placeholder='Input text here...', label='Input Text') |
|
db_info = gr.Textbox(lines=6, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info') |
|
|
|
with gr.Accordion("Hyperparameters", open=False): |
|
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.1, step=0.1) |
|
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01) |
|
top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1) |
|
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.1) |
|
|
|
run_button = gr.Button("Generate SQL") |
|
|
|
examples = gr.Examples([ |
|
["What is the average, minimum, and maximum age for all French singers?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"], |
|
["Show location and name for all stadiums with a capacity between 5000 and 10000.", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"], |
|
["What are the number of concerts that occurred in the stadium with the largest capacity ?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"], |
|
["How many male singers performed in concerts in the year 2023?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"], |
|
["List the names of all singers who performed in a concert with the theme 'Rock'", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"] |
|
], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], fn=bot) |
|
|
|
run_button.click(fn=bot, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], outputs=output_box) |
|
|
|
demo.launch() |
|
|