nileshhanotia commited on
Commit
42de116
1 Parent(s): 8f0ce1c

Update sql_generator.py

Browse files
Files changed (1) hide show
  1. sql_generator.py +23 -6
sql_generator.py CHANGED
@@ -8,7 +8,7 @@ class SQLGenerator:
8
  self.model_name = "premai-io/prem-1B-SQL"
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
-
12
  def generate_query(self, natural_language_query):
13
  schema_info = """
14
  CREATE TABLE products (
@@ -43,15 +43,17 @@ class SQLGenerator:
43
  outputs = self.model.generate(
44
  inputs["input_ids"],
45
  max_length=256,
46
- do_sample=False,
47
  num_return_sequences=1,
48
  eos_token_id=self.tokenizer.eos_token_id,
49
  pad_token_id=self.tokenizer.pad_token_id,
50
- temperature=0.7, # Adjust temperature for more creative output
51
  top_k=50 # Consider top k predictions for variability
52
  )
53
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
54
-
 
 
55
  def fetch_shopify_data(self, endpoint):
56
  headers = {
57
  'X-Shopify-Access-Token': ACCESS_TOKEN,
@@ -59,9 +61,24 @@ class SQLGenerator:
59
  }
60
  url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
61
  response = requests.get(url, headers=headers)
62
-
63
  if response.status_code == 200:
64
  return response.json()
65
  else:
66
  print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
67
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  self.model_name = "premai-io/prem-1B-SQL"
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
10
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
11
+
12
  def generate_query(self, natural_language_query):
13
  schema_info = """
14
  CREATE TABLE products (
 
43
  outputs = self.model.generate(
44
  inputs["input_ids"],
45
  max_length=256,
46
+ do_sample=True, # Enable sampling to use temperature
47
  num_return_sequences=1,
48
  eos_token_id=self.tokenizer.eos_token_id,
49
  pad_token_id=self.tokenizer.pad_token_id,
50
+ temperature=0.7, # Allow temperature to affect output
51
  top_k=50 # Consider top k predictions for variability
52
  )
53
+
54
+ generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
55
+ return generated_query # Return the generated SQL query
56
+
57
  def fetch_shopify_data(self, endpoint):
58
  headers = {
59
  'X-Shopify-Access-Token': ACCESS_TOKEN,
 
61
  }
62
  url = f"https://{SHOP_NAME}/admin/api/2023-10/{endpoint}.json"
63
  response = requests.get(url, headers=headers)
64
+
65
  if response.status_code == 200:
66
  return response.json()
67
  else:
68
  print(f"Error fetching {endpoint}: {response.status_code} - {response.text}")
69
  return None
70
+
71
+ # Example of how to use the SQLGenerator class
72
+ if __name__ == "__main__":
73
+ sql_generator = SQLGenerator()
74
+
75
+ # Example natural language query
76
+ natural_language_query = "Show me shirts with red color"
77
+
78
+ # Generate SQL query
79
+ sql_query = sql_generator.generate_query(natural_language_query)
80
+ print(f"Generated SQL Query: {sql_query}")
81
+
82
+ # Fetch data from Shopify (example endpoint)
83
+ shopify_data = sql_generator.fetch_shopify_data("products")
84
+ print(f"Shopify Data: {shopify_data}")