mtyrrell commited on
Commit
ecc8726
·
1 Parent(s): 5952c14

refactor and ts cleanup

Browse files
app/__pycache__/main.cpython-311.pyc ADDED
Binary file (32.6 kB). View file
 
app/__pycache__/models.cpython-311.pyc ADDED
Binary file (2.49 kB). View file
 
app/__pycache__/nodes.cpython-311.pyc ADDED
Binary file (7.84 kB). View file
 
app/__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/app/__pycache__/utils.cpython-311.pyc and b/app/__pycache__/utils.cpython-311.pyc differ
 
app/main.py CHANGED
@@ -14,17 +14,19 @@ import os
14
  from datetime import datetime
15
  import logging
16
  from contextlib import asynccontextmanager
17
- import threading
18
  from langchain_core.runnables import RunnableLambda
19
- import tempfile
20
- import mimetypes
21
  import asyncio
22
  from typing import Generator
23
  import json
24
  import httpx
25
- import ast
26
 
27
- from utils import getconfig
 
 
28
 
29
  config = getconfig("params.cfg")
30
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
@@ -36,272 +38,9 @@ MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
36
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
37
  logger = logging.getLogger(__name__)
38
 
39
- # CORE FUNCTIONALITY - KEEP THESE
40
- # File type detection
41
- def detect_file_type(filename: str, file_content: bytes = None) -> str:
42
- """Detect file type based on extension and content"""
43
- if not filename:
44
- return "unknown"
45
-
46
- # Get file extension
47
- _, ext = os.path.splitext(filename.lower())
48
-
49
- # Define file type mappings
50
- file_type_mappings = {
51
- '.geojson': 'geojson',
52
- '.json': 'json', # Could be geojson, will check content
53
- '.pdf': 'text',
54
- '.docx': 'text',
55
- '.doc': 'text',
56
- '.txt': 'text',
57
- '.md': 'text',
58
- '.csv': 'text',
59
- '.xlsx': 'text',
60
- '.xls': 'text'
61
- }
62
-
63
- detected_type = file_type_mappings.get(ext, 'unknown')
64
-
65
- # For JSON files, check if it's actually GeoJSON
66
- if detected_type == 'json' and file_content:
67
- try:
68
- import json
69
- content_str = file_content.decode('utf-8')
70
- data = json.loads(content_str)
71
- # Check if it has GeoJSON structure
72
- if isinstance(data, dict) and ('type' in data and data.get('type') == 'FeatureCollection'):
73
- detected_type = 'geojson'
74
- elif isinstance(data, dict) and ('type' in data and data.get('type') in ['Feature', 'Point', 'LineString', 'Polygon', 'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection']):
75
- detected_type = 'geojson'
76
- except:
77
- pass # Keep as json if parsing fails
78
-
79
- logger.info(f"Detected file type: {detected_type} for file: {filename}")
80
- return detected_type
81
-
82
- # Models - KEEP THESE
83
- class GraphState(TypedDict):
84
- query: str
85
- context: str
86
- ingestor_context: str
87
- result: str
88
- sources: Optional[List[Dict[str, str]]] # Added for ChatUI sources
89
- reports_filter: str
90
- sources_filter: str
91
- subtype_filter: str
92
- year_filter: str
93
- file_content: Optional[bytes]
94
- filename: Optional[str]
95
- metadata: Optional[Dict[str, Any]]
96
- file_type: Optional[str]
97
- workflow_type: Optional[str] # 'standard' or 'geojson_direct'
98
-
99
- class ChatFedInput(TypedDict):
100
- query: str
101
- reports_filter: Optional[str]
102
- sources_filter: Optional[str]
103
- subtype_filter: Optional[str]
104
- year_filter: Optional[str]
105
- session_id: Optional[str]
106
- user_id: Optional[str]
107
- file_content: Optional[bytes]
108
- filename: Optional[str]
109
-
110
- class ChatFedOutput(TypedDict):
111
- result: str
112
- metadata: Dict[str, Any]
113
-
114
- class ChatUIInput(BaseModel):
115
- text: str
116
-
117
- # CORE PROCESSING NODES - KEEP THESE
118
- # File type detection node
119
- def detect_file_type_node(state: GraphState) -> GraphState:
120
- """Detect file type and determine workflow"""
121
- file_type = "unknown"
122
- workflow_type = "standard"
123
-
124
- if state.get("file_content") and state.get("filename"):
125
- file_type = detect_file_type(state["filename"], state["file_content"])
126
-
127
- # Determine workflow based on file type
128
- if file_type == "geojson":
129
- workflow_type = "geojson_direct"
130
- else:
131
- workflow_type = "standard"
132
-
133
- metadata = state.get("metadata", {})
134
- metadata.update({
135
- "file_type": file_type,
136
- "workflow_type": workflow_type
137
- })
138
-
139
- return {
140
- "file_type": file_type,
141
- "workflow_type": workflow_type,
142
- "metadata": metadata
143
- }
144
-
145
- # Module functions
146
- def ingest_node(state: GraphState) -> GraphState:
147
- """Process file through appropriate ingestor based on file type"""
148
- start_time = datetime.now()
149
-
150
- # If no file provided, skip this step
151
- if not state.get("file_content") or not state.get("filename"):
152
- logger.info("No file provided, skipping ingestion")
153
- return {"ingestor_context": "", "metadata": state.get("metadata", {})}
154
-
155
- file_type = state.get("file_type", "unknown")
156
- logger.info(f"Ingesting {file_type} file: {state['filename']}")
157
-
158
- try:
159
- # Choose ingestor based on file type
160
- if file_type == "geojson":
161
- ingestor_url = GEOJSON_INGESTOR
162
- logger.info(f"Using GeoJSON ingestor: {ingestor_url}")
163
- else:
164
- ingestor_url = INGESTOR
165
- logger.info(f"Using standard ingestor: {ingestor_url}")
166
-
167
- client = Client(ingestor_url)
168
-
169
- # Create a temporary file to upload
170
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
171
- tmp_file.write(state["file_content"])
172
- tmp_file_path = tmp_file.name
173
-
174
- try:
175
- # Call the ingestor's ingest endpoint
176
- ingestor_context = client.predict(
177
- file(tmp_file_path),
178
- api_name="/ingest"
179
- )
180
-
181
- logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
182
-
183
- # Handle error cases
184
- if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
185
- raise Exception(ingestor_context)
186
-
187
- finally:
188
- # Clean up temporary file
189
- os.unlink(tmp_file_path)
190
-
191
- duration = (datetime.now() - start_time).total_seconds()
192
- metadata = state.get("metadata", {})
193
- metadata.update({
194
- "ingestion_duration": duration,
195
- "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
196
- "ingestion_success": True,
197
- "ingestor_used": ingestor_url
198
- })
199
-
200
- return {
201
- "ingestor_context": ingestor_context,
202
- "metadata": metadata
203
- }
204
-
205
- except Exception as e:
206
- duration = (datetime.now() - start_time).total_seconds()
207
- logger.error(f"Ingestion failed: {str(e)}")
208
-
209
- metadata = state.get("metadata", {})
210
- metadata.update({
211
- "ingestion_duration": duration,
212
- "ingestion_success": False,
213
- "ingestion_error": str(e)
214
- })
215
- return {"ingestor_context": "", "metadata": metadata}
216
-
217
- def geojson_direct_result_node(state: GraphState) -> GraphState:
218
- """For GeoJSON files, return ingestor results directly without retrieval/generation"""
219
- logger.info("Processing GeoJSON file - returning direct results")
220
-
221
- ingestor_context = state.get("ingestor_context", "")
222
-
223
- # For GeoJSON files, the ingestor result is the final result
224
- result = ingestor_context if ingestor_context else "No results from GeoJSON processing."
225
-
226
- metadata = state.get("metadata", {})
227
- metadata.update({
228
- "processing_type": "geojson_direct",
229
- "result_length": len(result)
230
- })
231
-
232
- return {
233
- "result": result,
234
- "metadata": metadata
235
- }
236
-
237
- def retrieve_node(state: GraphState) -> GraphState:
238
- start_time = datetime.now()
239
- logger.info(f"Retrieval: {state['query'][:50]}...")
240
-
241
- try:
242
- client = Client(RETRIEVER)
243
- context = client.predict(
244
- query=state["query"],
245
- reports_filter=state.get("reports_filter", ""),
246
- sources_filter=state.get("sources_filter", ""),
247
- subtype_filter=state.get("subtype_filter", ""),
248
- year_filter=state.get("year_filter", ""),
249
- api_name="/retrieve"
250
- )
251
-
252
- duration = (datetime.now() - start_time).total_seconds()
253
- metadata = state.get("metadata", {})
254
- metadata.update({
255
- "retrieval_duration": duration,
256
- "context_length": len(context) if context else 0,
257
- "retrieval_success": True
258
- })
259
-
260
- return {"context": context, "metadata": metadata}
261
-
262
- except Exception as e:
263
- duration = (datetime.now() - start_time).total_seconds()
264
- logger.error(f"Retrieval failed: {str(e)}")
265
-
266
- metadata = state.get("metadata", {})
267
- metadata.update({
268
- "retrieval_duration": duration,
269
- "retrieval_success": False,
270
- "retrieval_error": str(e)
271
- })
272
- return {"context": "", "metadata": metadata}
273
 
