| """ |
| NoNoQL - Natural Language to SQL/MongoDB Query Generator |
| Streamlit Frontend Application (HuggingFace Spaces Version) |
| """ |
|
|
| import streamlit as st |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import os |
| import json |
| from datetime import datetime |
|
|
| |
| HF_MODEL_REPO = "mohhhhhit/nonoql" |
|
|
| |
| def is_hf_space(): |
| """Check if running on HuggingFace Spaces""" |
| return os.getenv("SPACE_ID") is not None |
|
|
| |
| DEFAULT_MODEL_PATH = HF_MODEL_REPO if is_hf_space() else "models" |
|
|
| HISTORY_FILE_PATH = os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), |
| "data", |
| "query_history.json" |
| ) |
|
|
| SCHEMA_FILE_PATH = os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), |
| "data", |
| "database_schema.txt" |
| ) |
|
|
| DEFAULT_SCHEMA = """**employees** |
| - employee_id, name, email |
| - department, salary, hire_date, age |
| |
| **departments** |
| - department_id, department_name |
| - manager_id, budget, location |
| |
| **projects** |
| - project_id, project_name |
| - start_date, end_date, budget, status |
| |
| **orders** |
| - order_id, customer_name |
| - product_name, quantity |
| - order_date, total_amount |
| |
| **products** |
| - product_id, product_name |
| - category, price |
| - stock_quantity, supplier""" |
|
|
| |
| st.set_page_config( |
| page_title="NoNoQL - Natural Language to SQL/MongoDB Query Generator", |
| page_icon="🔍", |
| layout="wide", |
| initial_sidebar_state="expanded" |
| ) |
|
|
| |
| st.markdown(""" |
| <style> |
| /* Inject title into Streamlit header bar */ |
| header[data-testid="stHeader"] { |
| background-color: rgba(14, 17, 23, 0.95) !important; |
| } |
| |
| header[data-testid="stHeader"]::before { |
| content: "NoNoQL"; |
| color: white; |
| font-size: 1.3rem; |
| font-weight: 600; |
| position: absolute; |
| left: 1rem; |
| top: 50%; |
| transform: translateY(-50%); |
| z-index: 999; |
| } |
| |
| .query-box { |
| background-color: #f0f2f6; |
| border-radius: 10px; |
| padding: 20px; |
| margin: 10px 0; |
| border-left: 5px solid #1E88E5; |
| } |
| .success-box { |
| background-color: #d4edda; |
| border-radius: 10px; |
| padding: 20px; |
| margin: 10px 0; |
| border-left: 5px solid #28a745; |
| } |
| .example-query { |
| background-color: #fff3cd; |
| border-radius: 5px; |
| padding: 10px; |
| margin: 5px 0; |
| cursor: pointer; |
| } |
| .example-query:hover { |
| background-color: #ffe69c; |
| } |
| .stButton>button { |
| width: 100%; |
| background-color: #1E88E5; |
| color: white; |
| font-size: 1.1rem; |
| padding: 0.5rem 1rem; |
| border-radius: 10px; |
| border: none; |
| margin-top: 1rem; |
| } |
| .stButton>button:hover { |
| background-color: #1565C0; |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
|
|
| def extract_columns_from_nl(natural_language_query): |
| """Extract table name and column names from natural language query""" |
| import re |
| |
| nl = natural_language_query.lower().strip() |
| |
| |
| table_match = re.search(r'(?:table|collection)\s+(?:named|called)?\s*(\w+)', nl) |
| table_name = table_match.group(1) if table_match else None |
| |
| |
| columns = [] |
| |
| |
| col_match = re.search(r'columns?\s+(?:as|named|like|called)?\s*([^,]+(?:,\s*[^,]+)*)', nl) |
| if col_match: |
| col_text = col_match.group(1) |
| |
| columns = re.split(r',|\s+and\s+', col_text) |
| columns = [c.strip() for c in columns if c.strip()] |
| |
| |
| if not columns: |
| col_match = re.search(r'(?:add|with)\s+(?:columns?)?\s*([^,]+(?:,\s*[^,]+)*)', nl) |
| if col_match: |
| col_text = col_match.group(1) |
| columns = re.split(r',|\s+and\s+', col_text) |
| columns = [c.strip() for c in columns if c.strip()] |
| |
| return table_name, columns |
|
|
|
|
| def fix_create_table_sql(generated_sql, table_name, requested_columns): |
| """Replace hallucinated columns with actual requested columns in CREATE TABLE""" |
| import re |
| |
| if not table_name or not requested_columns: |
| return generated_sql |
| |
| |
| if not re.search(r'CREATE\s+TABLE', generated_sql, re.IGNORECASE): |
| return generated_sql |
| |
| |
| def infer_type(col_name): |
| col_lower = col_name.lower() |
| if 'id' in col_lower: |
| return 'INT PRIMARY KEY' |
| elif any(word in col_lower for word in ['name', 'title', 'description', 'address', 'city']): |
| return 'VARCHAR(100)' |
| elif any(word in col_lower for word in ['email']): |
| return 'VARCHAR(100)' |
| elif any(word in col_lower for word in ['phone', 'contact', 'mobile']): |
| return 'VARCHAR(20)' |
| elif any(word in col_lower for word in ['date', 'created', 'updated']): |
| return 'DATE' |
| elif any(word in col_lower for word in ['price', 'salary', 'amount', 'cost']): |
| return 'DECIMAL(10,2)' |
| elif any(word in col_lower for word in ['age', 'quantity', 'count', 'stock']): |
| return 'INT' |
| elif any(word in col_lower for word in ['status', 'type', 'category']): |
| return 'VARCHAR(50)' |
| else: |
| return 'VARCHAR(100)' |
| |
| |
| col_defs = [] |
| for col in requested_columns: |
| col_clean = col.strip() |
| if col_clean: |
| col_type = infer_type(col_clean) |
| col_defs.append(f"{col_clean} {col_type}") |
| |
| |
| |
| if_not_exists_match = re.search( |
| r'CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+' + re.escape(table_name), |
| generated_sql, |
| re.IGNORECASE |
| ) |
| if if_not_exists_match: |
| create_clause = if_not_exists_match.group(0) |
| else: |
| create_match = re.search( |
| r'CREATE\s+TABLE\s+' + re.escape(table_name), |
| generated_sql, |
| re.IGNORECASE |
| ) |
| if not create_match: |
| return generated_sql |
| create_clause = create_match.group(0) |
|
|
| new_columns = ', '.join(col_defs) |
| return f"{create_clause} ({new_columns});" |
|
|
|
|
| def fix_create_collection_mongo(generated_mongo, table_name, requested_columns): |
| """Fix MongoDB createCollection to use correct collection name and sample document""" |
| if not table_name: |
| return generated_mongo |
| |
| |
| doc_fields = [] |
| for col in requested_columns: |
| col_clean = col.strip() |
| if col_clean: |
| |
| if 'id' in col_clean.lower(): |
| doc_fields.append(f'"{col_clean}": 1') |
| elif any(word in col_clean.lower() for word in ['name', 'title']): |
| doc_fields.append(f'"{col_clean}": "sample_name"') |
| elif 'email' in col_clean.lower(): |
| doc_fields.append(f'"{col_clean}": "user@example.com"') |
| elif any(word in col_clean.lower() for word in ['phone', 'contact']): |
| doc_fields.append(f'"{col_clean}": "1234567890"') |
| else: |
| doc_fields.append(f'"{col_clean}": "sample_value"') |
| |
| |
| if doc_fields: |
| fixed_mongo = f"db.{table_name}.insertOne({{{', '.join(doc_fields)}}});" |
| else: |
| fixed_mongo = f"db.createCollection('{table_name}');" |
| |
| return fixed_mongo |
|
|
|
|
| def detect_comparison_operator(natural_language_query): |
| """Detect comparison operator from natural language |
| |
| Returns: operator string ('>', '<', '>=', '<=', '=') or None |
| """ |
| import re |
| |
| nl = natural_language_query.lower() |
| |
| |
| if re.search(r'\b(greater than|more than|above|exceeds?)\b', nl): |
| return '>' |
| elif re.search(r'\b(less than|fewer than|below|under)\b', nl): |
| return '<' |
| elif re.search(r'\b(greater than or equal to|at least|minimum)\b', nl): |
| return '>=' |
| elif re.search(r'\b(less than or equal to|at most|maximum)\b', nl): |
| return '<=' |
| elif re.search(r'\b(equals?|is|=)\b', nl): |
| return '=' |
| |
| return None |
|
|
|
|
| def fix_sql_operation_type(generated_sql, natural_language_query): |
| """Fix SQL queries with wrong operation type (SELECT vs DELETE vs UPDATE vs INSERT)""" |
| import re |
| |
| nl = natural_language_query.lower() |
| |
| |
| if re.search(r'\b(delete|remove)\b', nl): |
| |
| if re.match(r'SELECT\s+\*\s+FROM', generated_sql, re.IGNORECASE): |
| |
| match = re.search(r'SELECT\s+\*\s+FROM\s+(\w+)(\s+WHERE\s+.+)?', generated_sql, re.IGNORECASE) |
| if match: |
| table = match.group(1) |
| where_clause = match.group(2) if match.group(2) else '' |
| generated_sql = f"DELETE FROM {table}{where_clause}" |
| |
| return generated_sql |
|
|
|
|
| def fix_mongodb_operation_type(generated_mongo, natural_language_query): |
| """Fix MongoDB queries with wrong operation type""" |
| import re |
| |
| nl = natural_language_query.lower() |
| |
| |
| if re.search(r'\b(delete|remove)\b', nl): |
| |
| if re.search(r'\.(find|findOne|insertOne|deleteOne)\s*\(', generated_mongo): |
| |
| generated_mongo = re.sub( |
| r'\.(find|findOne|insertOne|deleteOne)\s*\(', |
| '.deleteMany(', |
| generated_mongo |
| ) |
| |
| return generated_mongo |
|
|
|
|
| def fix_mongodb_missing_braces(generated_mongo): |
| """Fix MongoDB queries that are missing curly braces around query objects |
| |
| Example: db.collection.find("field": value) -> db.collection.find({"field": value}) |
| """ |
| import re |
| |
| |
| |
| |
| |
| pattern1 = r'(\.\w+)\(\"(\w+)\":\s*([^)]+)\)' |
| match = re.search(pattern1, generated_mongo) |
| if match: |
| method = match.group(1) |
| field = match.group(2) |
| value = match.group(3).strip() |
| |
| value = value.rstrip(';') |
| |
| generated_mongo = re.sub( |
| pattern1, |
| method + '({"' + field + '": ' + value + '})', |
| generated_mongo |
| ) |
| else: |
| |
| pattern2 = r'(\.\w+)\((\w+):\s*([^)]+)\)' |
| match = re.search(pattern2, generated_mongo) |
| if match: |
| method = match.group(1) |
| field = match.group(2) |
| value = match.group(3).strip() |
| value = value.rstrip(';') |
| generated_mongo = re.sub( |
| pattern2, |
| method + '({' + field + ': ' + value + '})', |
| generated_mongo |
| ) |
| |
| return generated_mongo |
|
|
|
|
| def fix_comparison_operator_sql(generated_sql, natural_language_query): |
| """Fix SQL queries with wrong comparison operators""" |
| import re |
| |
| correct_op = detect_comparison_operator(natural_language_query) |
| |
| if correct_op and correct_op != '=': |
| |
| |
| generated_sql = re.sub( |
| r'(WHERE\s+\w+)\s*=\s*', |
| r'\1 ' + correct_op + ' ', |
| generated_sql, |
| flags=re.IGNORECASE |
| ) |
| |
| return generated_sql |
|
|
|
|
| def fix_comparison_operator_mongodb(generated_mongo, natural_language_query): |
| """Fix MongoDB queries with wrong comparison operators""" |
| import re |
| |
| correct_op = detect_comparison_operator(natural_language_query) |
| |
| if correct_op and correct_op != '=': |
| |
| mongo_op_map = { |
| '>': '$gt', |
| '<': '$lt', |
| '>=': '$gte', |
| '<=': '$lte' |
| } |
| |
| mongo_op = mongo_op_map.get(correct_op) |
| |
| if mongo_op: |
| |
| |
| |
| |
| pattern1 = r'\{"(\w+)":\s*([^,}{]+)\}' |
| match = re.search(pattern1, generated_mongo) |
| if match: |
| field = match.group(1) |
| value = match.group(2).strip() |
| |
| replacement = '{"' + field + '": {' + mongo_op + ': ' + value + '}}' |
| generated_mongo = re.sub(pattern1, replacement, generated_mongo, count=1) |
| else: |
| |
| pattern2 = r'\{(\w+):\s*([^,}{]+)\}' |
| match = re.search(pattern2, generated_mongo) |
| if match: |
| field = match.group(1) |
| value = match.group(2).strip() |
| |
| replacement = '{' + field + ': {' + mongo_op + ': ' + value + '}}' |
| generated_mongo = re.sub(pattern2, replacement, generated_mongo, count=1) |
| |
| return generated_mongo |
|
|
|
|
| def parse_update_query(natural_language_query): |
| """Parse UPDATE query from natural language |
| |
| Example: "Update employees set department to Sales where employee_id is 101" |
| Returns: (table, set_column, set_value, where_column, where_value) |
| """ |
| import re |
| |
| |
| |
| |
| match = re.search( |
| r'update\s+(\w+)\s+set\s+(\w+)\s+to\s+([^\s]+(?:\s+[^\s]+)*?)\s+where\s+(\w+)\s+(?:is|equals?|=)\s+(.+)', |
| natural_language_query, |
| re.IGNORECASE |
| ) |
| |
| if match: |
| table_name = match.group(1) |
| set_column = match.group(2) |
| set_value = match.group(3).strip() |
| where_column = match.group(4) |
| where_value = match.group(5).strip() |
| return (table_name, set_column, set_value, where_column, where_value) |
| |
| |
| match = re.search( |
| r'update\s+(\w+)\s+set\s+(\w+)\s*=\s*([^\s]+(?:\s+[^\s]+)*?)\s+where\s+(\w+)\s*=\s*(.+)', |
| natural_language_query, |
| re.IGNORECASE |
| ) |
| |
| if match: |
| table_name = match.group(1) |
| set_column = match.group(2) |
| set_value = match.group(3).strip() |
| where_column = match.group(4) |
| where_value = match.group(5).strip() |
| return (table_name, set_column, set_value, where_column, where_value) |
| |
| return None |
|
|
|
|
| def fix_update_query_sql(generated_sql, natural_language_query): |
| """Fix malformed UPDATE SQL queries""" |
| import re |
| |
| |
| if 'update' in natural_language_query.lower(): |
| |
| if not re.search(r'UPDATE\s+\w+\s+SET', generated_sql, re.IGNORECASE): |
| parsed = parse_update_query(natural_language_query) |
| if parsed: |
| table, set_col, set_val, where_col, where_val = parsed |
| |
| |
| try: |
| |
| float(set_val) |
| set_val_quoted = set_val |
| except: |
| set_val_quoted = f"'{set_val}'" |
| |
| try: |
| float(where_val) |
| where_val_quoted = where_val |
| except: |
| where_val_quoted = f"'{where_val}'" |
| |
| |
| return f"UPDATE {table} SET {set_col} = {set_val_quoted} WHERE {where_col} = {where_val_quoted};" |
| |
| return generated_sql |
|
|
|
|
| def fix_update_query_mongodb(generated_mongo, natural_language_query): |
| """Fix malformed UPDATE MongoDB queries""" |
| import re |
| |
| |
| if 'update' in natural_language_query.lower(): |
| |
| if not re.search(r'\.update', generated_mongo, re.IGNORECASE): |
| parsed = parse_update_query(natural_language_query) |
| if parsed: |
| table, set_col, set_val, where_col, where_val = parsed |
| |
| |
| try: |
| float(set_val) |
| set_val_formatted = set_val |
| except: |
| set_val_formatted = f'"{set_val}"' |
| |
| try: |
| float(where_val) |
| where_val_formatted = where_val |
| except: |
| where_val_formatted = f'"{where_val}"' |
| |
| |
| return f"db.{table}.updateMany({{{where_col}: {where_val_formatted}}}, {{$set: {{{set_col}: {set_val_formatted}}}}});" |
| |
| return generated_mongo |
|
|
|
|
| class TexQLModel: |
| """Unified model wrapper for SQL/MongoDB generation""" |
| |
| def __init__(self, model_path): |
| """Initialize the model for inference""" |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.loaded = False |
| |
| try: |
| |
| with st.spinner(f"Loading model from {'HuggingFace Hub' if '/' in model_path else 'local path'}..."): |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
| self.model.to(self.device) |
| self.model.eval() |
| self.loaded = True |
| st.success(f"✅ Model loaded successfully on {self.device.upper()}") |
| except Exception as e: |
| st.error(f"❌ Error loading model: {str(e)}") |
| if is_hf_space(): |
| st.info("💡 Model is loading from HuggingFace Hub - this may take a moment on first run") |
| |
| def generate_query(self, natural_language_query, target_type='sql', temperature=0.3, |
| num_beams=10, repetition_penalty=1.2, length_penalty=0.8): |
| """Generate SQL or MongoDB query from natural language |
| |
| Args: |
| natural_language_query: The user's natural language query |
| target_type: 'sql' or 'mongodb' to specify output format |
| temperature: Sampling temperature (lower = more focused) |
| num_beams: Number of beams for beam search |
| repetition_penalty: Penalty for repeating tokens (>1.0 discourages repetition) |
| length_penalty: Penalty for length (>1.0 encourages longer, <1.0 encourages shorter) |
| """ |
| if not self.loaded: |
| return "Model not loaded" |
| |
| input_text = f"translate to {target_type}: {natural_language_query}" |
| |
| inputs = self.tokenizer( |
| input_text, |
| return_tensors="pt", |
| max_length=256, |
| truncation=True |
| ).to(self.device) |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_length=512, |
| num_beams=num_beams, |
| temperature=temperature, |
| repetition_penalty=repetition_penalty, |
| length_penalty=length_penalty, |
| no_repeat_ngram_size=3, |
| early_stopping=True, |
| do_sample=False |
| ) |
| |
| generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| if any(word in natural_language_query.lower() for word in ['create', 'add columns']): |
| table_name, requested_columns = extract_columns_from_nl(natural_language_query) |
| |
| if table_name and requested_columns: |
| if target_type == 'sql': |
| generated_query = fix_create_table_sql(generated_query, table_name, requested_columns) |
| elif target_type == 'mongodb': |
| generated_query = fix_create_collection_mongo(generated_query, table_name, requested_columns) |
| |
| |
| if 'update' in natural_language_query.lower() and 'set' in natural_language_query.lower(): |
| if target_type == 'sql': |
| generated_query = fix_update_query_sql(generated_query, natural_language_query) |
| elif target_type == 'mongodb': |
| generated_query = fix_update_query_mongodb(generated_query, natural_language_query) |
| |
| |
| if target_type == 'sql': |
| generated_query = fix_sql_operation_type(generated_query, natural_language_query) |
| elif target_type == 'mongodb': |
| generated_query = fix_mongodb_operation_type(generated_query, natural_language_query) |
| |
| |
| if target_type == 'mongodb': |
| generated_query = fix_mongodb_missing_braces(generated_query) |
| |
| |
| if target_type == 'sql': |
| generated_query = fix_comparison_operator_sql(generated_query, natural_language_query) |
| elif target_type == 'mongodb': |
| generated_query = fix_comparison_operator_mongodb(generated_query, natural_language_query) |
| |
| return generated_query |
|
|
|
|
| @st.cache_resource |
| def load_model(model_path): |
| """Load the unified NoNoQL model (cached)""" |
| model = None |
| |
| |
| if '/' in model_path or not os.path.exists(model_path): |
| model = TexQLModel(model_path) |
| elif os.path.exists(model_path): |
| model = TexQLModel(model_path) |
| else: |
| st.error(f"❌ Model path not found: {model_path}") |
| |
| return model |
|
|
|
|
| def save_query_history(nl_query, sql_query, mongodb_query, max_history=500): |
| """Save query to history with size limit""" |
| if 'history' not in st.session_state: |
| st.session_state.history = [] |
| |
| st.session_state.history.append({ |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| 'natural_language': nl_query, |
| 'sql': sql_query, |
| 'mongodb': mongodb_query |
| }) |
| |
| |
| if len(st.session_state.history) > max_history: |
| st.session_state.history = st.session_state.history[-max_history:] |
|
|
| persist_query_history(st.session_state.history) |
|
|
|
|
| def delete_history_entry(index): |
| """Delete a specific history entry""" |
| if 'history' in st.session_state and 0 <= index < len(st.session_state.history): |
| st.session_state.history.pop(index) |
| persist_query_history(st.session_state.history) |
|
|
|
|
| def load_query_history(): |
| """Load query history from disk""" |
| try: |
| if not os.path.exists(HISTORY_FILE_PATH): |
| return [] |
|
|
| with open(HISTORY_FILE_PATH, "r", encoding="utf-8") as history_file: |
| history = json.load(history_file) |
|
|
| if isinstance(history, list): |
| return history |
| return [] |
| except Exception: |
| return [] |
|
|
|
|
| def persist_query_history(history): |
| """Persist query history to disk""" |
| try: |
| os.makedirs(os.path.dirname(HISTORY_FILE_PATH), exist_ok=True) |
| with open(HISTORY_FILE_PATH, "w", encoding="utf-8") as history_file: |
| json.dump(history, history_file, indent=2) |
| except Exception: |
| pass |
|
|
|
|
| def load_schema(): |
| """Load database schema from disk""" |
| try: |
| if not os.path.exists(SCHEMA_FILE_PATH): |
| return DEFAULT_SCHEMA |
|
|
| with open(SCHEMA_FILE_PATH, "r", encoding="utf-8") as schema_file: |
| schema = schema_file.read() |
|
|
| return schema if schema.strip() else DEFAULT_SCHEMA |
| except Exception: |
| return DEFAULT_SCHEMA |
|
|
|
|
| def persist_schema(schema): |
| """Persist database schema to disk""" |
| try: |
| os.makedirs(os.path.dirname(SCHEMA_FILE_PATH), exist_ok=True) |
| with open(SCHEMA_FILE_PATH, "w", encoding="utf-8") as schema_file: |
| schema_file.write(schema) |
| except Exception: |
| pass |
|
|
|
|
| def main(): |
| |
| if is_hf_space(): |
| st.info("🤗 Running on HuggingFace Spaces - Model loaded from Hub") |
| |
| if 'history' not in st.session_state: |
| st.session_state.history = load_query_history() |
| |
| if 'schema' not in st.session_state: |
| st.session_state.schema = load_schema() |
| |
| if 'schema_edit_mode' not in st.session_state: |
| st.session_state.schema_edit_mode = False |
|
|
| |
| with st.sidebar: |
| st.header("⚙️ Configuration") |
| |
| |
| st.subheader("Model Path") |
| model_path = st.text_input( |
| "NoNoQL Model Path", |
| value=DEFAULT_MODEL_PATH, |
| help="HuggingFace repo (user/repo) or local path" |
| ) |
| |
| |
| if '/' in model_path: |
| st.caption(f"📥 Loading from HuggingFace: [{model_path}](https://huggingface.co/{model_path})") |
| else: |
| st.caption(f"📂 Loading from local path: {model_path}") |
| |
| |
| st.subheader("Generation Parameters") |
| temperature = st.slider( |
| "Temperature", |
| min_value=0.1, |
| max_value=1.0, |
| value=0.3, |
| step=0.1, |
| help="Lower = more focused, Higher = more creative" |
| ) |
| num_beams = st.slider( |
| "Beam Search Width", |
| min_value=1, |
| max_value=10, |
| value=10, |
| help="Higher values improve accuracy (recommended: keep at 10)" |
| ) |
| repetition_penalty = st.slider( |
| "Repetition Penalty", |
| min_value=1.0, |
| max_value=2.0, |
| value=1.2, |
| step=0.1, |
| help="Higher = less repetition (prevents hallucinating extra columns)" |
| ) |
| length_penalty = st.slider( |
| "Length Penalty", |
| min_value=0.5, |
| max_value=1.5, |
| value=0.8, |
| step=0.1, |
| help="Lower = prefer shorter outputs, Higher = prefer longer outputs" |
| ) |
| |
| |
| if st.button("🔄 Load/Reload Models"): |
| st.cache_resource.clear() |
| st.rerun() |
| |
| |
| st.subheader("📚 History Settings") |
| max_history_size = st.number_input( |
| "Max History Entries", |
| min_value=10, |
| max_value=1000, |
| value=500, |
| step=10, |
| help="Maximum number of queries to keep in history" |
| ) |
| |
| |
| st.subheader("📊 Database Schema") |
| |
| |
| col1, col2 = st.columns([1, 3]) |
| with col1: |
| if st.button("✏️ Edit" if not st.session_state.schema_edit_mode else "👁️ View"): |
| st.session_state.schema_edit_mode = not st.session_state.schema_edit_mode |
| st.rerun() |
| with col2: |
| if st.session_state.schema_edit_mode: |
| st.info("✏️ Editing Mode") |
| else: |
| st.caption("View your database tables and columns") |
| |
| if st.session_state.schema_edit_mode: |
| |
| edited_schema = st.text_area( |
| "Edit Database Schema", |
| value=st.session_state.schema, |
| height=300, |
| help="Define your database tables and columns. Use Markdown format." |
| ) |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| if st.button("💾 Save Schema", use_container_width=True): |
| st.session_state.schema = edited_schema |
| persist_schema(edited_schema) |
| if is_hf_space(): |
| st.warning("⚠️ Schema saved to session only (HF Spaces has read-only filesystem)") |
| else: |
| st.success("Schema saved!") |
| st.session_state.schema_edit_mode = False |
| st.rerun() |
| |
| with col2: |
| if st.button("🔄 Reset to Default", use_container_width=True): |
| st.session_state.schema = DEFAULT_SCHEMA |
| persist_schema(DEFAULT_SCHEMA) |
| st.success("Schema reset to default!") |
| st.rerun() |
| else: |
| |
| with st.expander("View Available Tables", expanded=False): |
| st.markdown(st.session_state.schema) |
| |
| |
| with st.spinner("Loading model..."): |
| model = load_model(model_path) |
| |
| |
| if model and model.loaded: |
| device_info = "🎮 GPU" if model.device == "cuda" else "💻 CPU" |
| st.success(f"✅ Model Loaded ({device_info})") |
| st.info("💡 This model generates both SQL and MongoDB queries") |
| else: |
| st.error("⚠️ Model Not Available - Please check the model path") |
| |
| |
| st.subheader("🔤 Enter Your Query") |
| |
| |
| with st.expander("💡 Example Queries - Click to expand"): |
| examples = [ |
| "Show all employees", |
| "Find employees where salary is greater than 50000", |
| "Get all departments with budget more than 100000", |
| "Insert a new employee with name John Doe, email john@example.com, department Engineering", |
| "Update employees set department to Sales where employee_id is 101", |
| "Delete orders with total_amount less than 1000", |
| "Count all products in Electronics category", |
| "Show top 10 employees ordered by salary", |
| ] |
| |
| selected_example = st.selectbox( |
| "Choose an example query:", |
| [""] + examples, |
| index=0, |
| format_func=lambda x: "Select an example..." if x == "" else x |
| ) |
| |
| if selected_example and st.button("📝 Use This Example", use_container_width=True): |
| st.session_state.user_query = selected_example |
| st.rerun() |
| |
| user_query = st.text_area( |
| "or", |
| value=st.session_state.get('user_query', ''), |
| height=100, |
| placeholder="write your query here..." |
| ) |
| |
| |
| if st.button("🚀 Generate Queries"): |
| if not user_query.strip(): |
| st.warning("Please enter a query") |
| elif not model or not model.loaded: |
| st.error("Model is not loaded. Please check the model path and reload.") |
| else: |
| with st.spinner("Generating queries..."): |
| |
| sql_query = model.generate_query( |
| user_query, |
| target_type='sql', |
| temperature=temperature, |
| num_beams=num_beams, |
| repetition_penalty=repetition_penalty, |
| length_penalty=length_penalty |
| ) |
| |
| mongodb_query = model.generate_query( |
| user_query, |
| target_type='mongodb', |
| temperature=temperature, |
| num_beams=num_beams, |
| repetition_penalty=repetition_penalty, |
| length_penalty=length_penalty |
| ) |
| |
| |
| save_query_history(user_query, sql_query, mongodb_query, max_history_size) |
| |
| |
| st.markdown("---") |
| st.success("✅ Queries Generated Successfully!") |
| |
| |
| st.markdown('<div class="query-box">', unsafe_allow_html=True) |
| st.markdown("**📝 Your Query:**") |
| st.code(user_query, language="text") |
| st.markdown('</div>', unsafe_allow_html=True) |
| |
| |
| col1, col2 = st.columns(2) |
| |
| with col1: |
| st.markdown("### 🗄️ SQL Query") |
| st.code(sql_query, language="sql") |
| |
| |
| if st.button("📋 Copy SQL", key="copy_sql"): |
| st.session_state.clipboard = sql_query |
| st.success("Copied to clipboard!") |
| |
| with col2: |
| st.markdown("### 🍃 MongoDB Query") |
| st.code(mongodb_query, language="javascript") |
| |
| |
| if st.button("📋 Copy MongoDB", key="copy_mongo"): |
| st.session_state.clipboard = mongodb_query |
| st.success("Copied to clipboard!") |
| |
| |
| if 'history' in st.session_state and st.session_state.history: |
| st.markdown("---") |
| st.subheader("📚 Query History") |
| |
| |
| col1, col2, col3 = st.columns([2, 1, 1]) |
| |
| with col1: |
| search_term = st.text_input( |
| "🔍 Search History", |
| placeholder="Search in queries...", |
| label_visibility="collapsed" |
| ) |
| |
| with col2: |
| sort_order = st.selectbox( |
| "Sort", |
| ["Newest First", "Oldest First"], |
| label_visibility="collapsed" |
| ) |
| |
| with col3: |
| show_limit = st.number_input( |
| "Show", |
| min_value=5, |
| max_value=100, |
| value=10, |
| step=5, |
| label_visibility="collapsed" |
| ) |
| |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| if st.button("🗑️ Clear All History"): |
| st.session_state.history = [] |
| persist_query_history(st.session_state.history) |
| st.rerun() |
| |
| with col2: |
| if st.button("💾 Export History"): |
| history_json = json.dumps(st.session_state.history, indent=2) |
| st.download_button( |
| label="Download History (JSON)", |
| data=history_json, |
| file_name=f"nonoql_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", |
| mime="application/json" |
| ) |
| |
| |
| filtered_history = st.session_state.history |
| if search_term: |
| search_lower = search_term.lower() |
| filtered_history = [ |
| entry for entry in st.session_state.history |
| if search_lower in entry['natural_language'].lower() or |
| search_lower in entry.get('sql', '').lower() or |
| search_lower in entry.get('mongodb', '').lower() |
| ] |
| |
| |
| if sort_order == "Oldest First": |
| display_history = filtered_history[:show_limit] |
| else: |
| display_history = list(reversed(filtered_history[-show_limit:])) |
| |
| |
| st.markdown(f"**Showing {len(display_history)} of {len(filtered_history)} queries** (Total: {len(st.session_state.history)})") |
| |
| if not display_history: |
| st.info("No queries found matching your search.") |
| |
| |
| for display_idx, entry in enumerate(display_history): |
| |
| actual_idx = st.session_state.history.index(entry) |
| |
| with st.expander( |
| f"🕐 {entry['timestamp']} - {entry['natural_language'][:60]}...", |
| expanded=False |
| ): |
| |
| col1, col2, col3 = st.columns([3, 1, 1]) |
| |
| with col1: |
| st.markdown(f"**Natural Language Query:**") |
| st.info(entry['natural_language']) |
| |
| with col2: |
| if st.button("🔄 Rerun", key=f"rerun_{actual_idx}"): |
| st.session_state.user_query = entry['natural_language'] |
| st.rerun() |
| |
| with col3: |
| if st.button("🗑️ Delete", key=f"del_{actual_idx}"): |
| delete_history_entry(actual_idx) |
| st.rerun() |
| |
| |
| col1, col2 = st.columns(2) |
| with col1: |
| st.markdown("**SQL Query:**") |
| if entry.get('sql'): |
| st.code(entry['sql'], language="sql") |
| else: |
| st.text("N/A") |
| |
| with col2: |
| st.markdown("**MongoDB Query:**") |
| if entry.get('mongodb'): |
| st.code(entry['mongodb'], language="javascript") |
| else: |
| st.text("N/A") |
| |
| |
| st.markdown("---") |
| st.markdown(""" |
| <div style='text-align: center; color: #666; padding: 2rem;'> |
| <p>NoNoQL - Natural Language to Query Generator</p> |
| <p>Powered by T5 Transformer Models | Built with Streamlit</p> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|