bstraehle commited on
Commit
2dfbd8a
1 Parent(s): 3378cee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -4,24 +4,26 @@ from datasets import load_dataset
4
  from huggingface_hub import HfApi, login
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
6
 
7
- hf_profile = "bstraehle"
 
 
8
 
9
- action_1 = "Fine-tune pre-trained model"
10
- action_2 = "Prompt fine-tuned model"
 
11
 
12
- system_prompt = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
13
- user_prompt = "What is the total trade value and average price for each trader and stock in the trade_history table?"
14
- sql_schema = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
15
 
16
- model_name = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
17
- dataset_name = "gretelai/synthetic_text_to_sql"
18
-
19
- def process(action, model_name, dataset_name, system_prompt, user_prompt, sql_schema):
20
  #raise gr.Error("Please clone and bring your own credentials.")
21
  if action == action_1:
22
- result = fine_tune_model(model_name, dataset_name)
 
 
23
  elif action == action_2:
24
- result = prompt_model(model_name, system_prompt, user_prompt, sql_schema)
25
  return result
26
 
27
  def fine_tune_model(model_name, dataset_name):
@@ -104,7 +106,6 @@ def fine_tune_model(model_name, dataset_name):
104
  def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
105
  pipe = pipeline("text-generation",
106
  model=model_name,
107
- model_kwargs={"torch_dtype": torch.bfloat16},
108
  device_map="auto",
109
  max_new_tokens=1000)
110
 
@@ -126,7 +127,8 @@ def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
126
 
127
  def load_model(model_name):
128
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
129
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
130
 
131
  if not tokenizer.pad_token:
132
  tokenizer.pad_token = tokenizer.eos_token
@@ -135,10 +137,11 @@ def load_model(model_name):
135
 
136
  demo = gr.Interface(fn=process,
137
  inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_2),
138
- gr.Textbox(label = "Model Name", value = model_name, lines = 1),
139
- gr.Textbox(label = "Dataset Name", value = dataset_name, lines = 1),
140
- gr.Textbox(label = "System Prompt", value = system_prompt, lines = 2),
141
- gr.Textbox(label = "User Prompt", value = user_prompt, lines = 2),
142
- gr.Textbox(label = "SQL Schema", value = sql_schema, lines = 2)],
 
143
  outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
144
  demo.launch()
 
4
  from huggingface_hub import HfApi, login
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, pipeline
6
 
7
+ ACTION_1 = "Prompt base model"
8
+ ACTION_2 = "Fine-tune base model"
9
+ ACTION_3 = "Prompt fine-tuned model"
10
 
11
+ SYSTEM_PROMPT = "You are a text to SQL query translator. Given a question in English, generate a SQL query based on the provided SCHEMA. Do not generate any additional text. SCHEMA: {schema}"
12
+ USER_PROMPT = "What is the total trade value and average price for each trader and stock in the trade_history table?"
13
+ SQL_SCHEMA = "CREATE TABLE trade_history (id INT, trader_id INT, stock VARCHAR(255), price DECIMAL(5,2), quantity INT, trade_time TIMESTAMP);"
14
 
15
+ BASE_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
16
+ FT_MODEL_NAME = "bstraehle/Meta-Llama-3.1-8B-Instruct-text-to-sql"
17
+ DATASET_NAME = "gretelai/synthetic_text_to_sql"
18
 
19
+ def process(action, base_model_name, ft_model_name, dataset_name, system_prompt, user_prompt, sql_schema):
 
 
 
20
  #raise gr.Error("Please clone and bring your own credentials.")
21
  if action == action_1:
22
+ result = prompt_model(base_model_name, system_prompt, user_prompt, sql_schema)
23
+ elif action == action_2:
24
+ result = fine_tune_model(base_model_name, dataset_name)
25
  elif action == action_2:
26
+ result = prompt_model(ft_model_name, system_prompt, user_prompt, sql_schema)
27
  return result
28
 
29
  def fine_tune_model(model_name, dataset_name):
 
106
  def prompt_model(model_name, system_prompt, user_prompt, sql_schema):
107
  pipe = pipeline("text-generation",
108
  model=model_name,
 
109
  device_map="auto",
110
  max_new_tokens=1000)
111
 
 
127
 
128
  def load_model(model_name):
129
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
130
+ #tokenizer = AutoTokenizer.from_pretrained(model_name)
131
+ tokenizer = model.tokenizer
132
 
133
  if not tokenizer.pad_token:
134
  tokenizer.pad_token = tokenizer.eos_token
 
137
 
138
  demo = gr.Interface(fn=process,
139
  inputs=[gr.Radio([action_1, action_2], label = "Action", value = action_2),
140
+ gr.Textbox(label = "Base Model Name", value = BASE_MODEL_NAME, lines = 1),
141
+ gr.Textbox(label = "Fine-Tuned Model Name", value = FT_MODEL_NAME, lines = 1),
142
+ gr.Textbox(label = "Dataset Name", value = DATASET_NAME, lines = 1),
143
+ gr.Textbox(label = "System Prompt", value = SYSTEM_PROMPT, lines = 2),
144
+ gr.Textbox(label = "User Prompt", value = USER_PROMPT, lines = 2),
145
+ gr.Textbox(label = "SQL Schema", value = SQL_SCHEMA, lines = 2)],
146
  outputs=[gr.Textbox(label = "Prompt Completion", value = os.environ["OUTPUT"])])
147
  demo.launch()