274
- # Helper function to convert retrieval context to expected format
275
- def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
276
- """Convert string context to list format expected by generator"""
277
- try:
278
- # Try to parse as list first
279
- if context.startswith('['):
280
- return ast.literal_eval(context)
281
- else:
282
- # If it's a string, wrap it in a simple format
283
- return [{
284
- "answer": context,
285
- "answer_metadata": {
286
- "filename": "Retrieved Context",
287
- "page": "Unknown",
288
- "year": "Unknown",
289
- "source": "Retriever"
290
- }
291
- }]
292
- except:
293
- # Fallback: simple string wrapping
294
- return [{
295
- "answer": context,
296
- "answer_metadata": {
297
- "filename": "Retrieved Context",
298
- "page": "Unknown",
299
- "year": "Unknown",
300
- "source": "Retriever"
301
- }
302
- }]
303
 
304
- # MAIN STREAMING GENERATOR - KEEP THIS (but consider simplifying the fallback logic)
305
  async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
306
  """Streaming version that calls generator's FastAPI endpoint"""
307
  start_time = datetime.now()
@@ -456,7 +195,7 @@ async def generate_node_streaming(state: GraphState) -> Generator[GraphState, No
456
  except json.JSONDecodeError:
457
  raise Exception(data_content)
458
 
459
- # TODO: CONSIDER REMOVING THIS GRADIO FALLBACK IF FASTAPI IS RELIABLE
460
  except Exception as fastapi_error:
461
  logger.warning(f"FastAPI endpoint failed: {fastapi_error}")
462
  logger.info("Falling back to Gradio client")
@@ -538,7 +277,7 @@ async def generate_node_streaming(state: GraphState) -> Generator[GraphState, No
538
  })
