taspol commited on
Commit
152241c
·
1 Parent(s): c6f0ab2

feat: add cold start

Browse files
Files changed (4) hide show
  1. app.py +1 -0
  2. class_mod/rest_qdrant.py +2 -2
  3. data_importer.py +20 -3
  4. utils/llm_caller.py +1 -0
app.py CHANGED
@@ -48,6 +48,7 @@ def greet_json():
48
 
49
  @app.post("/v1/generateTripPlan", response_model=PlanResponse)
50
  def generate_trip_plan(request: PlanRequest):
 
51
  try:
52
  trip_plan = asyncio.run(agent.query_with_rag(request))
53
  return PlanResponse(tripOverview=trip_plan.tripOverview,
 
48
 
49
  @app.post("/v1/generateTripPlan", response_model=PlanResponse)
50
  def generate_trip_plan(request: PlanRequest):
51
+ data_importer.coldStartDatabase()
52
  try:
53
  trip_plan = asyncio.run(agent.query_with_rag(request))
54
  return PlanResponse(tripOverview=trip_plan.tripOverview,
class_mod/rest_qdrant.py CHANGED
@@ -19,7 +19,7 @@ class RestQdrantClient:
19
  r = self.session.get(f"{self.url}/collections/{collection_name}", timeout=self.timeout)
20
  r.raise_for_status()
21
  return r.json()
22
- def search(self, collection_name, query_vector, limit=10, with_payload=True):
23
  payload = {
24
  "vector": query_vector,
25
  "limit": limit,
@@ -28,7 +28,7 @@ class RestQdrantClient:
28
  r = self.session.post(
29
  f"{self.url}/collections/{collection_name}/points/search",
30
  json=payload,
31
- timeout=self.timeout
32
  )
33
  r.raise_for_status()
34
  return r.json()
 
19
  r = self.session.get(f"{self.url}/collections/{collection_name}", timeout=self.timeout)
20
  r.raise_for_status()
21
  return r.json()
22
+ def search(self, collection_name, query_vector, limit=10, with_payload=True,timeout=1):
23
  payload = {
24
  "vector": query_vector,
25
  "limit": limit,
 
28
  r = self.session.post(
29
  f"{self.url}/collections/{collection_name}/points/search",
30
  json=payload,
31
+ timeout=timeout
32
  )
33
  r.raise_for_status()
34
  return r.json()
data_importer.py CHANGED
@@ -57,8 +57,8 @@ class DataImporter:
57
  def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
58
  point_id = custom_id or str(uuid.uuid4())
59
  embedding = self.encode_text(text)[0]
60
-
61
  payload = {"text": text}
 
62
  if metadata:
63
  payload.update(metadata)
64
 
@@ -101,7 +101,7 @@ class DataImporter:
101
  print(f"Error extracting from YouTube: {e}")
102
  return None
103
 
104
- def search_similar(self, query: str, limit: int = 5) -> List[Dict]:
105
  """Search with Qdrant availability check - always returns a list"""
106
  if not self.qdrant_available or not self.client:
107
  print("Warning: Qdrant not available, returning empty results")
@@ -113,7 +113,8 @@ class DataImporter:
113
  results = self.client.search(
114
  collection_name=self.collection_name,
115
  query_vector=query_embedding,
116
- limit=limit
 
117
  )
118
  print(f"Search results: {results}")
119
  return [
@@ -128,3 +129,19 @@ class DataImporter:
128
  except Exception as e:
129
  print(f"Error searching: {e}")
130
  raise ValueError(f"Search failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
58
  point_id = custom_id or str(uuid.uuid4())
59
  embedding = self.encode_text(text)[0]
 
60
  payload = {"text": text}
61
+
62
  if metadata:
63
  payload.update(metadata)
64
 
 
101
  print(f"Error extracting from YouTube: {e}")
102
  return None
103
 
104
+ def search_similar(self, query: str, limit: int = 1) -> List[Dict]:
105
  """Search with Qdrant availability check - always returns a list"""
106
  if not self.qdrant_available or not self.client:
107
  print("Warning: Qdrant not available, returning empty results")
 
113
  results = self.client.search(
114
  collection_name=self.collection_name,
115
  query_vector=query_embedding,
116
+ limit=limit,
117
+ timeout=15
118
  )
119
  print(f"Search results: {results}")
120
  return [
 
129
  except Exception as e:
130
  print(f"Error searching: {e}")
131
  raise ValueError(f"Search failed: {str(e)}")
132
+
133
+ def coldStartDatabase(self):
134
+ coldstart_texts = "I want to go to Chiang Mai"
135
+ try:
136
+ query_embedding = self.encode_text(coldstart_texts)[0]
137
+ results = self.client.search(
138
+ collection_name=self.collection_name,
139
+ query_vector=query_embedding,
140
+ limit=1,
141
+ timeout=1
142
+ )
143
+ print(f"Cold start results: {results}")
144
+ except Exception as e:
145
+ print(f"finish cold start, with error: {e}")
146
+
147
+
utils/llm_caller.py CHANGED
@@ -96,6 +96,7 @@ class LLMCaller:
96
  query_vector=query_embedding,
97
  limit=top_k,
98
  with_payload=True,
 
99
  )
100
 
101
  # 4. Convert search results to RetrievedItem format
 
96
  query_vector=query_embedding,
97
  limit=top_k,
98
  with_payload=True,
99
+ timeout=15
100
  )
101
 
102
  # 4. Convert search results to RetrievedItem format