delayedkarma commited on
Commit
b5afe62
1 Parent(s): 31e3617

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md CHANGED
@@ -23,6 +23,7 @@ should probably proofread and complete it, then remove this comment. -->
23
  # mistral-7b-text-to-sql
24
 
25
  This model is a fine-tuned version of [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the b-mc2/sql-create-context dataset.
 
26
 
27
  ## Model description
28
 
@@ -31,6 +32,64 @@ This model is a fine-tuned version of [mistralai/Mistral-7B-v0.1](https://huggin
31
  - License: Apache 2.0
32
  - Finetuned from model : Mistral-7B-v0.1
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ## Training procedure
35
 
36
  ### Training hyperparameters
 
23
  # mistral-7b-text-to-sql
24
 
25
  This model is a fine-tuned version of [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the b-mc2/sql-create-context dataset.
26
+ These are the adapter weights, and the code to use these for generation is given below. A full model will be uploaded at a later date.
27
 
28
  ## Model description
29
 
 
32
  - License: Apache 2.0
33
  - Finetuned from model : Mistral-7B-v0.1
34
 
35
+ ## How to get started with the model
36
+
37
+ ```python
38
+ import torch
39
+ from transformers import AutoTokenizer, pipeline
40
+ from datasets import load_dataset
41
+ from peft import AutoPeftModelForCausalLM
42
+ from random import randint
43
+
44
+ peft_model_id = "delayedkarma/mistral-7b-text-to-sql"
45
+
46
+ # Load Model with PEFT adapter
47
+ model = AutoPeftModelForCausalLM.from_pretrained(
48
+ peft_model_id,
49
+ device_map="auto",
50
+ torch_dtype=torch.float16
51
+ )
52
+ tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
53
+ # load into pipeline
54
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
55
+
56
+ # Load dataset and Convert dataset to OAI messages
57
+ system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
58
+ SCHEMA:
59
+ {schema}"""
60
+
61
+ def create_conversation(sample):
62
+ return {
63
+ "messages": [
64
+ {"role": "system", "content": system_message.format(schema=sample["context"])},
65
+ {"role": "user", "content": sample["question"]},
66
+ {"role": "assistant", "content": sample["answer"]}
67
+ ]
68
+ }
69
+
70
+ # Load dataset from the hub
71
+ dataset = load_dataset("b-mc2/sql-create-context", split="train")
72
+ dataset = dataset.shuffle().select(range(100))
73
+
74
+ # Convert dataset to OAI messages
75
+ dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)
76
+
77
+ dataset = dataset.train_test_split(test_size=20/100)
78
+
79
+ # Evaluate
80
+ eval_dataset = dataset['test']
81
+ rand_idx = randint(0, len(eval_dataset))
82
+
83
+ # Test on sample
84
+ prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
85
+ outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
86
+
87
+ print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
88
+ print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
89
+ print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
90
+
91
+ ```
92
+
93
  ## Training procedure
94
 
95
  ### Training hyperparameters