File size: 8,528 Bytes
40c895f
 
 
 
 
 
 
 
 
 
 
80be5da
40c895f
f7fa740
40c895f
 
4592479
40c895f
 
 
 
 
78715f1
 
40c895f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f23744
40c895f
 
ba3b60e
5f23744
40c895f
 
 
9ae4caa
40c895f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d119dc4
913de56
 
 
 
80be5da
 
 
 
 
 
 
 
66983ee
e2ff706
9c4badc
d119dc4
e2ff706
2ea5c26
 
 
8a0ec5f
 
e2ff706
2ea5c26
 
8a0ec5f
 
2ea5c26
8a0ec5f
 
2ea5c26
8a0ec5f
e2ff706
8a0ec5f
2ea5c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a0ec5f
e2ff706
4592479
2ea5c26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
from threading import Event, Thread
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer,
)
import gradio as gr
import torch
import sqlparse

model_name = os.getenv("HF_MODEL_NAME", None)
tok = AutoTokenizer.from_pretrained(model_name)

max_new_tokens = 1024

print(f"Starting to load the model {model_name}")

m = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=0,
    #load_in_8bit=True,
)

m.config.pad_token_id = m.config.eos_token_id
m.generation_config.pad_token_id = m.config.eos_token_id

stop_tokens = [";", "###", "Result"]
stop_token_ids = tok.convert_tokens_to_ids(stop_tokens)

print(f"Successfully loaded the model {model_name} into memory")

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in stop_token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

def bot(input_message: str, db_info="", temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08):
    stop = StopOnTokens()

    # Format the user's input message
    messages = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n\nConvert text to sql: {input_message} {db_info}\n\n### Response:\n\n"

    input_ids = tok(messages, return_tensors="pt").input_ids
    input_ids = input_ids.to(m.device)
    streamer = TextIteratorStreamer(tok, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=temperature > 0.0,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        streamer=streamer,
        stopping_criteria=StoppingCriteriaList([stop]),
    )

    stream_complete = Event()

    def generate_and_signal_complete():
        m.generate(**generate_kwargs)
        stream_complete.set()

    t1 = Thread(target=generate_and_signal_complete)
    t1.start()

    partial_text = ""
    for new_text in streamer:
        partial_text += new_text

    # Split the text by "|", and get the last element in the list which should be the final query
    try:
        final_query = partial_text.split("|")[1].strip()
    except Exception:
        final_query = partial_text

    try:
        # Attempt to format SQL query using sqlparse
        formatted_query = sqlparse.format(final_query, reindent=True, keyword_case='upper')
    except Exception:
        # If formatting fails, use the original, unformatted query
        formatted_query = final_query

    # Convert SQL to markdown (not required, but just to show how to use the markdown module)
    final_query_markdown = f"{formatted_query}"
    return final_query_markdown

with gr.Blocks(theme='gradio/soft') as demo:
    header = gr.HTML("""
        <h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
        <h3 style="text-align: center">πŸ§™β€β™‚οΈ Generate SQL queries from Natural Language πŸ§™β€β™‚οΈ</h3>
    """)

    output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
    input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
    db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')

    with gr.Accordion("Hyperparameters", open=False):
        temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
        top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
        top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
        repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
        
    run_button = gr.Button("Generate SQL", variant="primary")
    
    with gr.Accordion("Examples", open=True):
        examples = gr.Examples([
            ["What is the average, minimum, and maximum age for all French singers?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["Show location and name for all stadiums with a capacity between 5000 and 10000.", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["What are the number of concerts that occurred in the stadium with the largest capacity ?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["How many male singers performed in concerts in the year 2023?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["List the names of all singers who performed in a concert with the theme 'Rock'", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"]
        ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], fn=bot)

    bitsandbytes_model = "richardr1126/spider-skeleton-wizard-coder-8bit"
    merged_model = "richardr1126/spider-skeleton-wizard-coder-merged"
    initial_model = "WizardLM/WizardCoder-15B-V1.0"
    finetuned_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
    dataset = "richardr1126/spider-skeleton-context-instruct"
    
    footer = gr.HTML(f"""
        <p>πŸ› οΈ If you want you can <strong>duplicate this Space</strong>, then change the HF_MODEL_REPO spaces env varaible to use any Transformers model.</p>
        <p>🌐 Leveraging the <a href='https://huggingface.co/{bitsandbytes_model}'><strong>bitsandbytes 8-bit version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
        <p>πŸ”— How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{finetuned_model}'><strong>{finetuned_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
        <p>πŸ“‰ Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{finetuned_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
    """)


    run_button.click(fn=bot, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="txt2sql")

demo.queue(concurrency_count=1, max_size=10).launch()