|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# BERT Transformer Model Trained on Custom Database |
|
|
|
This is a BERT model fine-tuned on the Custom dataset for SQL query generation. |
|
|
|
## Model Details |
|
|
|
- **Model Type**: BERT |
|
- **Training Data**: Custom dataset |
|
- **Use Case**: SQL query generation from natural language questions |
|
|
|
## Usage |
|
|
|
You can use this model with the Hugging Face `transformers` library: |
|
|
|
```python |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
tokenizer = BertTokenizer.from_pretrained('VPrashant/sql_bert') |
|
model = BertForSequenceClassification.from_pretrained('VPrashant/sql_bert') |
|
|
|
def predict_sql_query(question, tokenizer, model): |
|
inputs = tokenizer(question, return_tensors='pt', max_length=128, truncation=True, padding='max_length') |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_label = torch.argmax(logits, dim=1).item() |
|
reverse_label_map = {i: query for query, i in label_map.items()} |
|
predicted_query = reverse_label_map[predicted_label] |
|
|
|
return predicted_query |
|
|
|
question = "Which projects have more than 5 employees working on them?" |
|
# Predict the SQL query |
|
predicted_query = predict_sql_query(question, tokenizer, model) |
|
print(f"Predicted SQL Query: {predicted_query}") |
|
``` |