moheesh commited on
Commit
c177f34
Β·
1 Parent(s): be9c684

removed lazy loading

Browse files
Files changed (1) hide show
  1. src/app.py +102 -69
src/app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Streamlit App for SQL Learning Assistant
3
- Integrates: RAG + Fine-tuned Model + Gemini Enhancement
4
  """
5
 
6
  import streamlit as st
@@ -26,60 +26,103 @@ st.set_page_config(
26
  )
27
 
28
  # =============================================================================
29
- # CACHED LOADERS - Load on-demand, cache forever
30
  # =============================================================================
31
 
32
- @st.cache_resource(show_spinner=False)
33
- def load_chromadb():
34
- """Download ChromaDB from HuggingFace if needed."""
 
 
 
 
 
 
 
 
 
 
 
 
35
  chromadb_path = "chromadb_data"
36
- hf_chromadb_id = os.getenv("HF_CHROMADB_ID", None)
37
 
 
38
  has_files = False
39
  if os.path.exists(chromadb_path):
40
  local_files = os.listdir(chromadb_path) if os.path.isdir(chromadb_path) else []
41
  has_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2
42
 
43
  if not has_files and hf_chromadb_id:
 
44
  from huggingface_hub import snapshot_download
45
  os.makedirs(chromadb_path, exist_ok=True)
46
  snapshot_download(repo_id=hf_chromadb_id, repo_type="dataset", local_dir=chromadb_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- return chromadb_path
49
-
50
- @st.cache_resource(show_spinner=False)
51
- def load_retriever():
52
- """Load the RAG retriever."""
53
- load_chromadb()
54
- from rag.retriever import SQLRetriever
55
- return SQLRetriever()
56
-
57
- @st.cache_resource(show_spinner=False)
58
- def load_model():
59
- """Load the fine-tuned model."""
60
- from finetuning.inference import SQLGenerator
61
- return SQLGenerator()
62
-
63
- @st.cache_resource(show_spinner=False)
64
- def load_prompt_builder():
65
- """Load prompt builder."""
66
- from prompts.prompt_builder import PromptBuilder
67
- return PromptBuilder()
68
-
69
- @st.cache_resource(show_spinner=False)
70
- def load_gemini():
71
- """Load Gemini client."""
72
- from pipeline.integrated import GeminiClient, GEMINI_KEYS
73
- if GEMINI_KEYS:
74
- return GeminiClient()
75
- return None
76
 
77
  # =============================================================================
78
  # HELPER FUNCTION TO RUN PIPELINE
79
  # =============================================================================
80
 
81
  def run_pipeline(question, num_examples=3):
82
- """Run the full pipeline - loads components on first use."""
83
  result = {
84
  'question': question,
85
  'success': False,
@@ -89,29 +132,26 @@ def run_pipeline(question, num_examples=3):
89
  # Step 1: RAG
90
  rag_context = ""
91
  examples = []
92
- try:
93
- with st.spinner("πŸ” Loading RAG system..."):
94
- retriever = load_retriever()
95
- if retriever:
96
  examples = retriever.retrieve(question, top_k=num_examples)
97
  rag_context = "Similar SQL examples:\n\n"
98
  for i, r in enumerate(examples, 1):
99
  rag_context += f"Example {i}:\nQuestion: {r['question']}\nSQL: {r['sql']}\n\n"
100
- except Exception as e:
101
- st.warning(f"RAG error: {e}")
102
 
103
  result['steps']['rag'] = {'examples': examples, 'num_examples': len(examples), 'context': rag_context}
104
 
105
  # Step 2: Prompt
106
  prompt = ""
107
- try:
108
- prompt_builder = load_prompt_builder()
109
- if prompt_builder:
110
  prompt_result = prompt_builder.build_prompt(question=question, rag_context=rag_context)
111
  if prompt_result['success']:
112
  prompt = prompt_result['prompt']
113
- except:
114
- pass
115
  if not prompt:
116
  prompt = f"{rag_context}\nQuestion: {question}\n\nSQL:"
117
 
@@ -119,13 +159,11 @@ def run_pipeline(question, num_examples=3):
119
 
120
  # Step 3: Fine-tuned Model
121
  finetuned_sql = None
122
- try:
123
- with st.spinner("πŸ€– Loading AI model..."):
124
- model = load_model()
125
- if model:
126
  finetuned_sql = model.generate(question, rag_context)
127
- except Exception as e:
128
- st.warning(f"Model error: {e}")
129
 
130
  result['steps']['finetuned'] = {'sql': finetuned_sql, 'error': None if finetuned_sql else 'Model not available'}
131
 
@@ -134,9 +172,8 @@ def run_pipeline(question, num_examples=3):
134
 
135
  # Step 4: Gemini Enhancement
136
  enhanced_sql = finetuned_sql
137
- try:
138
- gemini = load_gemini()
139
- if gemini:
140
  enhance_prompt = f"""You are an SQL expert. Review and enhance this SQL query.
