mtyrrell commited on
Commit
1900b1d
·
1 Parent(s): 687387a

retriever adapter

Browse files
Files changed (5) hide show
  1. .gitignore +2 -1
  2. app/models.py +2 -1
  3. app/nodes.py +16 -12
  4. app/retriever_adapter.py +73 -0
  5. params.cfg +2 -1
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .env
2
- *.DS_Store
 
 
1
  .env
2
+ *.DS_Store
3
+ __pycache__/
app/models.py CHANGED
@@ -18,6 +18,8 @@ class GraphState(TypedDict):
18
  metadata: Optional[Dict[str, Any]]
19
  file_type: Optional[str]
20
  workflow_type: Optional[str] # 'standard' or 'geojson_direct'
 
 
21
 
22
  class ChatUIInput(BaseModel):
23
  """Input model for text-only ChatUI requests"""
@@ -27,4 +29,3 @@ class ChatUIFileInput(BaseModel):
27
  """Input model for ChatUI requests with file attachments"""
28
  text: str
29
  files: Optional[List[Dict[str, Any]]] = None
30
-
 
18
  metadata: Optional[Dict[str, Any]]
19
  file_type: Optional[str]
20
  workflow_type: Optional[str] # 'standard' or 'geojson_direct'
21
+ metadata_filters: Optional[Dict[str, Any]]
22
+ metadata: Dict[str, Any]
23
 
24
  class ChatUIInput(BaseModel):
25
  """Input model for text-only ChatUI requests"""
 
29
  """Input model for ChatUI requests with file attachments"""
30
  text: str
31
  files: Optional[List[Dict[str, Any]]] = None
 
app/nodes.py CHANGED
@@ -1,8 +1,7 @@
1
- from utils import detect_file_type, convert_context_to_list, merge_state, getconfig
2
- from models import GraphState
3
- from datetime import datetime
4
  import tempfile
5
  import os
 
 
6
  from gradio_client import Client, file
7
  import logging
8
  import dotenv
@@ -10,6 +9,9 @@ import httpx
10
  import json
11
  from typing import Generator, Optional
12
 
 
 
 
13
  dotenv.load_dotenv()
14
 
15
  logger = logging.getLogger(__name__)
@@ -22,6 +24,7 @@ INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed
22
  GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
23
  MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
24
 
 
25
 
26
  #----------------------------------------
27
  # LANGGRAPH NODE FUNCTIONS
@@ -122,19 +125,18 @@ def geojson_direct_result_node(state: GraphState) -> GraphState:
122
 
123
 
124
  def retrieve_node(state: GraphState) -> GraphState:
125
- """Retrieve relevant context from vector store"""
126
  start_time = datetime.now()
127
  logger.info(f"Retrieval: {state['query'][:50]}...")
128
 
129
  try:
130
- client = Client(RETRIEVER, hf_token=os.getenv("HF_TOKEN"))
131
- context = client.predict(
 
 
132
  query=state["query"],
133
- reports_filter=state.get("reports_filter", ""),
134
- sources_filter=state.get("sources_filter", ""),
135
- subtype_filter=state.get("subtype_filter", ""),
136
- year_filter=state.get("year_filter", ""),
137
- api_name="/retrieve"
138
  )
139
 
140
  duration = (datetime.now() - start_time).total_seconds()
@@ -142,7 +144,9 @@ def retrieve_node(state: GraphState) -> GraphState:
142
  metadata.update({
143
  "retrieval_duration": duration,
144
  "context_length": len(context) if context else 0,
145
- "retrieval_success": True
 
 
146
  })
147
 
148
  return {"context": context, "metadata": metadata}
 
 
 
 
1
  import tempfile
2
  import os
3
+ from models import GraphState
4
+ from datetime import datetime
5
  from gradio_client import Client, file
6
  import logging
7
  import dotenv
 
9
  import json
10
  from typing import Generator, Optional
11
 
12
+ from utils import detect_file_type, convert_context_to_list, merge_state, getconfig
13
+ from retriever_adapter import RetrieverAdapter
14
+
15
  dotenv.load_dotenv()
16
 
17
  logger = logging.getLogger(__name__)
 
24
  GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
25
  MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
26
 
27
+ retriever_adapter = RetrieverAdapter("params.cfg")
28
 
