bstraehle commited on
Commit
d157f84
1 Parent(s): 67ecd94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -16,22 +16,22 @@ SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in Eng
16
  USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
17
  SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
18
 
19
- BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
20
  FT_MODEL_NAME = "Meta-Llama-3.1-8B-text-to-sql"
21
  DATASET_NAME = "gretelai/synthetic_text_to_sql"
22
 
23
- def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_context):
24
  raise gr.Error("Please clone and bring your own Hugging Face credentials.")
25
 
26
  if action == ACTION_1:
27
- result = prompt_model(base_model_name, system_prompt, user_prompt, sql_context)
28
  elif action == ACTION_2:
29
- result = fine_tune_model(base_model_name, dataset_name)
30
  elif action == ACTION_3:
31
  result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
32
  return result
33
 
34
- def fine_tune_model(base_model_name, dataset_name):
35
  # Load dataset
36
 
37
  dataset = load_dataset(dataset_name)
@@ -44,7 +44,7 @@ def fine_tune_model(base_model_name, dataset_name):
44
 
45
  # Load model
46
 
47
- model, tokenizer = load_model(base_model_name)
48
 
49
  print("### Model")
50
  print(model)
@@ -80,7 +80,7 @@ def fine_tune_model(base_model_name, dataset_name):
80
  # Configure training arguments
81
 
82
  training_args = Seq2SeqTrainingArguments(
83
- output_dir=f"./{FT_MODEL_NAME}",
84
  num_train_epochs=3, # 37,500 steps
85
  #max_steps=1, # overwrites num_train_epochs
86
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
@@ -106,8 +106,8 @@ def fine_tune_model(base_model_name, dataset_name):
106
 
107
  # Push model and tokenizer to HF
108
 
109
- model.push_to_hub(FT_MODEL_NAME)
110
- tokenizer.push_to_hub(FT_MODEL_NAME)
111
 
112
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
113
  pipe = pipeline("text-generation",
@@ -142,9 +142,9 @@ def load_model(model_name):
142
 
143
  demo = gr.Interface(fn=process,
144
  inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
145
- gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
146
- gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
147
  gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
 
148
  gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
149
  gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
150
  gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 4)],
 
16
  USER_PROMPT = "How many new users joined from countries with stricter data privacy laws than the United States in the past month?"
17
  SQL_CONTEXT = "CREATE TABLE users (user_id INT, country VARCHAR(50), joined_date DATE); CREATE TABLE data_privacy_laws (country VARCHAR(50), privacy_level INT); INSERT INTO users (user_id, country, joined_date) VALUES (1, 'USA', '2023-02-15'), (2, 'Germany', '2023-02-27'); INSERT INTO data_privacy_laws (country, privacy_level) VALUES ('USA', 5), ('Germany', 8);"
18
 
19
+ PT_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
20
  FT_MODEL_NAME = "Meta-Llama-3.1-8B-text-to-sql"
21
  DATASET_NAME = "gretelai/synthetic_text_to_sql"
22
 
23
+ def process(action, pt_model_name, dataset_name, ft_model_name, system_prompt, user_prompt, sql_context):
24
  raise gr.Error("Please clone and bring your own Hugging Face credentials.")
25
 
26
  if action == ACTION_1:
27
+ result = prompt_model(pt_model_name, system_prompt, user_prompt, sql_context)
28
  elif action == ACTION_2:
29
+ result = fine_tune_model(pt_model_name, dataset_name, ft_model_name)
30
  elif action == ACTION_3:
31
  result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_context)
32
  return result
33
 
34
+ def fine_tune_model(pt_model_name, dataset_name, ft_model_name):
35
  # Load dataset
36
 
37
  dataset = load_dataset(dataset_name)
 
44
 
45
  # Load model
46
 
47
+ model, tokenizer = load_model(pt_model_name)
48
 
49
  print("### Model")
50
  print(model)
 
80
  # Configure training arguments
81
 
82
  training_args = Seq2SeqTrainingArguments(
83
+ output_dir=f"./{ft_model_name}",
84
  num_train_epochs=3, # 37,500 steps
85
  #max_steps=1, # overwrites num_train_epochs
86
  # TODO https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
 
106
 
107
  # Push model and tokenizer to HF
108
 
109
+ model.push_to_hub(ft_model_name)
110
+ tokenizer.push_to_hub(ft_model_name)
111
 
112
  def prompt_model(model_name, system_prompt, user_prompt, sql_context):
113
  pipe = pipeline("text-generation",
 
142
 
143
  demo = gr.Interface(fn=process,
144
  inputs=[gr.Radio([ACTION_1, ACTION_2, ACTION_3], label = "Action", value = ACTION_3),
145
+ gr.Textbox(label = "Pre-Trained Model Name", value = PT_MODEL_NAME, lines = 1),
 
146
  gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
147
+ gr.Textbox(label = "Fine-Tuned Model Name", value = f"{HF_ACCOUNT}/{FT_MODEL_NAME}", lines = 1),
148
  gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
149
  gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
150
  gr.Textbox(label = "SQL Context", value = SQL_CONTEXT, lines = 4)],