141
 
142
  Original Question: {question}
@@ -158,23 +195,22 @@ Enhanced SQL:"""
158
  enhanced_sql = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
159
  if enhanced_sql.lower().startswith("sql"):
160
  enhanced_sql = enhanced_sql[3:].strip()
161
- except Exception as e:
162
- st.warning(f"Gemini enhance error: {e}")
163
 
164
  result['steps']['gemini_enhance'] = {'sql': enhanced_sql, 'info': {'enhanced': enhanced_sql != finetuned_sql}}
165
  result['final_sql'] = enhanced_sql
166
 
167
  # Step 5: Explanation
168
  explanation = ""
169
- try:
170
- gemini = load_gemini()
171
- if gemini:
172
  explain_prompt = f"Explain this SQL query in simple terms (2-3 sentences):\n\nSQL: {enhanced_sql}"
173
  response, error = gemini.generate(explain_prompt)
174
  if response and not error:
175
  explanation = response.strip()
176
- except:
177
- pass
178
 
179
  result['explanation'] = explanation
180
  result['success'] = True
@@ -275,14 +311,11 @@ with st.sidebar:
275
  st.markdown("### πŸ“Š System Status")
276
  col1, col2 = st.columns(2)
277
  with col1:
278
- st.markdown("βœ… **RAG**")
279
- st.markdown("βœ… **Model**")
280
  with col2:
281
- st.markdown("βœ… **Prompts**")
282
- if os.getenv("GEMINI_API_KEY"):
283
- st.markdown("βœ… **Gemini**")
284
- else:
285
- st.markdown("❌ **Gemini**")
286
 
287
  st.markdown("---")
288
 
 
1
  """
2
  Streamlit App for SQL Learning Assistant