539
  yield {"result": f"Error: {str(e)}", "metadata": metadata}
540
 
541
- # Conditional routing function - KEEP THIS
542
  def route_workflow(state: GraphState) -> str:
543
  """Route to appropriate workflow based on file type"""
544
  workflow_type = state.get("workflow_type", "standard")
@@ -864,7 +603,6 @@ async def root():
864
  "message": "ChatFed Orchestrator API",
865
  "endpoints": {
866
  "health": "/health",
867
- # "chatfed": "/chatfed", # Commented out - test if ChatUI needs this
868
  "chatfed-ui-stream": "/chatfed-ui-stream",
869
  "chatfed-with-file": "/chatfed-with-file",
870
  # "chatfed-with-file-stream": "/chatfed-with-file/stream",
@@ -873,7 +611,7 @@ async def root():
873
 
874
 
875
 
876
- # # FILE UPLOAD ADAPTER - KEEP THIS
877
  async def chatfed_with_file_adapter(
878
  query: str,
879
  file_content: Optional[bytes] = None,
@@ -942,7 +680,7 @@ async def chatfed_with_file_adapter(
942
  logger.error(f"File upload streaming failed: {str(e)}")
943
  yield f"Error: {str(e)}"
944
 
945
- # TODO: PROBABLY REMOVE - NON-STREAMING FILE UPLOAD
946
  # @app.post("/chatfed-with-file")
947
  # async def chatfed_with_file(
948
  # query: str = Form(...),
@@ -984,7 +722,7 @@ async def chatfed_with_file_adapter(
984
  # media_type="text/plain"
985
  # )
986
 
987
- # MAIN FILE UPLOAD STREAMING ENDPOINT - KEEP THIS
988
  @app.post("/chatfed-with-file/stream")
989
  async def chatfed_with_file_stream(
990
  query: str = Form(...),
@@ -1050,16 +788,6 @@ async def chatfed_with_file_stream(
1050
  }
1051
  )
1052
 
1053
- # TODO: TEST IF CHATUI NEEDS THESE LANGSERVE ENDPOINTS
1054
- # If ChatUI works without these, they can be removed
1055
- # add_routes(
1056
- # app,
1057
- # RunnableLambda(process_query_langserve),
1058
- # path="/chatfed",
1059
- # input_type=ChatFedInput,
1060
- # output_type=ChatFedOutput
1061
- # )
1062
-
1063
  add_routes(
1064
  app,
1065
  RunnableLambda(chatui_adapter),
 
14
  from datetime import datetime
15
  import logging
16
  from contextlib import asynccontextmanager
17
+ # import threading
18
  from langchain_core.runnables import RunnableLambda
19
+ # import tempfile
20
+ # import mimetypes
21
  import asyncio
22
  from typing import Generator
23
  import json
24
  import httpx
25
+ # import ast
26
 
27
+ from utils import getconfig, convert_context_to_list
28
+ from nodes import detect_file_type_node, ingest_node, geojson_direct_result_node, retrieve_node
29
+ from models import GraphState, ChatFedInput, ChatFedOutput, ChatUIInput
30
 
31
  config = getconfig("params.cfg")
32
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
 
38
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
39
  logger = logging.getLogger(__name__)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # MAIN STREAMING GENERATOR
44
  async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
45
  """Streaming version that calls generator's FastAPI endpoint"""
46
  start_time = datetime.now()
 
195
  except json.JSONDecodeError:
196
  raise Exception(data_content)
197
 
198
+ # GRADIO FALLBACK
199
  except Exception as fastapi_error:
200
  logger.warning(f"FastAPI endpoint failed: {fastapi_error}")
201
  logger.info("Falling back to Gradio client")
 
277
  })
278
  yield {"result": f"Error: {str(e)}", "metadata": metadata}
279
 
280
+ # Conditional routing function
281
  def route_workflow(state: GraphState) -> str:
282
  """Route to appropriate workflow based on file type"""
283
  workflow_type = state.get("workflow_type", "standard")
 
603
  "message": "ChatFed Orchestrator API",
604
  "endpoints": {
605
  "health": "/health",
 
606
  "chatfed-ui-stream": "/chatfed-ui-stream",
607
  "chatfed-with-file": "/chatfed-with-file",
608
  # "chatfed-with-file-stream": "/chatfed-with-file/stream",
 
611
 
612
 
613
 
614
+ # # FILE UPLOAD ADAPTER
615
  async def chatfed_with_file_adapter(
616
  query: str,
617
  file_content: Optional[bytes] = None,
 
680
  logger.error(f"File upload streaming failed: {str(e)}")
681
  yield f"Error: {str(e)}"
682
 
683
+ # NON-STREAMING FILE UPLOAD
684
  # @app.post("/chatfed-with-file")
685
  # async def chatfed_with_file(
686
  # query: str = Form(...),
 
722
  # media_type="text/plain"
723
  # )
724
 
725
+ # MAIN FILE UPLOAD STREAMING ENDPOINT
726
  @app.post("/chatfed-with-file/stream")
727
  async def chatfed_with_file_stream(
728
  query: str = Form(...),
 
788
  }
789
  )
790
 
 
 
 
 
 
 
 
 
 
 
791
  add_routes(
792
  app,
793
  RunnableLambda(chatui_adapter),
app/models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+ from typing import Optional, Dict, Any, List
3
+ from typing_extensions import TypedDict
4
+ from pydantic import BaseModel
5
+
6
+ class GraphState(TypedDict):
7
+ query: str
8
+ context: str
9
+ ingestor_context: str
10
+ result: str
11
+ sources: Optional[List[Dict[str, str]]] # Added for ChatUI sources
12
+ reports_filter: str
13
+ sources_filter: str
14
+ subtype_filter: str
15
+ year_filter: str
16
+ file_content: Optional[bytes]
17
+ filename: Optional[str]
18
+ metadata: Optional[Dict[str, Any]]
19
+ file_type: Optional[str]
20
+ workflow_type: Optional[str] # 'standard' or 'geojson_direct'
21
+
22
+ class ChatFedInput(TypedDict):
23
+ query: str
24
+ reports_filter: Optional[str]
25
+ sources_filter: Optional[str]
26
+ subtype_filter: Optional[str]
27
+ year_filter: Optional[str]
28
+ session_id: Optional[str]
29
+ user_id: Optional[str]
30
+ file_content: Optional[bytes]
31
+ filename: Optional[str]
32
+
33
+ class ChatFedOutput(TypedDict):
34
+ result: str
35
+ metadata: Dict[str, Any]
36
+
37
+ class ChatUIInput(BaseModel):
38
+ text: str
app/nodes.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import detect_file_type
2
+ from models import GraphState
3
+ from datetime import datetime
4
+ import tempfile
5
+ import os
6
+ from gradio_client import Client, file
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # CORE PROCESSING NODES
12
+ #----------------------------------------
13
+ # File type detection node
14
+ def detect_file_type_node(state: GraphState) -> GraphState:
15
+ """Detect file type and determine workflow"""
16
+ file_type = "unknown"
17
+ workflow_type = "standard"
18
+
19
+ if state.get("file_content") and state.get("filename"):
20
+ file_type = detect_file_type(state["filename"], state["file_content"])
21
+
22
+ # Determine workflow based on file type
23
+ if file_type == "geojson":
24
+ workflow_type = "geojson_direct"
25
+ else:
26
+ workflow_type = "standard"
27
+
28
+ metadata = state.get("metadata", {})
29
+ metadata.update({
30
+ "file_type": file_type,
31
+ "workflow_type": workflow_type
32
+ })
33
+
34
+ return {
35
+ "file_type": file_type,
36
+ "workflow_type": workflow_type,
37
+ "metadata": metadata
38
+ }
39
+
40
+ # Module functions
41
+ def ingest_node(state: GraphState) -> GraphState:
42
+ """Process file through appropriate ingestor based on file type"""
43
+ start_time = datetime.now()
44
+
45
+ # If no file provided, skip this step
46
+ if not state.get("file_content") or not state.get("filename"):
47
+ logger.info("No file provided, skipping ingestion")
48
+ return {"ingestor_context": "", "metadata": state.get("metadata", {})}
49
+
50
+ file_type = state.get("file_type", "unknown")
51
+ logger.info(f"Ingesting {file_type} file: {state['filename']}")
52
+
53
+ try:
54
+ # Choose ingestor based on file type
55
+ if file_type == "geojson":
56
+ ingestor_url = GEOJSON_INGESTOR
57
+ logger.info(f"Using GeoJSON ingestor: {ingestor_url}")
58
+ else:
59
+ ingestor_url = INGESTOR
60
+ logger.info(f"Using standard ingestor: {ingestor_url}")
61
+
62
+ client = Client(ingestor_url)
63
+
64
+ # Create a temporary file to upload
65
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
66
+ tmp_file.write(state["file_content"])
67
+ tmp_file_path = tmp_file.name
68
+
69
+ try:
70
+ # Call the ingestor's ingest endpoint
71
+ ingestor_context = client.predict(
72
+ file(tmp_file_path),
73
+ api_name="/ingest"
74
+ )
75
+
76
+ logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
77
+
78
+ # Handle error cases
79
+ if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
80
+ raise Exception(ingestor_context)
81
+
82
+ finally:
83
+ # Clean up temporary file
84
+ os.unlink(tmp_file_path)
85
+
86
+ duration = (datetime.now() - start_time).total_seconds()
87
+ metadata = state.get("metadata", {})
88
+ metadata.update({
89
+ "ingestion_duration": duration,
90
+ "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
91
+ "ingestion_success": True,
92
+ "ingestor_used": ingestor_url
93
+ })
94
+
95
+ return {
96
+ "ingestor_context": ingestor_context,
97
+ "metadata": metadata
98
+ }
99
+
100
+ except Exception as e:
101
+ duration = (datetime.now() - start_time).total_seconds()
102
+ logger.error(f"Ingestion failed: {str(e)}")
103
+
104
+ metadata = state.get("metadata", {})
105
+ metadata.update({
106
+ "ingestion_duration": duration,
107
+ "ingestion_success": False,
108
+ "ingestion_error": str(e)
109
+ })
110
+ return {"ingestor_context": "", "metadata": metadata}
111
+
112
+ def geojson_direct_result_node(state: GraphState) -> GraphState:
113
+ """For GeoJSON files, return ingestor results directly without retrieval/generation"""
114
+ logger.info("Processing GeoJSON file - returning direct results")
115
+
116
+ ingestor_context = state.get("ingestor_context", "")
117
+
118
+ # For GeoJSON files, the ingestor result is the final result
119
+ result = ingestor_context if ingestor_context else "No results from GeoJSON processing."
120
+
121
+ metadata = state.get("metadata", {})
122
+ metadata.update({
123
+ "processing_type": "geojson_direct",
124
+ "result_length": len(result)
125
+ })
126
+
127
+ return {
128
+ "result": result,
129
+ "metadata": metadata
130
+ }
131
+
132
+ def retrieve_node(state: GraphState) -> GraphState:
133
+ start_time = datetime.now()
134
+ logger.info(f"Retrieval: {state['query'][:50]}...")
135
+
136
+ try:
137
+ client = Client(RETRIEVER)
138
+ context = client.predict(
139
+ query=state["query"],
140
+ reports_filter=state.get("reports_filter", ""),
141
+ sources_filter=state.get("sources_filter", ""),
142
+ subtype_filter=state.get("subtype_filter", ""),
143
+ year_filter=state.get("year_filter", ""),
144
+ api_name="/retrieve"
145
+ )
146
+
147
+ duration = (datetime.now() - start_time).total_seconds()
148
+ metadata = state.get("metadata", {})
149
+ metadata.update({
150
+ "retrieval_duration": duration,
151
+ "context_length": len(context) if context else 0,
152
+ "retrieval_success": True
153
+ })
154
+
155
+ return {"context": context, "metadata": metadata}
156
+
157
+ except Exception as e:
158
+ duration = (datetime.now() - start_time).total_seconds()
159
+ logger.error(f"Retrieval failed: {str(e)}")
160
+
161
+ metadata = state.get("metadata", {})
162
+ metadata.update({
163
+ "retrieval_duration": duration,
164
+ "retrieval_success": False,
165
+ "retrieval_error": str(e)
166
+ })
167
+ return {"context": "", "metadata": metadata}
app/utils.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import ast
5
  import re
6
  from dotenv import load_dotenv
 
7
 
8
  # Local .env file
9
  load_dotenv()
@@ -44,3 +45,74 @@ def get_auth(provider: str) -> dict:
44
 
45
  return auth_config
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import ast
5
  import re
6
  from dotenv import load_dotenv
7
+ from typing import Optional, Dict, Any, List
8
 
9
  # Local .env file
10
  load_dotenv()
 
45
 
46
  return auth_config
47
 
48
+ # File type detection
49
+ def detect_file_type(filename: str, file_content: bytes = None) -> str:
50
+ """Detect file type based on extension and content"""
51
+ if not filename:
52
+ return "unknown"
53
+
54
+ # Get file extension
55
+ _, ext = os.path.splitext(filename.lower())
56
+
57
+ # Define file type mappings
58
+ file_type_mappings = {
59
+ '.geojson': 'geojson',
60
+ '.json': 'json', # Could be geojson, will check content
61
+ '.pdf': 'text',
62
+ '.docx': 'text',
63
+ '.doc': 'text',
64
+ '.txt': 'text',
65
+ '.md': 'text',
66
+ '.csv': 'text',
67
+ '.xlsx': 'text',
68
+ '.xls': 'text'
69
+ }
70
+
71
+ detected_type = file_type_mappings.get(ext, 'unknown')
72
+
73
+ # For JSON files, check if it's actually GeoJSON
74
+ if detected_type == 'json' and file_content:
75
+ try:
76
+ import json
77
+ content_str = file_content.decode('utf-8')
78
+ data = json.loads(content_str)
79
+ # Check if it has GeoJSON structure
80
+ if isinstance(data, dict) and ('type' in data and data.get('type') == 'FeatureCollection'):
81
+ detected_type = 'geojson'
82
+ elif isinstance(data, dict) and ('type' in data and data.get('type') in ['Feature', 'Point', 'LineString', 'Polygon', 'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection']):
83
+ detected_type = 'geojson'
84
+ except:
85
+ pass # Keep as json if parsing fails
86
+
87
+ logger.info(f"Detected file type: {detected_type} for file: {filename}")
88
+ return detected_type
89
+
90
+ # Helper function to convert retrieval context to expected format
91
+ def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
92
+ """Convert string context to list format expected by generator"""
93
+ try:
94
+ # Try to parse as list first
95
+ if context.startswith('['):
96
+ return ast.literal_eval(context)
97
+ else:
98
+ # If it's a string, wrap it in a simple format
99
+ return [{
100
+ "answer": context,
101
+ "answer_metadata": {
102
+ "filename": "Retrieved Context",
103
+ "page": "Unknown",
104
+ "year": "Unknown",
105
+ "source": "Retriever"
106
+ }
107
+ }]
108
+ except:
109
+ # Fallback: simple string wrapping
110
+ return [{
111
+ "answer": context,
112
+ "answer_metadata": {
113
+ "filename": "Retrieved Context",
114
+ "page": "Unknown",
115
+ "year": "Unknown",
116
+ "source": "Retriever"
117
+ }
118
+ }]