File size: 4,491 Bytes
7ee13ee 9d34813 7ee13ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
---
license: apache-2.0
---
# TinyLlama for Text-to-SQL
## Problem Statement
I need a small generative model that can generate SQL code in response to user queries while avoiding any additional commentary. This will help reduce operational costs, increase throughput, and lower latency.
## Solution
### Part 1: Initial Experimentation (Refer to `Run_Tinyllama_Chat.ipynb`)
#### Step 1: Using an Off-the-Shelf Model
I started with the TinyLlama model. Below is an example of the initial request and response:
```
<|system|>
CREATE TABLE head(age INTEGER)</s>
<|user|>
How many heads of the departments are older than 56?</s>
<|assistant|>
I don't have access to the latest data or the current headcount of the departments...
```
The model did not return the expected SQL query, which is understandable given the lack of context.
#### Step 2: Prompt Engineering
I attempted prompt engineering by adding more details to the context:
```
<|system|>
You can only reply in SQL query language. Provide only SQL for the user's query given this context --> CREATE TABLE head(age INTEGER)</s>
<|user|>
How many heads of the departments are older than 56?</s>
<|assistant|>
SELECT COUNT(*) FROM head WHERE age > 56
```
The model generated the SQL query but included additional commentary, which I wanted to avoid.
#### Step 3: Further Refinement
Despite additional prompt engineering efforts, the model still produced unwanted explanations:
```
<|assistant|>
To calculate the number of heads of the departments older than 56, you can use the following SQL query:
SELECT COUNT(*) FROM departments WHERE age > 56;
In the above query, "departments" is the name of the table and "age" is the column name...
```
This led me to consider fine-tuning the model.
---
### Part 2: Fine-Tuning the Model
I decided to fine-tune TinyLlama for better SQL-specific responses. Below are the steps to replicate the fine-tuning process.
#### Setup Environment and Run Fine-Tuning Job on RunPod.io
```bash
#!/bin/bash
pip install -q accelerate transformers peft deepspeed bitsandbytes --no-build-isolation
pip install trl==0.9.6
pip install packaging ninja
MAX_JOBS=16 pip install flash-attn==2.6.0.post1 --no-build-isolation
git clone https://github.com/Rajesh-Nair/llm-text2sql-finetuning
cd llm-text2sql-finetuning
accelerate launch --config_file "ds_z3_qlora_config.yaml" train.py run_config.yaml | tee accelerate_output.log
```
#### Key Components of Fine-Tuning
1. **Dataset**: Utilized `b-mc2/sql-create-context` from Hugging Face for fine-tuning. High-quality data is essential for improving model performance.
2. **Accelerate**: Leveraged `accelerate` to enhance training speed and minimize boilerplate code.
3. **Distributed Training**:
- Deployed across two GPUs on a single node via RunPod.io.
- Hardware specifications: L4 GPU, PyTorch 2.1, Python 3.10, CUDA 11.8 (Ubuntu image).
4. **QLoRA**:
- Applied QLoRA for memory-efficient fine-tuning.
- Configured LoRA with 8-rank matrices for all linear layers.
5. **DeepSpeed Zero3**: Implemented for optimized sharding of optimizers, gradients, and parameters.
6. **Mixed Precision**: Utilized to accelerate training and improve GPU efficiency.
7. **Batch Size & Gradient Accumulation**:
- Set batch size per device to 4.
- Applied gradient accumulation every 2 steps for optimal performance.
- Increasing batch size beyond this sometimes led to GPU communication bottlenecks.
8. **Gradient Clipping**: Enabled to prevent unexpected exploding gradients.
9. **Training Duration & Cost**:
- Each epoch took approximately 1 hour.
- Training was force-stopped after 3 epochs due to negligible improvements in training loss.
- Total fine-tuning cost on RunPod: under \$3.
10. **Training Logs**: Captured logs in `accelerate_outlog.log` for future analysis and reference.
#### Serving the Fine-Tuned Model
Refer to `Run_ft_Tinyllama_Chat.ipynb` for deploying the fine-tuned model.
Example Query and Response:
```
<|system|>
CREATE TABLE head(age INTEGER)</s>
<|user|>
How many heads of the departments are older than 56?</s>
<|assistant|>
SELECT COUNT(*) FROM head WHERE age > 56
```
The fine-tuned model now returns only the SQL query, as intended.
---
### Final Model & Deployment
After fine-tuning, I merged the trained adapter with the base model and uploaded it to Hugging Face: Here is the full code : 🔗 [**TinyLlama-1.1B-Chat-Text2SQL**](https://github.com/Rajesh-Nair/llm-text2sql-finetuning)
|