3
+ Eager Loading - Load everything at startup
4
  """
5
 
6
  import streamlit as st
 
26
  )
27
 
28
  # =============================================================================
29
+ # LOAD ALL COMPONENTS AT STARTUP (EAGER LOADING)
30
  # =============================================================================
31
 
32
+ @st.cache_resource(show_spinner=True)
33
+ def load_all_components():
34
+ """Load all components at startup."""
35
+ components = {
36
+ 'retriever': None,
37
+ 'model': None,
38
+ 'prompt_builder': None,
39
+ 'gemini': None
40
+ }
41
+
42
+ # 1. Load ChromaDB first
43
+ print("=" * 50)
44
+ print("LOADING ALL COMPONENTS AT STARTUP")
45
+ print("=" * 50)
46
+
47
  chromadb_path = "chromadb_data"
48
+ hf_chromadb_id = os.getenv("HF_CHROMADB_ID")
49
 
50
+ # Check if ChromaDB has actual files
51
  has_files = False
52
  if os.path.exists(chromadb_path):
53
  local_files = os.listdir(chromadb_path) if os.path.isdir(chromadb_path) else []
54
  has_files = any('chroma' in f.lower() or 'sqlite' in f.lower() for f in local_files) or len(local_files) > 2
55
 
56
  if not has_files and hf_chromadb_id:
57
+ print(f"☁️ Downloading ChromaDB from HuggingFace: {hf_chromadb_id}")
58
  from huggingface_hub import snapshot_download
59
  os.makedirs(chromadb_path, exist_ok=True)
60
  snapshot_download(repo_id=hf_chromadb_id, repo_type="dataset", local_dir=chromadb_path)
61
+ print("βœ“ ChromaDB downloaded!")
62
+
63
+ # 2. Load RAG Retriever
64
+ try:
65
+ print("Loading RAG Retriever...")
66
+ from rag.retriever import SQLRetriever
67
+ components['retriever'] = SQLRetriever()
68
+ print("βœ“ RAG Retriever loaded")
69
+ except Exception as e:
70
+ print(f"βœ— RAG error: {e}")
71
+
72
+ # 3. Load Fine-tuned Model
73
+ try:
74
+ print("Loading Fine-tuned Model...")
75
+ from finetuning.inference import SQLGenerator
76
+ components['model'] = SQLGenerator()
77
+ print("βœ“ Fine-tuned Model loaded")
78
+ except Exception as e:
79
+ print(f"βœ— Model error: {e}")
80
+
81
+ # 4. Load Prompt Builder
82
+ try:
83
+ print("Loading Prompt Builder...")
84
+ from prompts.prompt_builder import PromptBuilder
85
+ components['prompt_builder'] = PromptBuilder()
86
+ print("βœ“ Prompt Builder loaded")
87
+ except Exception as e:
88
+ print(f"βœ— Prompt Builder error: {e}")
89
+
90
+ # 5. Load Gemini
91
+ try:
92
+ print("Loading Gemini...")
93
+ from pipeline.integrated import GeminiClient, GEMINI_KEYS
94
+ if GEMINI_KEYS:
95
+ components['gemini'] = GeminiClient()
96
+ print("βœ“ Gemini loaded")
97
+ else:
98
+ print("⚠️ No Gemini API keys found")
99
+ except Exception as e:
100
+ print(f"βœ— Gemini error: {e}")
101
 
102
+ print("=" * 50)
103
+ print("ALL COMPONENTS LOADED")
104
+ print("=" * 50)
105
+
106
+ return components
107
+
108
+ # =============================================================================
109
+ # LOAD COMPONENTS NOW (AT STARTUP)
110
+ # =============================================================================
111
+
112
+ with st.spinner("πŸš€ Loading SQL Learning Assistant... Please wait..."):
113
+ COMPONENTS = load_all_components()
114
+
115
+ retriever = COMPONENTS['retriever']
116
+ model = COMPONENTS['model']
117
+ prompt_builder = COMPONENTS['prompt_builder']
118
+ gemini = COMPONENTS['gemini']
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # =============================================================================
121
  # HELPER FUNCTION TO RUN PIPELINE
122
  # =============================================================================
123
 
124
  def run_pipeline(question, num_examples=3):
125
+ """Run the full pipeline using pre-loaded components."""
126
  result = {
127
  'question': question,
128
  'success': False,
 
132
  # Step 1: RAG
133
  rag_context = ""
134
  examples = []
135
+ if retriever:
136
+ try:
 
 
137
  examples = retriever.retrieve(question, top_k=num_examples)
138
  rag_context = "Similar SQL examples:\n\n"
139
  for i, r in enumerate(examples, 1):
140
  rag_context += f"Example {i}:\nQuestion: {r['question']}\nSQL: {r['sql']}\n\n"
141
+ except Exception as e:
142
+ st.warning(f"RAG error: {e}")
143
 
144
  result['steps']['rag'] = {'examples': examples, 'num_examples': len(examples), 'context': rag_context}
145
 
146
  # Step 2: Prompt
147
  prompt = ""
148
+ if prompt_builder:
149
+ try:
 
150
  prompt_result = prompt_builder.build_prompt(question=question, rag_context=rag_context)
151
  if prompt_result['success']:
152
  prompt = prompt_result['prompt']
153
+ except:
154
+ pass
155
  if not prompt:
156
  prompt = f"{rag_context}\nQuestion: {question}\n\nSQL:"
157
 
 
159
 
160
  # Step 3: Fine-tuned Model
161
  finetuned_sql = None
162
+ if model:
163
+ try:
 
 
164
  finetuned_sql = model.generate(question, rag_context)
165
+ except Exception as e:
166
+ st.warning(f"Model error: {e}")
167
 
168
  result['steps']['finetuned'] = {'sql': finetuned_sql, 'error': None if finetuned_sql else 'Model not available'}
169
 
 
172
 
173
  # Step 4: Gemini Enhancement
174
  enhanced_sql = finetuned_sql
175
+ if gemini:
176
+ try:
 
177
  enhance_prompt = f"""You are an SQL expert. Review and enhance this SQL query.
178
 
179
  Original Question: {question}
 
195
  enhanced_sql = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
196
  if enhanced_sql.lower().startswith("sql"):
197
  enhanced_sql = enhanced_sql[3:].strip()
198
+ except Exception as e:
199
+ st.warning(f"Gemini enhance error: {e}")
200
 
201
  result['steps']['gemini_enhance'] = {'sql': enhanced_sql, 'info': {'enhanced': enhanced_sql != finetuned_sql}}
202
  result['final_sql'] = enhanced_sql
203
 
204
  # Step 5: Explanation
205
  explanation = ""
206
+ if gemini:
207
+ try:
 
208
  explain_prompt = f"Explain this SQL query in simple terms (2-3 sentences):\n\nSQL: {enhanced_sql}"
209
  response, error = gemini.generate(explain_prompt)
210
  if response and not error:
211
  explanation = response.strip()
212
+ except:
213
+ pass
214
 
215
  result['explanation'] = explanation
216
  result['success'] = True
 
311
  st.markdown("### πŸ“Š System Status")
312
  col1, col2 = st.columns(2)
313
  with col1:
314
+ st.markdown("βœ… **RAG**" if retriever else "❌ **RAG**")
315
+ st.markdown("βœ… **Model**" if model else "❌ **Model**")
316
  with col2:
317
+ st.markdown("βœ… **Prompts**" if prompt_builder else "❌ **Prompts**")
318
+ st.markdown("βœ… **Gemini**" if gemini else "❌ **Gemini**")
 
 
 
319
 
320
  st.markdown("---")
321