Update app.py
Browse files
app.py
CHANGED
@@ -4,24 +4,26 @@ from datasets import load_dataset
|
|
4 |
from huggingface_hub import HfApi, login
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
6 |
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
dataset_name = "gretelai/synthetic_text_to_sql"
|
18 |
-
|
19 |
-
def process(action, model_name, dataset_name, system_prompt, user_prompt, sql_schema):
|
20 |
#raise gr.Error("Please clone and bring your own credentials.")
|
21 |
if action == action_1:
|
22 |
-
result =
|
|
|
|
|
23 |
elif action == action_2:
|
24 |
-
result = prompt_model(
|
25 |
return result
|
26 |
|
27 |
def fine_tune_model(model_name, dataset_name):
|
@@ -104,7 +106,6 @@ def fine_tune_model(model_name, dataset_name):
|
|
104 |
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
|
105 |
pipe = pipeline("text-generation",
|
106 |
model=model_name,
|
107 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
108 |
device_map="auto",
|
109 |
max_new_tokens=1000)
|
110 |
|
@@ -126,7 +127,8 @@ def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
|
|
126 |
|
127 |
def load_model(model_name):
|
128 |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
129 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
130 |
|
131 |
if not tokenizer.pad_token:
|
132 |
tokenizer.pad_token = tokenizer.eos_token
|
@@ -135,10 +137,11 @@ def load_model(model_name):
|
|
135 |
|
136 |
demo = gr.Interface(fn=process,
|
137 |
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_2),
|
138 |
-
gr.Textbox(label = "Model Name", value =
|
139 |
-
gr.Textbox(label = "
|
140 |
-
gr.Textbox(label = "
|
141 |
-
gr.Textbox(label = "
|
142 |
-
gr.Textbox(label = "
|
|
|
143 |
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
|
144 |
demo.launch()
|
|
|
4 |
from huggingface_hub import HfApi, login
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
|
6 |
|
7 |
+
ACTION_1 = "Prompt base model"
|
8 |
+
ACTION_2 = "Fine-tune base model"
|
9 |
+
ACTION_3 = "Prompt fine-tuned model"
|
10 |
|
11 |
+
SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
|
12 |
+
USER_PROMPT = "What is the total trade value and average price for each trader and stock in the trade_history table?"
|
13 |
+
SQL_SCHEMA = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
|
14 |
|
15 |
+
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
16 |
+
FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
|
17 |
+
DATASET_NAME = "gretelai/synthetic_text_to_sql"
|
18 |
|
19 |
+
def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_schema):
|
|
|
|
|
|
|
20 |
#raise gr.Error("Please clone and bring your own credentials.")
|
21 |
if action == action_1:
|
22 |
+
result = prompt_model(base_model_name, system_prompt, user_prompt, sql_schema)
|
23 |
+
elif action == action_2:
|
24 |
+
result = fine_tune_model(base_model_name, dataset_name)
|
25 |
elif action == action_2:
|
26 |
+
result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_schema)
|
27 |
return result
|
28 |
|
29 |
def fine_tune_model(model_name, dataset_name):
|
|
|
106 |
def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
|
107 |
pipe = pipeline("text-generation",
|
108 |
model=model_name,
|
|
|
109 |
device_map="auto",
|
110 |
max_new_tokens=1000)
|
111 |
|
|
|
127 |
|
128 |
def load_model(model_name):
|
129 |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
130 |
+
#tokenizer = AutoTokenizer.from_pretrained(model_name)
|
131 |
+
tokenizer = model.tokenizer
|
132 |
|
133 |
if not tokenizer.pad_token:
|
134 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
137 |
|
138 |
demo = gr.Interface(fn=process,
|
139 |
inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_2),
|
140 |
+
gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
|
141 |
+
gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
|
142 |
+
gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
|
143 |
+
gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
|
144 |
+
gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
|
145 |
+
gr.Textbox(label = "SQL Schema", value = SQL_SCHEMA, lines = 2)],
|
146 |
outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
|
147 |
demo.launch()
|