Manoj Kumar commited on
Commit
4c735b8
·
1 Parent(s): d7b8d30

initial commit

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +61 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
- app_file: gpt_neo.db.py
9
  pinned: false
10
  python: 3.9
11
  ---
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.11.0
8
+ app_file: gpt_neo_db.py
9
  pinned: false
10
  python: 3.9
11
  ---
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Define the schema for the database
6
+ db_schema = {
7
+ "products": ["product_id", "name", "price", "description", "type"],
8
+ "orders": ["order_id", "product_id", "quantity", "order_date"],
9
+ "customers": ["customer_id", "name", "email", "phone_number"]
10
+ }
11
+
12
+ # Load the model and tokenizer
13
+ model_name = "EleutherAI/gpt-neox-20b" # You can also use "Llama-2-7b" or another model
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
16
+
17
+ def generate_sql_query(context, question):
18
+ """
19
+ Generate an SQL query based on the question and context.
20
+
21
+ Args:
22
+ context (str): Description of the database schema or table relationships.
23
+ question (str): User's natural language query.
24
+
25
+ Returns:
26
+ str: Generated SQL query.
27
+ """
28
+ # Prepare the prompt
29
+ prompt = f"""
30
+ Context: {context}
31
+
32
+ Question: {question}
33
+
34
+ Write an SQL query to address the question based on the context.
35
+ Query:
36
+ """
37
+ # Tokenize input
38
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ # Generate SQL query
41
+ output = model.generate(inputs.input_ids, max_length=512, num_beams=5, early_stopping=True)
42
+ query = tokenizer.decode(output[0], skip_special_tokens=True)
43
+
44
+ # Extract query from the output
45
+ sql_query = query.split("Query:")[-1].strip()
46
+ return sql_query
47
+
48
+ # Schema as a context for the model
49
+ schema_description = json.dumps(db_schema, indent=4)
50
+
51
+ # Example interactive questions
52
+ print("Ask a question about the database schema.")
53
+ while True:
54
+ user_question = input("Question: ")
55
+ if user_question.lower() in ["exit", "quit"]:
56
+ print("Exiting...")
57
+ break
58
+
59
+ # Generate SQL query
60
+ sql_query = generate_sql_query(schema_description, user_question)
61
+ print(f"Generated SQL Query:\n{sql_query}\n")