feat: add cold start
Browse files- app.py +1 -0
- class_mod/rest_qdrant.py +2 -2
- data_importer.py +20 -3
- utils/llm_caller.py +1 -0
app.py
CHANGED
|
@@ -48,6 +48,7 @@ def greet_json():
|
|
| 48 |
|
| 49 |
@app.post("/v1/generateTripPlan", response_model=PlanResponse)
|
| 50 |
def generate_trip_plan(request: PlanRequest):
|
|
|
|
| 51 |
try:
|
| 52 |
trip_plan = asyncio.run(agent.query_with_rag(request))
|
| 53 |
return PlanResponse(tripOverview=trip_plan.tripOverview,
|
|
|
|
| 48 |
|
| 49 |
@app.post("/v1/generateTripPlan", response_model=PlanResponse)
|
| 50 |
def generate_trip_plan(request: PlanRequest):
|
| 51 |
+
data_importer.coldStartDatabase()
|
| 52 |
try:
|
| 53 |
trip_plan = asyncio.run(agent.query_with_rag(request))
|
| 54 |
return PlanResponse(tripOverview=trip_plan.tripOverview,
|
class_mod/rest_qdrant.py
CHANGED
|
@@ -19,7 +19,7 @@ class RestQdrantClient:
|
|
| 19 |
r = self.session.get(f"{self.url}/collections/{collection_name}", timeout=self.timeout)
|
| 20 |
r.raise_for_status()
|
| 21 |
return r.json()
|
| 22 |
-
def search(self, collection_name, query_vector, limit=10, with_payload=True):
|
| 23 |
payload = {
|
| 24 |
"vector": query_vector,
|
| 25 |
"limit": limit,
|
|
@@ -28,7 +28,7 @@ class RestQdrantClient:
|
|
| 28 |
r = self.session.post(
|
| 29 |
f"{self.url}/collections/{collection_name}/points/search",
|
| 30 |
json=payload,
|
| 31 |
-
timeout=
|
| 32 |
)
|
| 33 |
r.raise_for_status()
|
| 34 |
return r.json()
|
|
|
|
| 19 |
r = self.session.get(f"{self.url}/collections/{collection_name}", timeout=self.timeout)
|
| 20 |
r.raise_for_status()
|
| 21 |
return r.json()
|
| 22 |
+
def search(self, collection_name, query_vector, limit=10, with_payload=True,timeout=1):
|
| 23 |
payload = {
|
| 24 |
"vector": query_vector,
|
| 25 |
"limit": limit,
|
|
|
|
| 28 |
r = self.session.post(
|
| 29 |
f"{self.url}/collections/{collection_name}/points/search",
|
| 30 |
json=payload,
|
| 31 |
+
timeout=timeout
|
| 32 |
)
|
| 33 |
r.raise_for_status()
|
| 34 |
return r.json()
|
data_importer.py
CHANGED
|
@@ -57,8 +57,8 @@ class DataImporter:
|
|
| 57 |
def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
|
| 58 |
point_id = custom_id or str(uuid.uuid4())
|
| 59 |
embedding = self.encode_text(text)[0]
|
| 60 |
-
|
| 61 |
payload = {"text": text}
|
|
|
|
| 62 |
if metadata:
|
| 63 |
payload.update(metadata)
|
| 64 |
|
|
@@ -101,7 +101,7 @@ class DataImporter:
|
|
| 101 |
print(f"Error extracting from YouTube: {e}")
|
| 102 |
return None
|
| 103 |
|
| 104 |
-
def search_similar(self, query: str, limit: int =
|
| 105 |
"""Search with Qdrant availability check - always returns a list"""
|
| 106 |
if not self.qdrant_available or not self.client:
|
| 107 |
print("Warning: Qdrant not available, returning empty results")
|
|
@@ -113,7 +113,8 @@ class DataImporter:
|
|
| 113 |
results = self.client.search(
|
| 114 |
collection_name=self.collection_name,
|
| 115 |
query_vector=query_embedding,
|
| 116 |
-
limit=limit
|
|
|
|
| 117 |
)
|
| 118 |
print(f"Search results: {results}")
|
| 119 |
return [
|
|
@@ -128,3 +129,19 @@ class DataImporter:
|
|
| 128 |
except Exception as e:
|
| 129 |
print(f"Error searching: {e}")
|
| 130 |
raise ValueError(f"Search failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def insert_text(self, text: str, metadata: Optional[Dict] = None, custom_id: Optional[str] = None) -> str:
|
| 58 |
point_id = custom_id or str(uuid.uuid4())
|
| 59 |
embedding = self.encode_text(text)[0]
|
|
|
|
| 60 |
payload = {"text": text}
|
| 61 |
+
|
| 62 |
if metadata:
|
| 63 |
payload.update(metadata)
|
| 64 |
|
|
|
|
| 101 |
print(f"Error extracting from YouTube: {e}")
|
| 102 |
return None
|
| 103 |
|
| 104 |
+
def search_similar(self, query: str, limit: int = 1) -> List[Dict]:
|
| 105 |
"""Search with Qdrant availability check - always returns a list"""
|
| 106 |
if not self.qdrant_available or not self.client:
|
| 107 |
print("Warning: Qdrant not available, returning empty results")
|
|
|
|
| 113 |
results = self.client.search(
|
| 114 |
collection_name=self.collection_name,
|
| 115 |
query_vector=query_embedding,
|
| 116 |
+
limit=limit,
|
| 117 |
+
timeout=15
|
| 118 |
)
|
| 119 |
print(f"Search results: {results}")
|
| 120 |
return [
|
|
|
|
| 129 |
except Exception as e:
|
| 130 |
print(f"Error searching: {e}")
|
| 131 |
raise ValueError(f"Search failed: {str(e)}")
|
| 132 |
+
|
| 133 |
+
def coldStartDatabase(self):
|
| 134 |
+
coldstart_texts = "I want to go to Chiang Mai"
|
| 135 |
+
try:
|
| 136 |
+
query_embedding = self.encode_text(coldstart_texts)[0]
|
| 137 |
+
results = self.client.search(
|
| 138 |
+
collection_name=self.collection_name,
|
| 139 |
+
query_vector=query_embedding,
|
| 140 |
+
limit=1,
|
| 141 |
+
timeout=1
|
| 142 |
+
)
|
| 143 |
+
print(f"Cold start results: {results}")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"finish cold start, with error: {e}")
|
| 146 |
+
|
| 147 |
+
|
utils/llm_caller.py
CHANGED
|
@@ -96,6 +96,7 @@ class LLMCaller:
|
|
| 96 |
query_vector=query_embedding,
|
| 97 |
limit=top_k,
|
| 98 |
with_payload=True,
|
|
|
|
| 99 |
)
|
| 100 |
|
| 101 |
# 4. Convert search results to RetrievedItem format
|
|
|
|
| 96 |
query_vector=query_embedding,
|
| 97 |
limit=top_k,
|
| 98 |
with_payload=True,
|
| 99 |
+
timeout=15
|
| 100 |
)
|
| 101 |
|
| 102 |
# 4. Convert search results to RetrievedItem format
|