added more files
Browse files- TODO.md +10 -0
- check_model.py +7 -0
- download_model.py +11 -0
- login.py +11 -0
- main.py +308 -0
- requirements.txt +9 -0
- test_torch.py +10 -0
TODO.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task: Completed ✅
|
| 2 |
+
|
| 3 |
+
## Steps:
|
| 4 |
+
1. [x] Confirm plan approved.
|
| 5 |
+
2. [x] Generate complete Markdown content (temp file: new_readme.md).
|
| 6 |
+
3. [x] Replace in final_merged_model/README.md with full whitepaper (YAML preserved, exhaustive content, updated script).
|
| 7 |
+
4. [x] Preview opened in VSCode.
|
| 8 |
+
5. [x] Complete task.
|
| 9 |
+
|
| 10 |
+
Final README.md is now an enterprise-grade whitepaper with all required sections, tables, 263% boost, guardrails, dataset details, test case tables, constraints/V2.0, and merged model usage script.
|
check_model.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForCausalLM
|
| 2 |
+
try:
|
| 3 |
+
model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', local_files_only=True)
|
| 4 |
+
print('Local model files:', list(model.state_dict().keys())[:5])
|
| 5 |
+
except Exception as e:
|
| 6 |
+
print('Model not cached:', str(e))
|
| 7 |
+
|
download_model.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 4 |
+
'Qwen/Qwen2.5-7B-Instruct',
|
| 5 |
+
dtype=torch.float16,
|
| 6 |
+
device_map='cpu',
|
| 7 |
+
trust_remote_code=True
|
| 8 |
+
)
|
| 9 |
+
tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-7B-Instruct', trust_remote_code=True)
|
| 10 |
+
print('Qwen2.5-7B-Instruct base model downloaded and cached locally for CPU.')
|
| 11 |
+
|
login.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import login, whoami
|
| 2 |
+
import os
|
| 3 |
+
# Replace with your actual token from https://huggingface.co/settings/tokens
|
| 4 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
login(token=HF_TOKEN)
|
| 8 |
+
user_info = whoami()
|
| 9 |
+
print(f"\n[SUCCESS] Authenticated as: {user_info['name']}")
|
| 10 |
+
except Exception as e:
|
| 11 |
+
print(f"\n[LOGIC ERROR] Authentication Failed: {e}")
|
main.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 3 |
+
from peft import PeftModel
|
| 4 |
+
import psycopg2
|
| 5 |
+
from psycopg2 import pool
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
from groq import Groq
|
| 9 |
+
import os
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
# Setup logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Load environment variables
|
| 17 |
+
load_dotenv()
|
| 18 |
+
|
| 19 |
+
class EnergyIntelligenceBot:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.db = None
|
| 22 |
+
self._validate_env_vars()
|
| 23 |
+
self._connect_db()
|
| 24 |
+
|
| 25 |
+
def _validate_env_vars(self):
|
| 26 |
+
"""Ensures all critical variables exist before starting."""
|
| 27 |
+
required_vars = ["DB_USER", "DB_PASSWORD", "DB_HOST", "DB_PORT", "DB_NAME", "GROQ_API_KEY"]
|
| 28 |
+
missing = [var for var in required_vars if not os.environ.get(var)]
|
| 29 |
+
if missing:
|
| 30 |
+
raise ValueError(f"CRITICAL: Missing environment variables: {', '.join(missing)}")
|
| 31 |
+
|
| 32 |
+
def _connect_db(self):
|
| 33 |
+
"""Safely connects to PostgreSQL."""
|
| 34 |
+
try:
|
| 35 |
+
self.db = psycopg2.pool.SimpleConnectionPool(
|
| 36 |
+
minconn=1,
|
| 37 |
+
maxconn=10,
|
| 38 |
+
user=os.environ.get("DB_USER"),
|
| 39 |
+
password=os.environ.get("DB_PASSWORD"),
|
| 40 |
+
host=os.environ.get("DB_HOST"),
|
| 41 |
+
port=os.environ.get("DB_PORT"),
|
| 42 |
+
database=os.environ.get("DB_NAME")
|
| 43 |
+
)
|
| 44 |
+
logger.info("PostgreSQL connection pool created successfully.")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"Database connection failed: {e}. Bot will run in text-only mode.")
|
| 47 |
+
|
| 48 |
+
def call_llm(self, messages, model="llama-3.1-8b-instant"):
|
| 49 |
+
"""Calls Groq API."""
|
| 50 |
+
api_key = os.environ.get("GROQ_API_KEY")
|
| 51 |
+
client = Groq(api_key=api_key)
|
| 52 |
+
return client.chat.completions.create(
|
| 53 |
+
messages=messages,
|
| 54 |
+
model=model
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def execute_and_format_query(self, generated_sql, user_message):
|
| 58 |
+
logger.info(f"Generated SQL text:\n{generated_sql}")
|
| 59 |
+
|
| 60 |
+
# --- ROBUST SQL CLEANING ---
|
| 61 |
+
sql_match = re.search(r"```sql\s*(.*?)\s*```", generated_sql, re.DOTALL | re.IGNORECASE)
|
| 62 |
+
|
| 63 |
+
if sql_match:
|
| 64 |
+
query = sql_match.group(1).strip()
|
| 65 |
+
else:
|
| 66 |
+
fallback_match = re.search(r"(SELECT[\s\S]*)", generated_sql, re.IGNORECASE)
|
| 67 |
+
query = fallback_match.group(1).strip() if fallback_match else generated_sql
|
| 68 |
+
|
| 69 |
+
query = re.sub(r';\s*(?=union)', ' ', query, flags=re.IGNORECASE)
|
| 70 |
+
query = query.split(';')[0].strip() + ';'
|
| 71 |
+
|
| 72 |
+
logger.info(f"Cleaned SQL query: {query}")
|
| 73 |
+
|
| 74 |
+
Qres = []
|
| 75 |
+
|
| 76 |
+
if not self.db:
|
| 77 |
+
return "Error: Database connection is not available."
|
| 78 |
+
|
| 79 |
+
conn = None
|
| 80 |
+
try:
|
| 81 |
+
conn = self.db.getconn()
|
| 82 |
+
with conn.cursor() as cur:
|
| 83 |
+
cur.execute(query)
|
| 84 |
+
Qres = cur.fetchall()
|
| 85 |
+
logger.info(f"Result Fetched: {str(Qres)[:100]}...")
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Error executing query: {e}")
|
| 88 |
+
if conn:
|
| 89 |
+
conn.rollback()
|
| 90 |
+
return f"An error occurred while executing the query: {e}"
|
| 91 |
+
finally:
|
| 92 |
+
if conn:
|
| 93 |
+
self.db.putconn(conn)
|
| 94 |
+
|
| 95 |
+
messages = [
|
| 96 |
+
{
|
| 97 |
+
"role": "system",
|
| 98 |
+
"content": """Task : Your main goal is to make SQL query results easy to interpret for users who may not have technical backgrounds while ensuring all information is correct and clear.
|
| 99 |
+
user will give the conversation history, which contains ONLY the user's messages. Your task is to generate a response based on this history.
|
| 100 |
+
|
| 101 |
+
You have electric monitoring systems data, of 4 locations :
|
| 102 |
+
CNS Equipment Room
|
| 103 |
+
Glide Path
|
| 104 |
+
Localizer
|
| 105 |
+
DVOR
|
| 106 |
+
|
| 107 |
+
NEVER CHANGE THE ACTUAL DATA.
|
| 108 |
+
USER PROVIDED DATA SHOULD BE AS THEY ARE DONT EVEN TRY TO CONVERT THEM, LIKE FOR ENERGY TO KWH. THEY ARE ALREADY IN KWH FORMAT.
|
| 109 |
+
keep this in mind while making response that you have electric data so form them correctly with their units, there will be current, voltage, energy, power factor ,etc.
|
| 110 |
+
the r,y,b will be denoting the different phases such as red, yellow and blue phase.
|
| 111 |
+
There are not any phases in energy , frequency data. means they are regular data, they dont have any phases.
|
| 112 |
+
|
| 113 |
+
1. Receive SQL Query Results:
|
| 114 |
+
When given an SQL query result, your task is to format it professionally and clearly so that it is easy to read and understand.
|
| 115 |
+
|
| 116 |
+
2. Structure the Answer:
|
| 117 |
+
Tables: If the SQL query result contains rows and columns, format the output as a neat table.
|
| 118 |
+
Bullet Points or Lists: Use bullet points or structured lists if the results are better conveyed this way.
|
| 119 |
+
|
| 120 |
+
3. Contextual Information:
|
| 121 |
+
Add brief, clear explanations where necessary to provide context or meaning behind the data, ensuring the user understands what the result represents.
|
| 122 |
+
|
| 123 |
+
4. Formatting Example:
|
| 124 |
+
Guidelines:
|
| 125 |
+
Maintain a clean and simple presentation.
|
| 126 |
+
When needed, include context or analysis like trends, anomalies, or insights from the data.
|
| 127 |
+
The final answer should only include the well-formatted result and necessary explanation—no technical jargon or SQL-specific terms.
|
| 128 |
+
|
| 129 |
+
5. NOTE :
|
| 130 |
+
Read the fluctuations or anomaly data and notice them, if they are in percentage or the actual values for show them with units.(all data like 234,256,.. then volts, and if all data like 3.4, 6.56, 11.34, then %)
|
| 131 |
+
Never use this types of words in the final answer like "Based on the provided SQL query " or anything that indicates towards the sql query.
|
| 132 |
+
just give natural answers as human can understand without any technical things like the sql related things.
|
| 133 |
+
|
| 134 |
+
NEVER CHANGE THE VALUES FOR THE ENERGY REPORT, DONT YOU DARE CHANGING THEM. KEEP IT AS THEY ARE.
|
| 135 |
+
"""
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"role": "user",
|
| 139 |
+
"content": f"{user_message} \nThis was the query : {query} \nAnd Here is the query result : {str(Qres)}"
|
| 140 |
+
}
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
ai_msg = self.call_llm(messages)
|
| 145 |
+
response_content = ai_msg.choices[0].message.content
|
| 146 |
+
logger.info("Formatted response generated successfully.")
|
| 147 |
+
return response_content
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"Error calling LLM for formatting: {e}")
|
| 150 |
+
return "Failed to format response via LLM."
|
| 151 |
+
|
| 152 |
+
def close_connections(self):
|
| 153 |
+
if hasattr(self, 'db') and self.db:
|
| 154 |
+
self.db.closeall()
|
| 155 |
+
logger.info("PostgreSQL connection pool closed.")
|
| 156 |
+
|
| 157 |
+
# --- Initialization & Safe Execution ---
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
bot = None
|
| 161 |
+
try:
|
| 162 |
+
bot = EnergyIntelligenceBot()
|
| 163 |
+
|
| 164 |
+
# Hardware Check
|
| 165 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 166 |
+
logger.info(f"Using device: {device}")
|
| 167 |
+
if device == "cpu":
|
| 168 |
+
logger.warning("No GPU found. Running this model on CPU will be extremely slow.")
|
| 169 |
+
else:
|
| 170 |
+
# GPU optimizations for inference
|
| 171 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 172 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 173 |
+
logger.info(f"GPU optimizations enabled for {torch.cuda.get_device_name()}")
|
| 174 |
+
|
| 175 |
+
base_model_id = "Qwen/Qwen2.5-7B-Instruct"
|
| 176 |
+
adapter_path = "./"
|
| 177 |
+
|
| 178 |
+
# Check if adapter exists
|
| 179 |
+
if not os.path.exists(adapter_path) and adapter_path != "./":
|
| 180 |
+
logger.warning(f"Adapter path {adapter_path} not found. Ensure the path is correct.")
|
| 181 |
+
|
| 182 |
+
# VRAM check for RTX 2050 4GB
|
| 183 |
+
if torch.cuda.is_available():
|
| 184 |
+
total_vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| 185 |
+
logger.info(f"Total VRAM: {total_vram:.1f} GB")
|
| 186 |
+
if total_vram < 5:
|
| 187 |
+
logger.warning("Low VRAM detected (<5GB). Using aggressive offloading.")
|
| 188 |
+
|
| 189 |
+
if torch.cuda.is_available():
|
| 190 |
+
# Optimized 4-bit configuration for RTX 2050 4GB
|
| 191 |
+
bnb_config = BitsAndBytesConfig(
|
| 192 |
+
load_in_4bit=True,
|
| 193 |
+
bnb_4bit_use_double_quant=True,
|
| 194 |
+
bnb_4bit_quant_type="nf4",
|
| 195 |
+
bnb_4bit_compute_dtype=torch.float16
|
| 196 |
+
)
|
| 197 |
+
logger.info("Loading base model with AGGRESSIVE GPU quantization...")
|
| 198 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 199 |
+
base_model_id,
|
| 200 |
+
dtype=torch.float16,
|
| 201 |
+
device_map="auto", # Changed from "cuda:auto" to "auto" to better handle hybrid GPUs
|
| 202 |
+
quantization_config=bnb_config,
|
| 203 |
+
trust_remote_code=True,
|
| 204 |
+
low_cpu_mem_usage=True, # Critical to prevent RAM spike
|
| 205 |
+
max_memory={0: "3.5GiB", "cpu": "8GiB"} # Restricted RAM usage to keep system stable
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
logger.info("Loading base model on CPU (no quantization)...")
|
| 209 |
+
if torch.cuda.is_available():
|
| 210 |
+
logger.info(f"Model is actually on: {base_model.device}")
|
| 211 |
+
logger.info(f"VRAM used: {torch.cuda.memory_allocated(0)/1024**3:.2f}GB")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 215 |
+
base_model_id,
|
| 216 |
+
dtype=torch.float16,
|
| 217 |
+
device_map="cpu",
|
| 218 |
+
trust_remote_code=True,
|
| 219 |
+
low_cpu_mem_usage=True
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Verification check
|
| 223 |
+
logger.info(f"Model placement complete.")
|
| 224 |
+
if torch.cuda.is_available():
|
| 225 |
+
logger.info(f"Model is actually on: {base_model.device}")
|
| 226 |
+
logger.info(f"VRAM used: {torch.cuda.memory_allocated(0)/1024**3:.2f}GB")
|
| 227 |
+
if "cpu" in str(base_model.device):
|
| 228 |
+
logger.warning("MODEL IS ON CPU! Bitsandbytes may be failing to find CUDA kernels.")
|
| 229 |
+
|
| 230 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
|
| 231 |
+
if tokenizer.pad_token is None:
|
| 232 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 233 |
+
logger.info("Set pad_token to eos_token")
|
| 234 |
+
|
| 235 |
+
logger.info("Loading adapter...")
|
| 236 |
+
model = PeftModel.from_pretrained(base_model, adapter_path)
|
| 237 |
+
model.eval()
|
| 238 |
+
|
| 239 |
+
if device == "cuda":
|
| 240 |
+
logger.info("CPU offload ready (disabled due to PeftModel compatibility)")
|
| 241 |
+
|
| 242 |
+
print("Enter 'exit' to quit.")
|
| 243 |
+
while True:
|
| 244 |
+
user_question = input("Enter your question: ").strip().lower()
|
| 245 |
+
if user_question == "exit":
|
| 246 |
+
print("Exiting...")
|
| 247 |
+
break
|
| 248 |
+
if not user_question:
|
| 249 |
+
print("Please enter a question.")
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
prompt = f"generate the sql for this:{user_question.capitalize()}"
|
| 253 |
+
|
| 254 |
+
messages = [
|
| 255 |
+
{"role": "system", "content": """You are an expert NLP-to-SQL agent. Database table is 'main_cns' with energy monitoring data.
|
| 256 |
+
|
| 257 |
+
CRITICAL RULES:
|
| 258 |
+
- ONLY generate ONE real SELECT query for 'main_cns' table.
|
| 259 |
+
- NO examples, fictional tables (like 'energy'), multiple queries, or explanations.
|
| 260 |
+
- Output ONLY the SQL query inside ```sql ... ``` block.
|
| 261 |
+
- STRICTLY READ-ONLY SELECT statements. No INSERT/UPDATE/DELETE."""},
|
| 262 |
+
{"role": "user", "content": prompt}
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
logger.info("Processing inputs...")
|
| 266 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 267 |
+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 268 |
+
|
| 269 |
+
logger.info("Generating response...")
|
| 270 |
+
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
|
| 271 |
+
generated_ids = model.generate(
|
| 272 |
+
**inputs,
|
| 273 |
+
max_new_tokens=256,
|
| 274 |
+
do_sample=False,
|
| 275 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 276 |
+
use_cache=True
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
| 280 |
+
output_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=True)
|
| 281 |
+
|
| 282 |
+
generated_sql = output_text[0]
|
| 283 |
+
|
| 284 |
+
print("\n--- Model Response ---")
|
| 285 |
+
print(generated_sql)
|
| 286 |
+
|
| 287 |
+
if torch.cuda.is_available():
|
| 288 |
+
torch.cuda.empty_cache()
|
| 289 |
+
logger.info(f"Post-gen VRAM: {torch.cuda.memory_allocated(0)/1024**3:.1f}GB")
|
| 290 |
+
|
| 291 |
+
# Final formatting and DB execution
|
| 292 |
+
print("\n--- Executing SQL and Formatting Results ---")
|
| 293 |
+
formatted_response = bot.execute_and_format_query(generated_sql, prompt.capitalize())
|
| 294 |
+
|
| 295 |
+
print("\n--- Formatted Response ---")
|
| 296 |
+
print(formatted_response)
|
| 297 |
+
print("\n" + "="*80 + "\n")
|
| 298 |
+
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logger.critical(f"Application crashed: {e}")
|
| 301 |
+
finally:
|
| 302 |
+
# Resource cleanup
|
| 303 |
+
if bot:
|
| 304 |
+
bot.close_connections()
|
| 305 |
+
if torch.cuda.is_available():
|
| 306 |
+
torch.cuda.empty_cache()
|
| 307 |
+
logger.info("GPU cache cleared.")
|
| 308 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.2.6
|
| 2 |
+
transformers==5.3.0
|
| 3 |
+
peft==0.18.1
|
| 4 |
+
bitsandbytes==0.49.2
|
| 5 |
+
psycopg2-binary==2.9.11
|
| 6 |
+
groq==1.0.0
|
| 7 |
+
python-dotenv==1.1.1
|
| 8 |
+
huggingface-hub==1.6.0
|
| 9 |
+
accelerate==1.13.0
|
test_torch.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 3 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 4 |
+
if torch.cuda.is_available():
|
| 5 |
+
print(f"CUDA version: {torch.version.cuda}")
|
| 6 |
+
print(f"GPU name: {torch.cuda.get_device_name(0)}")
|
| 7 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 8 |
+
else:
|
| 9 |
+
print("No GPU")
|
| 10 |
+
|