Yogeshwaran27 commited on
Commit
d49fa53
1 Parent(s): 838615d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py CHANGED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ import argparse
4
+
5
+ def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
6
+ with open(prompt_file, "r") as f:
7
+ prompt = f.read()
8
+
9
+ with open(metadata_file, "r") as f:
10
+ table_metadata_string = f.read()
11
+
12
+ prompt = prompt.format(
13
+ user_question=question, table_metadata_string=table_metadata_string
14
+ )
15
+ return prompt
16
+
17
+
18
+ def get_tokenizer_model(model_name):
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto",
25
+ use_cache=True,
26
+ )
27
+ return tokenizer, model
28
+
29
+ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
30
+ tokenizer, model = get_tokenizer_model("defog/sqlcoder-34b-alpha")
31
+ prompt = generate_prompt(question, prompt_file, metadata_file)
32
+
33
+ # make sure the model stops generating at triple ticks
34
+ # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
35
+ eos_token_id = tokenizer.eos_token_id
36
+ pipe = pipeline(
37
+ "text-generation",
38
+ model=model,
39
+ tokenizer=tokenizer,
40
+ max_new_tokens=300,
41
+ do_sample=False,
42
+ num_beams=5, # do beam search with 5 beams for high quality results
43
+ )
44
+ generated_query = (
45
+ pipe(
46
+ prompt,
47
+ num_return_sequences=1,
48
+ eos_token_id=eos_token_id,
49
+ pad_token_id=eos_token_id,
50
+ )[0]["generated_text"]
51
+ .split("```sql")[-1]
52
+ .split("```")[0]
53
+ .split(";")[0]
54
+ .strip()
55
+ + ";"
56
+ )
57
+ return generated_query
58
+
59
+ if __name__ == "__main__":
60
+ # Parse arguments
61
+ parser = argparse.ArgumentParser(description="Run inference on a question")
62
+ parser.add_argument("-q","--question", type=str, help="Question to run inference on")
63
+ args = parser.parse_args()
64
+ question = args.question
65
+ print("Loading a model and generating a SQL query for answering your question...")
66
+ print(run_inference(question))