--- license: apache-2.0 datasets: - bugdaryan/sql-create-context-instruction language: - en tags: - text2sql widget: - text: "[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: CREATE TABLE head (age INTEGER) Question: How many heads of the departments are older than 56 ? [/INST] Here is the SQLite query to answer to the question: How many heads of the departments are older than 56 ?: ```" example_title: "Example 1" - text: "[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: CREATE TABLE people (first_name VARCHAR) Question: List the first names of people in alphabetical order? [/INST] Here is the SQLite query to answer to the question: List the first names of people in alphabetical order?: ```" example_title: "Example 2" - text: "[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: CREATE TABLE weather (zip_code VARCHAR, mean_sea_level_pressure_inches INTEGER) Question: What is the zip code in which the average mean sea level pressure is the lowest? [/INST] Here is the SQLite query to answer to the question: What is the zip code in which the average mean sea level pressure is the lowest?: ```" example_title: "Example 3" - text: "[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: CREATE TABLE weather (date VARCHAR, mean_temperature_f VARCHAR, mean_humidity VARCHAR, max_gust_speed_mph VARCHAR) Question: What are the date, mean temperature and mean humidity for the top 3 days with the largest max gust speeds? [/INST] Here is the SQLite query to answer to the question: What are the date, mean temperature and mean humidity for the top 3 days with the largest max gust speeds?: ```" example_title: "Example 4" - text: "[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: CREATE TABLE trip (end_station_id VARCHAR); CREATE TABLE station (id VARCHAR, city VARCHAR) Question: Count the number of trips that did not end in San Francisco city. [/INST] Here is the SQLite query to answer to the question: Count the number of trips that did not end in San Francisco city.: ```" example_title: "Example 5" --- # Model Card for MistralSQL-7B ## Model Information - **Model Name:** MistralSQL-7B - **Base Model Name:** mistralai/Mistral-7B-Instruct-v0.1 - **Base Model URL:** [Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) - **Dataset Name:** bugdaryan/sql-create-context-instruction - **Dataset URL:** [SQL Create Context Dataset](https://huggingface.co/datasets/bugdaryan/sql-create-context-instruction) - **Dataset Description:** This dataset is built upon SQL Create Context, sourced from WikiSQL and Spider, providing 78,577 examples of natural language queries, SQL CREATE TABLE statements, and SQL queries answering questions using the CREATE statement as context. ## Model Parameters - **LoRA Attention Dimension:** 64 - **LoRA Alpha Parameter:** 16 - **LoRA Dropout Probability:** 0.1 - **Bitsandbytes Parameters:** - Activate 4-bit precision base model loading: True - Compute dtype for 4-bit base models: float16 - Quantization type (fp4 or nf4): nf4 - Activate nested quantization for 4-bit base models: False - **TrainingArguments Parameters:** - Output directory: "./results" - Number of training epochs: 1 - Enable fp16/bf16 training: False/True - Batch size per GPU for training: 80 - Batch size per GPU for evaluation: 4 - Gradient accumulation steps: 1 - Enable gradient checkpointing: True - Maximum gradient norm (gradient clipping): 0.3 - Initial learning rate (AdamW optimizer): 2e-4 - Weight decay: 0.001 - Optimizer: paged_adamw_32bit - Learning rate schedule: cosine - Number of training steps (overrides num_train_epochs): -1 - Ratio of steps for a linear warmup: 0.03 - Group sequences into batches with the same length: True - Save checkpoint every X update steps: 0 - Log every X update steps: 10 - **SFT Parameters:** - Maximum sequence length: 500 - Packing: False ## Inference Parameters - **Temperature:** 0.7 ## Hardware and Software - **Training Hardware**: 2 RTX A6000 48GB GPUs ## License - Apache-2.0 ## Instruction Format To leverage instruction fine-tuning, prompts should be surrounded by `[INST]` and `[/INST]` tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id. For example: ```python from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline ) import torch model_name = 'bugdaryan/MistralSQL-7b' model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_name) pipe = pipeline('text-generation', model=model, tokenizer=tokenizer) table = "CREATE TABLE sales ( sale_id number PRIMARY KEY, product_id number, customer_id number, salesperson_id number, sale_date DATE, quantity number, FOREIGN KEY (product_id) REFERENCES products(product_id), FOREIGN KEY (customer_id) REFERENCES customers(customer_id), FOREIGN KEY (salesperson_id) REFERENCES salespeople(salesperson_id)); CREATE TABLE product_suppliers ( supplier_id number PRIMARY KEY, product_id number, supply_price number, FOREIGN KEY (product_id) REFERENCES products(product_id)); CREATE TABLE customers ( customer_id number PRIMARY KEY, name text, address text ); CREATE TABLE salespeople ( salesperson_id number PRIMARY KEY, name text, region text ); CREATE TABLE product_suppliers ( supplier_id number PRIMARY KEY, product_id number, supply_price number );" question = 'Find the salesperson who made the most sales.' prompt = f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: {table} Question: {question} [/INST] Here is the SQLite query to answer to the question: {question}: ``` " ans = pipe(prompt, max_new_tokens=100) print(ans[0]['generated_text'].split('```')[2]) ```