29
  #----------------------------------------
30
  # LANGGRAPH NODE FUNCTIONS
 
125
 
126
 
127
  def retrieve_node(state: GraphState) -> GraphState:
128
+ """Retrieve relevant context using adapter"""
129
  start_time = datetime.now()
130
  logger.info(f"Retrieval: {state['query'][:50]}...")
131
 
132
  try:
133
+ # Get filters from state (provided by ChatUI or LLM agent)
134
+ filters = state.get("metadata_filters")
135
+
136
+ context = retriever_adapter.retrieve(
137
  query=state["query"],
138
+ filters=filters,
139
+ hf_token=os.getenv("HF_TOKEN")
 
 
 
140
  )
141
 
142
  duration = (datetime.now() - start_time).total_seconds()
 
144
  metadata.update({
145
  "retrieval_duration": duration,
146
  "context_length": len(context) if context else 0,
147
+ "retrieval_success": True,
148
+ "filters_applied": filters,
149
+ "retriever_config": retriever_adapter.get_metadata()
150
  })
151
 
152
  return {"context": context, "metadata": metadata}
app/retriever_adapter.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Any, Optional
3
+ from gradio_client import Client
4
+ from utils import getconfig
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class RetrieverAdapter:
9
+ """
10
+ Simple adapter that passes query and filters to retriever.
11
+ Filters are provided by ChatUI frontend or LLM agent, not config.
12
+ """
13
+
14
+ def __init__(self, config_path: str = "params.cfg"):
15
+ self.config = getconfig(config_path)
16
+ self.retriever_url = self.config.get("retriever", "RETRIEVER")
17
+ self.collection_name = self.config.get("retriever", "COLLECTION_NAME", fallback=None)
18
+
19
+ logger.info(f"RetrieverAdapter initialized: {self.retriever_url}")
20
+ if self.collection_name:
21
+ logger.info(f"Collection: {self.collection_name}")
22
+
23
+ def retrieve(
24
+ self,
25
+ query: str,
26
+ filters: Optional[Dict[str, Any]] = None,
27
+ hf_token: Optional[str] = None
28
+ ) -> str:
29
+ """
30
+ Execute retrieval with query and optional filters.
31
+
32
+ Args:
33
+ query: Search query
34
+ filters: Metadata filters dict (from ChatUI or LLM agent)
35
+ hf_token: HuggingFace token for authentication
36
+
37
+ Returns:
38
+ Retrieved context string
39
+ """
40
+ try:
41
+ client = Client(self.retriever_url, hf_token=hf_token)
42
+
43
+ # Build parameters - always include query
44
+ params = {"query": query}
45
+
46
+ # Add collection name if configured
47
+ if self.collection_name:
48
+ params["collection_name"] = self.collection_name
49
+
50
+ # Add filters if provided
51
+ if filters:
52
+ params["filters"] = filters
53
+
54
+ params["api_name"] = "/retrieve"
55
+
56
+ # Log request
57
+ log_params = {k: f"<dict with {len(v)} keys>" if isinstance(v, dict) else v
58
+ for k, v in params.items() if k != "api_name"}
59
+ logger.info(f"Retrieval request: {log_params}")
60
+
61
+ context = client.predict(**params)
62
+ return context
63
+
64
+ except Exception as e:
65
+ logger.error(f"Retrieval failed: {str(e)}")
66
+ raise
67
+
68
+ def get_metadata(self) -> Dict[str, Any]:
69
+ """Return retriever configuration metadata"""
70
+ return {
71
+ "retriever_url": self.retriever_url,
72
+ "collection_name": self.collection_name
73
+ }
params.cfg CHANGED
@@ -1,5 +1,6 @@
1
  [retriever]
2
- RETRIEVER = https://giz-chatfed-retriever0-3.hf.space
 
3
 
4
  [generator]
5
  GENERATOR = https://giz-eudr-chabo-generator.hf.space
 
1
  [retriever]
2
+ RETRIEVER = https://giz-chatfed-retriever0-3.hf.space/
3
+ COLLECTION_NAME = EUDR # Optional, only if retriever needs it
4
 
5
  [generator]
6
  GENERATOR = https://giz-eudr-chabo-generator.hf.space