Spaces:
Running
Running
Commit
·
683f058
1
Parent(s):
b2a4481
Fix chat message with RAG
Browse files- app/configs/pinecone.py +2 -1
- app/domains/chat_message/controller.py +35 -35
- app/domains/chat_message/service.py +109 -26
- app/domains/properties/service.py +15 -24
- app/domains/user_action/service.py +7 -7
- app/seed/factories/article.py +88 -111
app/configs/pinecone.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
import pinecone
|
| 3 |
pc = pinecone.Pinecone(os.getenv("PINECONE_API_KEY"))
|
| 4 |
-
property_index = pc.Index("properties")
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import pinecone
|
| 3 |
pc = pinecone.Pinecone(os.getenv("PINECONE_API_KEY"))
|
| 4 |
+
property_index = pc.Index("properties")
|
| 5 |
+
article_index = pc.Index("articles")
|
app/domains/chat_message/controller.py
CHANGED
|
@@ -39,21 +39,34 @@ class ChatMessageController(Controller):
|
|
| 39 |
if not user.device_token:
|
| 40 |
return
|
| 41 |
notify_service = NotificationService()
|
| 42 |
-
title = "
|
| 43 |
-
body =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
notify_service.send_to_token(
|
| 45 |
token=user.device_token,
|
| 46 |
title=title,
|
| 47 |
body=body,
|
| 48 |
data={
|
| 49 |
"type": "chat",
|
| 50 |
-
"
|
| 51 |
"sender_id": str(message.sender_id),
|
| 52 |
"chat_session_id": str(message.session_id),
|
| 53 |
"created_at": str(message.created_at.timestamp()),
|
| 54 |
},
|
| 55 |
)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
@post("")
|
| 58 |
async def create_message(
|
| 59 |
self,
|
|
@@ -64,42 +77,29 @@ class ChatMessageController(Controller):
|
|
| 64 |
chat_service: ChatMessageService,
|
| 65 |
chat_session_service: ChatSessionService,
|
| 66 |
) -> Response:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
| 84 |
),
|
| 85 |
)
|
| 86 |
-
message = await chat_service.ai_respond_to_user(data, request.user.id)
|
| 87 |
return Response(
|
| 88 |
chat_service.to_schema(message, schema_type=ChatMessageSchema),
|
| 89 |
-
background=BackgroundTasks(
|
| 90 |
-
[
|
| 91 |
-
BackgroundTask(
|
| 92 |
-
chat_session_service.update_last_message,
|
| 93 |
-
message.session_id,
|
| 94 |
-
message,
|
| 95 |
-
),
|
| 96 |
-
BackgroundTask(
|
| 97 |
-
self.notify_message,
|
| 98 |
-
request.user,
|
| 99 |
-
message,
|
| 100 |
-
),
|
| 101 |
-
]
|
| 102 |
-
),
|
| 103 |
)
|
| 104 |
|
| 105 |
@post("/ai", no_auth=True, status_code=HTTP_200_OK)
|
|
|
|
| 39 |
if not user.device_token:
|
| 40 |
return
|
| 41 |
notify_service = NotificationService()
|
| 42 |
+
title = "AI Assistant"
|
| 43 |
+
body = (
|
| 44 |
+
f"You have a new message from {user.name}."
|
| 45 |
+
if message.sender_id
|
| 46 |
+
else "AI has the answer you need"
|
| 47 |
+
)
|
| 48 |
notify_service.send_to_token(
|
| 49 |
token=user.device_token,
|
| 50 |
title=title,
|
| 51 |
body=body,
|
| 52 |
data={
|
| 53 |
"type": "chat",
|
| 54 |
+
"id": str(message.id),
|
| 55 |
"sender_id": str(message.sender_id),
|
| 56 |
"chat_session_id": str(message.session_id),
|
| 57 |
"created_at": str(message.created_at.timestamp()),
|
| 58 |
},
|
| 59 |
)
|
| 60 |
|
| 61 |
+
async def chat_with_ai(
|
| 62 |
+
self,
|
| 63 |
+
data: CreateMessageDTO,
|
| 64 |
+
user: User,
|
| 65 |
+
chat_service: ChatMessageService,
|
| 66 |
+
):
|
| 67 |
+
message = await chat_service.ai_respond_to_user(data, user_id=user.id)
|
| 68 |
+
self.notify_message(user, message)
|
| 69 |
+
|
| 70 |
@post("")
|
| 71 |
async def create_message(
|
| 72 |
self,
|
|
|
|
| 77 |
chat_service: ChatMessageService,
|
| 78 |
chat_session_service: ChatSessionService,
|
| 79 |
) -> Response:
|
| 80 |
+
message = await chat_service.create_message(data, request.user.id)
|
| 81 |
+
background_task_list = [
|
| 82 |
+
BackgroundTask(
|
| 83 |
+
chat_session_service.update_last_message,
|
| 84 |
+
message.session_id,
|
| 85 |
+
message,
|
| 86 |
+
),
|
| 87 |
+
]
|
| 88 |
+
if data.is_ai:
|
| 89 |
+
background_task_list.append(
|
| 90 |
+
BackgroundTask(self.chat_with_ai, data, request.user, chat_service)
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
background_task_list.append(
|
| 94 |
+
BackgroundTask(
|
| 95 |
+
self.notify_message,
|
| 96 |
+
request.user,
|
| 97 |
+
message,
|
| 98 |
),
|
| 99 |
)
|
|
|
|
| 100 |
return Response(
|
| 101 |
chat_service.to_schema(message, schema_type=ChatMessageSchema),
|
| 102 |
+
background=BackgroundTasks(background_task_list),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
)
|
| 104 |
|
| 105 |
@post("/ai", no_auth=True, status_code=HTTP_200_OK)
|
app/domains/chat_message/service.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
from collections.abc import AsyncGenerator
|
| 2 |
from datetime import datetime
|
| 3 |
-
from typing import Dict, List
|
| 4 |
import uuid
|
| 5 |
from venv import logger
|
| 6 |
from sqlalchemy.dialects import postgresql # or mysql, sqlite depending on your DB
|
| 7 |
from sqlalchemy import and_, desc, or_, select
|
| 8 |
from sqlalchemy.orm import noload
|
| 9 |
-
from transformers import pipeline
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
from database.models.property import Property
|
| 12 |
from domains.properties.service import PropertyService
|
| 13 |
from domains.chat_session.service import ChatSessionService
|
|
@@ -21,6 +23,7 @@ from domains.supabase.service import SupabaseService, provide_supabase_service
|
|
| 21 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 22 |
from google.genai import types
|
| 23 |
from configs.gemai import client
|
|
|
|
| 24 |
import re
|
| 25 |
from litestar.exceptions import ValidationException, InternalServerException
|
| 26 |
import requests
|
|
@@ -100,7 +103,7 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 100 |
"model_id": message.id,
|
| 101 |
}
|
| 102 |
for image in data.image_list
|
| 103 |
-
],
|
| 104 |
)
|
| 105 |
return message
|
| 106 |
except Exception as e:
|
|
@@ -297,9 +300,8 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 297 |
|
| 298 |
async def summarize_session(self, session_id: uuid.UUID) -> str:
|
| 299 |
"""
|
| 300 |
-
Summarize the entire chat session
|
| 301 |
"""
|
| 302 |
-
# Fetch all messages ordered oldest first
|
| 303 |
query = (
|
| 304 |
select(ChatMessage)
|
| 305 |
.where(ChatMessage.session_id == session_id)
|
|
@@ -307,17 +309,39 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 307 |
)
|
| 308 |
result = await self.repository.session.execute(query)
|
| 309 |
messages: List[ChatMessage] = result.scalars().all()
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
)
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
async def build_chat_context(
|
| 323 |
self, session_id: uuid.UUID, window_size: int = 10
|
|
@@ -374,8 +398,10 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 374 |
context = await self.build_chat_context(data.session_id, window_size)
|
| 375 |
else:
|
| 376 |
context = []
|
|
|
|
| 377 |
context.append(UserContent(data.content))
|
| 378 |
-
system_instruction = """
|
|
|
|
| 379 |
Always respond helpfully. If suggestions are requested, at the very end append exactly one line with
|
| 380 |
#PROPERTY_CRITERIA:<json>
|
| 381 |
where <json> exactly matches the PropertySchema fields:
|
|
@@ -391,14 +417,19 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 391 |
"average_rating": number,
|
| 392 |
"status": boolean,
|
| 393 |
}
|
| 394 |
-
If not, do not append the tag.
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
try:
|
| 397 |
response = client.models.generate_content(
|
| 398 |
model="gemini-2.0-flash",
|
| 399 |
contents=context,
|
| 400 |
config=GenerateContentConfig(
|
| 401 |
-
tools=[Tool(google_search=GoogleSearch())],
|
| 402 |
system_instruction=system_instruction,
|
| 403 |
),
|
| 404 |
)
|
|
@@ -414,15 +445,11 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 414 |
pass
|
| 415 |
raise
|
| 416 |
assistant_text = response.text
|
| 417 |
-
|
| 418 |
-
CreateMessageDTO(session_id=data.session_id, content=data.content),
|
| 419 |
-
user_id,
|
| 420 |
-
auto_commit=False,
|
| 421 |
-
)
|
| 422 |
message = await self.create(
|
| 423 |
{
|
| 424 |
"content": assistant_text,
|
| 425 |
-
"session_id":
|
| 426 |
"sender_id": None,
|
| 427 |
}
|
| 428 |
)
|
|
@@ -434,6 +461,62 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
|
|
| 434 |
finally:
|
| 435 |
await self.repository.session.commit()
|
| 436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
async def provide_chat_message_service(
|
| 439 |
db_session: AsyncSession,
|
|
|
|
| 1 |
from collections.abc import AsyncGenerator
|
| 2 |
from datetime import datetime
|
| 3 |
+
from typing import Dict, List, Union
|
| 4 |
import uuid
|
| 5 |
from venv import logger
|
| 6 |
from sqlalchemy.dialects import postgresql # or mysql, sqlite depending on your DB
|
| 7 |
from sqlalchemy import and_, desc, or_, select
|
| 8 |
from sqlalchemy.orm import noload
|
| 9 |
+
from transformers import pipeline, AutoTokenizer
|
| 10 |
+
from pinecone import SearchRerank
|
| 11 |
+
from database.models.article import Article
|
| 12 |
+
from domains.news.service import ArticleService
|
| 13 |
from database.models.property import Property
|
| 14 |
from domains.properties.service import PropertyService
|
| 15 |
from domains.chat_session.service import ChatSessionService
|
|
|
|
| 23 |
from sqlalchemy.ext.asyncio import AsyncSession
|
| 24 |
from google.genai import types
|
| 25 |
from configs.gemai import client
|
| 26 |
+
from configs.pinecone import article_index, pc
|
| 27 |
import re
|
| 28 |
from litestar.exceptions import ValidationException, InternalServerException
|
| 29 |
import requests
|
|
|
|
| 103 |
"model_id": message.id,
|
| 104 |
}
|
| 105 |
for image in data.image_list
|
| 106 |
+
],
|
| 107 |
)
|
| 108 |
return message
|
| 109 |
except Exception as e:
|
|
|
|
| 300 |
|
| 301 |
async def summarize_session(self, session_id: uuid.UUID) -> str:
|
| 302 |
"""
|
| 303 |
+
Summarize the entire chat session by chunking the transcript to respect the model's token limit.
|
| 304 |
"""
|
|
|
|
| 305 |
query = (
|
| 306 |
select(ChatMessage)
|
| 307 |
.where(ChatMessage.session_id == session_id)
|
|
|
|
| 309 |
)
|
| 310 |
result = await self.repository.session.execute(query)
|
| 311 |
messages: List[ChatMessage] = result.scalars().all()
|
| 312 |
+
chunks: List[str] = []
|
| 313 |
+
current_chunk = []
|
| 314 |
+
current_tokens = 0
|
| 315 |
+
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-vi")
|
| 316 |
+
summarizer = pipeline(
|
| 317 |
+
"summarization",
|
| 318 |
+
model="Helsinki-NLP/opus-mt-en-vi",
|
| 319 |
+
tokenizer=tokenizer,
|
| 320 |
+
device=-1,
|
| 321 |
)
|
| 322 |
+
for msg in messages:
|
| 323 |
+
speaker = "User" if msg.sender_id else "Assistant"
|
| 324 |
+
line = f"{speaker}: {msg.content}"
|
| 325 |
+
tokens = len(tokenizer(line, add_special_tokens=False))
|
| 326 |
+
if current_tokens + tokens > tokenizer.model_max_length:
|
| 327 |
+
chunks.append("\n".join(current_chunk))
|
| 328 |
+
current_chunk = [line]
|
| 329 |
+
current_tokens = tokens
|
| 330 |
+
else:
|
| 331 |
+
current_chunk.append(line)
|
| 332 |
+
current_tokens += tokens
|
| 333 |
+
if current_chunk:
|
| 334 |
+
chunks.append("\n".join(current_chunk))
|
| 335 |
+
partial_summaries = []
|
| 336 |
+
for chunk in chunks:
|
| 337 |
+
summary_out = summarizer(
|
| 338 |
+
chunk, max_length=150, min_length=10, do_sample=False, truncation=True
|
| 339 |
+
)
|
| 340 |
+
partial_summaries.append(summary_out[0]["summary_text"])
|
| 341 |
+
combined = "\n".join(partial_summaries)
|
| 342 |
+
final_out = summarizer(combined, max_length=200, min_length=50, do_sample=False)
|
| 343 |
+
|
| 344 |
+
return final_out[0]["summary_text"]
|
| 345 |
|
| 346 |
async def build_chat_context(
|
| 347 |
self, session_id: uuid.UUID, window_size: int = 10
|
|
|
|
| 398 |
context = await self.build_chat_context(data.session_id, window_size)
|
| 399 |
else:
|
| 400 |
context = []
|
| 401 |
+
articles = await self.rag_article(data.content)
|
| 402 |
context.append(UserContent(data.content))
|
| 403 |
+
system_instruction = """
|
| 404 |
+
You are a real estate assistant that help user choose and find the best match properties. Detect if the user wants property suggestions in any language.
|
| 405 |
Always respond helpfully. If suggestions are requested, at the very end append exactly one line with
|
| 406 |
#PROPERTY_CRITERIA:<json>
|
| 407 |
where <json> exactly matches the PropertySchema fields:
|
|
|
|
| 417 |
"average_rating": number,
|
| 418 |
"status": boolean,
|
| 419 |
}
|
| 420 |
+
If not, do not append the tag.
|
| 421 |
+
You will be provided with a list of relative articles that might help you answer user.
|
| 422 |
+
Each article is separated by the mark: ======== Article <number> =======.
|
| 423 |
+
If there are conflicts in information of articles, use the newer information.
|
| 424 |
+
Here is the list of relative articles that you can based on to response to user: """
|
| 425 |
+
for i, article in enumerate(articles):
|
| 426 |
+
system_instruction += f"\n ======== Article {i + 1} ============ \nTitle: {article.title} \nContent: {article.content} \nPublished date: {article.publish_date.isoformat()}"
|
| 427 |
+
system_instruction += f" If you use information from any provided article. Reference that article with the link. Also, here is there summary of the conversation between you and this customer {summary}"
|
| 428 |
try:
|
| 429 |
response = client.models.generate_content(
|
| 430 |
model="gemini-2.0-flash",
|
| 431 |
contents=context,
|
| 432 |
config=GenerateContentConfig(
|
|
|
|
| 433 |
system_instruction=system_instruction,
|
| 434 |
),
|
| 435 |
)
|
|
|
|
| 445 |
pass
|
| 446 |
raise
|
| 447 |
assistant_text = response.text
|
| 448 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
message = await self.create(
|
| 450 |
{
|
| 451 |
"content": assistant_text,
|
| 452 |
+
"session_id": data.session_id,
|
| 453 |
"sender_id": None,
|
| 454 |
}
|
| 455 |
)
|
|
|
|
| 461 |
finally:
|
| 462 |
await self.repository.session.commit()
|
| 463 |
|
| 464 |
+
async def rag_article(self, query: str) -> list[Article]:
|
| 465 |
+
summarized_query = self.summarize_query_for_rag(
|
| 466 |
+
query, max_length=len(query) // 2
|
| 467 |
+
)
|
| 468 |
+
reranked_articles = self.get_relevant_articles(summarized_query, 20, 10)
|
| 469 |
+
article_service = ArticleService(session=self.repository.session)
|
| 470 |
+
full_articles = await article_service.list(
|
| 471 |
+
Article.id.in_([article["_id"] for article in reranked_articles])
|
| 472 |
+
)
|
| 473 |
+
return full_articles
|
| 474 |
+
|
| 475 |
+
def summarize_query_for_rag(
|
| 476 |
+
self,
|
| 477 |
+
text: str,
|
| 478 |
+
max_length: int = 100,
|
| 479 |
+
min_length: int = 5,
|
| 480 |
+
device: Union[str, int] = -1,
|
| 481 |
+
) -> str:
|
| 482 |
+
"""
|
| 483 |
+
Summarizes a user query in any language for use in a RAG retriever.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
text (str): The input text/query in any supported language.
|
| 487 |
+
max_length (int): Maximum length of the summary/query.
|
| 488 |
+
min_length (int): Minimum length of the summary/query.
|
| 489 |
+
device (Union[str, int]): Device for inference (-1 for CPU, 0 or 1 for GPU).
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
str: Summarized query text.
|
| 493 |
+
"""
|
| 494 |
+
summarizer = pipeline(
|
| 495 |
+
"summarization",
|
| 496 |
+
model="Helsinki-NLP/opus-mt-en-vi",
|
| 497 |
+
tokenizer="Helsinki-NLP/opus-mt-en-vi",
|
| 498 |
+
device=device,
|
| 499 |
+
)
|
| 500 |
+
summary = summarizer(
|
| 501 |
+
text, max_length=max_length, min_length=min_length, do_sample=False
|
| 502 |
+
)
|
| 503 |
+
return summary[0]["summary_text"]
|
| 504 |
+
|
| 505 |
+
def get_relevant_articles(
|
| 506 |
+
self, query: str, retrieval_n: int = 10, rerank_n: int = 3
|
| 507 |
+
) -> Dict:
|
| 508 |
+
result = article_index.search(
|
| 509 |
+
"__default__",
|
| 510 |
+
query={"top_k": retrieval_n, "inputs": {"text": query}},
|
| 511 |
+
rerank=SearchRerank(
|
| 512 |
+
model="bge-reranker-v2-m3",
|
| 513 |
+
rank_fields=["summary"],
|
| 514 |
+
top_n=rerank_n,
|
| 515 |
+
parameters={"truncate": "END"},
|
| 516 |
+
),
|
| 517 |
+
)
|
| 518 |
+
return result.to_dict()["result"]["hits"]
|
| 519 |
+
|
| 520 |
|
| 521 |
async def provide_chat_message_service(
|
| 522 |
db_session: AsyncSession,
|
app/domains/properties/service.py
CHANGED
|
@@ -199,17 +199,13 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
|
|
| 199 |
pagination: LimitOffset,
|
| 200 |
user_id: uuid.UUID,
|
| 201 |
) -> CursorPagination[str, Property]:
|
| 202 |
-
# 1) Build Pinecone metadata filter
|
| 203 |
meta_filter = self._build_pinecone_filter(search_param)
|
| 204 |
-
# 2) Generate user embedding from past interactions
|
| 205 |
user_embedding = await self._compute_user_embedding(user_id)
|
| 206 |
-
# 3) Query Pinecone
|
| 207 |
pine_res = property_index.query(
|
| 208 |
vector=user_embedding,
|
| 209 |
filter=meta_filter,
|
| 210 |
top_k=pagination.limit,
|
| 211 |
-
include_metadata=
|
| 212 |
-
# next_page_token=search_param.next_page_token,
|
| 213 |
)
|
| 214 |
ids = [m["id"] for m in pine_res["matches"]]
|
| 215 |
props = await self._fetch_properties_from_ids(ids)
|
|
@@ -261,19 +257,21 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
|
|
| 261 |
if search_param.lat is not None and search_param.lng is not None:
|
| 262 |
query = query.join(Property.address)
|
| 263 |
radius_meters = search_param.radius * 1000
|
| 264 |
-
radius_degrees = radius_meters / 111320.0
|
| 265 |
lat = search_param.lat
|
| 266 |
lng = search_param.lng
|
| 267 |
min_lat = lat - radius_degrees
|
| 268 |
max_lat = lat + radius_degrees
|
| 269 |
min_lng = lng - radius_degrees
|
| 270 |
max_lng = lng + radius_degrees
|
| 271 |
-
query = query.where(
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
| 277 |
# price filters
|
| 278 |
if search_param.min_price is not None:
|
| 279 |
query = query.where(Property.price >= search_param.min_price)
|
|
@@ -281,10 +279,7 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
|
|
| 281 |
query = query.where(Property.price <= search_param.max_price)
|
| 282 |
# Have review
|
| 283 |
if search_param.has_review:
|
| 284 |
-
subquery = (
|
| 285 |
-
select(Review.id)
|
| 286 |
-
.where(Review.property_id == Property.id)
|
| 287 |
-
)
|
| 288 |
query = query.where(exists(subquery))
|
| 289 |
# categorical
|
| 290 |
if search_param.property_category:
|
|
@@ -392,16 +387,12 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
|
|
| 392 |
|
| 393 |
async def _compute_user_embedding(self, user_id: uuid.UUID) -> list[float]:
|
| 394 |
user_action_repository = UserActionRepository(session=self.repository.session)
|
| 395 |
-
|
| 396 |
user_id=user_id
|
| 397 |
)
|
| 398 |
-
if len(
|
| 399 |
-
return next(
|
| 400 |
-
|
| 401 |
-
).values
|
| 402 |
-
result = await self.fetch_pinecone_document_by_id(
|
| 403 |
-
[UUID(id) for id in properties_action.keys()]
|
| 404 |
-
)
|
| 405 |
vectors = [value.values for value in result.values()]
|
| 406 |
mean_vector = np.mean(vectors, axis=0).tolist()
|
| 407 |
return mean_vector
|
|
|
|
| 199 |
pagination: LimitOffset,
|
| 200 |
user_id: uuid.UUID,
|
| 201 |
) -> CursorPagination[str, Property]:
|
|
|
|
| 202 |
meta_filter = self._build_pinecone_filter(search_param)
|
|
|
|
| 203 |
user_embedding = await self._compute_user_embedding(user_id)
|
|
|
|
| 204 |
pine_res = property_index.query(
|
| 205 |
vector=user_embedding,
|
| 206 |
filter=meta_filter,
|
| 207 |
top_k=pagination.limit,
|
| 208 |
+
include_metadata=False,
|
|
|
|
| 209 |
)
|
| 210 |
ids = [m["id"] for m in pine_res["matches"]]
|
| 211 |
props = await self._fetch_properties_from_ids(ids)
|
|
|
|
| 257 |
if search_param.lat is not None and search_param.lng is not None:
|
| 258 |
query = query.join(Property.address)
|
| 259 |
radius_meters = search_param.radius * 1000
|
| 260 |
+
radius_degrees = radius_meters / 111320.0
|
| 261 |
lat = search_param.lat
|
| 262 |
lng = search_param.lng
|
| 263 |
min_lat = lat - radius_degrees
|
| 264 |
max_lat = lat + radius_degrees
|
| 265 |
min_lng = lng - radius_degrees
|
| 266 |
max_lng = lng + radius_degrees
|
| 267 |
+
query = query.where(
|
| 268 |
+
and_(
|
| 269 |
+
Address.latitude >= min_lat,
|
| 270 |
+
Address.latitude <= max_lat,
|
| 271 |
+
Address.longitude >= min_lng,
|
| 272 |
+
Address.longitude <= max_lng,
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
# price filters
|
| 276 |
if search_param.min_price is not None:
|
| 277 |
query = query.where(Property.price >= search_param.min_price)
|
|
|
|
| 279 |
query = query.where(Property.price <= search_param.max_price)
|
| 280 |
# Have review
|
| 281 |
if search_param.has_review:
|
| 282 |
+
subquery = select(Review.id).where(Review.property_id == Property.id)
|
|
|
|
|
|
|
|
|
|
| 283 |
query = query.where(exists(subquery))
|
| 284 |
# categorical
|
| 285 |
if search_param.property_category:
|
|
|
|
| 387 |
|
| 388 |
async def _compute_user_embedding(self, user_id: uuid.UUID) -> list[float]:
|
| 389 |
user_action_repository = UserActionRepository(session=self.repository.session)
|
| 390 |
+
property_id_list = await user_action_repository.get_relevant_properties(
|
| 391 |
user_id=user_id
|
| 392 |
)
|
| 393 |
+
if len(property_id_list) == 0:
|
| 394 |
+
return next(iter(property_index.fetch(["0"]).vectors.values())).values
|
| 395 |
+
result = await self.fetch_pinecone_document_by_id(property_id_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
vectors = [value.values for value in result.values()]
|
| 397 |
mean_vector = np.mean(vectors, axis=0).tolist()
|
| 398 |
return mean_vector
|
app/domains/user_action/service.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from collections import defaultdict
|
| 2 |
from collections.abc import AsyncGenerator
|
|
|
|
| 3 |
import uuid
|
| 4 |
|
| 5 |
from sqlalchemy import select
|
|
@@ -11,29 +12,28 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 11 |
|
| 12 |
class UserActionRepository(SQLAlchemyAsyncRepository[UserAction]):
|
| 13 |
model_type = UserAction
|
| 14 |
-
|
|
|
|
| 15 |
prop_ids_subq = (
|
| 16 |
select(UserAction.property_id)
|
| 17 |
.where(UserAction.user_id == user_id)
|
|
|
|
| 18 |
.distinct()
|
| 19 |
.limit(10)
|
| 20 |
).subquery()
|
| 21 |
|
| 22 |
-
# Step 2: fetch all actions for those properties
|
| 23 |
result = await self.session.execute(
|
| 24 |
select(UserAction)
|
| 25 |
.where(
|
| 26 |
UserAction.user_id == user_id,
|
| 27 |
-
UserAction.property_id.in_(select(prop_ids_subq))
|
| 28 |
)
|
| 29 |
.order_by(UserAction.property_id, UserAction.created_at)
|
| 30 |
)
|
| 31 |
actions = result.scalars().all()
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
grouped: dict = {}
|
| 34 |
-
for act in actions:
|
| 35 |
-
grouped[str(act.property_id)].append(act)
|
| 36 |
-
return grouped
|
| 37 |
class UserActionService(SQLAlchemyAsyncRepositoryService[UserAction]):
|
| 38 |
repository_type = UserActionRepository
|
| 39 |
|
|
|
|
| 1 |
from collections import defaultdict
|
| 2 |
from collections.abc import AsyncGenerator
|
| 3 |
+
from typing import List
|
| 4 |
import uuid
|
| 5 |
|
| 6 |
from sqlalchemy import select
|
|
|
|
| 12 |
|
| 13 |
class UserActionRepository(SQLAlchemyAsyncRepository[UserAction]):
|
| 14 |
model_type = UserAction
|
| 15 |
+
|
| 16 |
+
async def get_relevant_properties(self, user_id: uuid.UUID) -> List[uuid.UUID]:
|
| 17 |
prop_ids_subq = (
|
| 18 |
select(UserAction.property_id)
|
| 19 |
.where(UserAction.user_id == user_id)
|
| 20 |
+
.where(UserAction.action == "view")
|
| 21 |
.distinct()
|
| 22 |
.limit(10)
|
| 23 |
).subquery()
|
| 24 |
|
|
|
|
| 25 |
result = await self.session.execute(
|
| 26 |
select(UserAction)
|
| 27 |
.where(
|
| 28 |
UserAction.user_id == user_id,
|
| 29 |
+
UserAction.property_id.in_(select(prop_ids_subq)),
|
| 30 |
)
|
| 31 |
.order_by(UserAction.property_id, UserAction.created_at)
|
| 32 |
)
|
| 33 |
actions = result.scalars().all()
|
| 34 |
+
return [action.property_id for action in actions]
|
| 35 |
+
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
class UserActionService(SQLAlchemyAsyncRepositoryService[UserAction]):
|
| 38 |
repository_type = UserActionRepository
|
| 39 |
|
app/seed/factories/article.py
CHANGED
|
@@ -13,6 +13,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
| 13 |
from configs.gemai import client
|
| 14 |
from google.genai.types import GenerateContentConfig
|
| 15 |
from advanced_alchemy.utils.text import slugify
|
|
|
|
|
|
|
| 16 |
|
| 17 |
safety_settings = [
|
| 18 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
|
@@ -26,86 +28,56 @@ safety_settings = [
|
|
| 26 |
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
| 27 |
},
|
| 28 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
async def generate_tags_and_summary(article_html_content: str) -> dict:
|
| 32 |
"""
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
article_html_content: The HTML content of the article.
|
| 37 |
-
|
| 38 |
-
Returns:
|
| 39 |
-
A dictionary with "tags" (list of strings) and "short_description" (string).
|
| 40 |
-
Returns empty values if generation fails.
|
| 41 |
-
"""
|
| 42 |
-
prompt = f"""
|
| 43 |
-
Analyze the following Vietnamese news article content (provided in HTML format) and perform two tasks:
|
| 44 |
-
1. Generate a concise short description (summary) of the article in Vietnamese. This description should be no more than 80 words and capture the main points.
|
| 45 |
-
2. Extract 3 to 7 relevant keywords (tags) for this article in Vietnamese. These tags should be single words or short phrases.
|
| 46 |
-
|
| 47 |
-
Article Content:
|
| 48 |
-
```html
|
| 49 |
-
{article_html_content[:15000]}
|
| 50 |
-
```
|
| 51 |
-
|
| 52 |
-
Provide your response strictly as a JSON object with two keys: "short_description" and "tags".
|
| 53 |
-
The "tags" value should be a list of strings.
|
| 54 |
-
Example JSON output:
|
| 55 |
-
{{
|
| 56 |
-
"short_description": "Một bản tóm tắt ngắn gọn của bài báo bằng tiếng Việt...,
|
| 57 |
-
"tags": ["bất động sản", "thị trường", "dự án mới", "Việt Nam"]
|
| 58 |
-
}}
|
| 59 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
try:
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
contents=[
|
| 67 |
-
prompt,
|
| 68 |
-
],
|
| 69 |
-
config=GenerateContentConfig(
|
| 70 |
-
safety_settings=safety_settings,
|
| 71 |
-
top_p=1,
|
| 72 |
-
temperature=0.7,
|
| 73 |
-
max_output_tokens=2048,
|
| 74 |
-
response_modalities=["TEXT"],
|
| 75 |
-
),
|
| 76 |
)
|
| 77 |
-
|
| 78 |
-
if cleaned_response_text.startswith("```json"):
|
| 79 |
-
cleaned_response_text = cleaned_response_text[7:]
|
| 80 |
-
if cleaned_response_text.endswith("```"):
|
| 81 |
-
cleaned_response_text = cleaned_response_text[:-3]
|
| 82 |
-
cleaned_response_text = cleaned_response_text.strip()
|
| 83 |
-
data = json.loads(cleaned_response_text)
|
| 84 |
-
|
| 85 |
-
tags = data.get("tags", [])
|
| 86 |
-
short_desc = data.get("short_description", "")
|
| 87 |
-
|
| 88 |
-
if not isinstance(tags, list):
|
| 89 |
-
print(
|
| 90 |
-
f"Warning: Gemini returned tags not as a list: {tags}. Using empty list."
|
| 91 |
-
)
|
| 92 |
-
tags = []
|
| 93 |
-
if not isinstance(short_desc, str):
|
| 94 |
-
print(
|
| 95 |
-
f"Warning: Gemini returned short_description not as a string: {short_desc}. Using empty string."
|
| 96 |
-
)
|
| 97 |
-
short_desc = ""
|
| 98 |
-
|
| 99 |
-
return {"tags": tags, "short_description": short_desc}
|
| 100 |
-
|
| 101 |
except Exception as e:
|
| 102 |
-
print(f"
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
class ArticleFactory(BaseFactory):
|
|
@@ -150,7 +122,6 @@ class ArticleFactory(BaseFactory):
|
|
| 150 |
max_tokens=10000,
|
| 151 |
)
|
| 152 |
text = response.choices[0].message.content.strip()
|
| 153 |
-
print(text)
|
| 154 |
articles = json.loads(text)
|
| 155 |
if not isinstance(articles, list):
|
| 156 |
raise ValueError("Expected a JSON list of articles.")
|
|
@@ -182,35 +153,36 @@ class ArticleFactory(BaseFactory):
|
|
| 182 |
await import_articles_from_json(
|
| 183 |
os.path.join(fixture_path, "articles.json"), session
|
| 184 |
)
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
| 214 |
except Exception as e:
|
| 215 |
await session.rollback()
|
| 216 |
print(f"Error during ArticleFactory seeding: {e}")
|
|
@@ -223,14 +195,15 @@ class ArticleFactory(BaseFactory):
|
|
| 223 |
await self.repository(session=session).delete_where(Article.id.is_not(None))
|
| 224 |
await session.commit()
|
| 225 |
|
| 226 |
-
|
| 227 |
def parse_vietnamese_datetime(date_str: str) -> datetime | None:
|
| 228 |
"""
|
| 229 |
-
Tries to parse common Vietnamese datetime string formats.
|
| 230 |
Returns a timezone-aware datetime object (UTC) or None if parsing fails.
|
| 231 |
"""
|
| 232 |
if not date_str or not isinstance(date_str, str):
|
| 233 |
return None
|
|
|
|
|
|
|
| 234 |
if "T" in date_str and ("Z" in date_str or "+" in date_str or "-" in date_str[10:]):
|
| 235 |
try:
|
| 236 |
dt = datetime.fromisoformat(date_str)
|
|
@@ -240,6 +213,7 @@ def parse_vietnamese_datetime(date_str: str) -> datetime | None:
|
|
| 240 |
except ValueError:
|
| 241 |
pass
|
| 242 |
|
|
|
|
| 243 |
formats_to_try = [
|
| 244 |
"%d/%m/%Y %H:%M:%S",
|
| 245 |
"%d/%m/%Y %H:%M",
|
|
@@ -247,7 +221,9 @@ def parse_vietnamese_datetime(date_str: str) -> datetime | None:
|
|
| 247 |
"%d-%m-%Y %H:%M",
|
| 248 |
"%Y-%m-%d %H:%M:%S",
|
| 249 |
"%Y/%m/%d %H:%M:%S",
|
|
|
|
| 250 |
]
|
|
|
|
| 251 |
for fmt in formats_to_try:
|
| 252 |
try:
|
| 253 |
dt_naive = datetime.strptime(date_str.strip(), fmt)
|
|
@@ -255,10 +231,10 @@ def parse_vietnamese_datetime(date_str: str) -> datetime | None:
|
|
| 255 |
return dt_aware
|
| 256 |
except ValueError:
|
| 257 |
continue
|
|
|
|
| 258 |
print(f"Warning: Could not parse date string: {date_str}")
|
| 259 |
return None
|
| 260 |
|
| 261 |
-
|
| 262 |
async def get_or_create_tags(session: AsyncSession, tag_names: List[str]) -> List[Tag]:
|
| 263 |
"""
|
| 264 |
Retrieves existing Tag objects or creates new ones for each tag name.
|
|
@@ -296,10 +272,9 @@ async def process_article_data(session: AsyncSession, article_data: Dict[str, An
|
|
| 296 |
return None
|
| 297 |
gemini_data = await generate_tags_and_summary(html_content)
|
| 298 |
tag_names = gemini_data.get("tags", [])
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
short_description = "Tóm tắt không có sẵn."
|
| 303 |
if not tag_names:
|
| 304 |
print(f"No tags generated for article: {title}")
|
| 305 |
publish_date = parse_vietnamese_datetime(published_date_str)
|
|
@@ -308,14 +283,14 @@ async def process_article_data(session: AsyncSession, article_data: Dict[str, An
|
|
| 308 |
f"Using current time for article '{title}' due to unparseable date: {published_date_str}"
|
| 309 |
)
|
| 310 |
publish_date = datetime.now(timezone.utc)
|
| 311 |
-
db_tags = await get_or_create_tags(session, tag_names)
|
| 312 |
new_article = Article(
|
| 313 |
title=title,
|
| 314 |
publish_date=publish_date,
|
| 315 |
content=html_content,
|
| 316 |
short_description=short_description[:499],
|
| 317 |
author=source_name,
|
| 318 |
-
tags=
|
| 319 |
)
|
| 320 |
return new_article
|
| 321 |
|
|
@@ -347,6 +322,8 @@ async def import_articles_from_json(json_filepath: str, session: AsyncSession):
|
|
| 347 |
|
| 348 |
articles_to_add = []
|
| 349 |
for i, item_data in enumerate(data_from_json):
|
|
|
|
|
|
|
| 350 |
print(f"\n--- Processing item {i+1}/{len(data_from_json)} ---")
|
| 351 |
article_obj = await process_article_data(session, item_data)
|
| 352 |
if article_obj:
|
|
|
|
| 13 |
from configs.gemai import client
|
| 14 |
from google.genai.types import GenerateContentConfig
|
| 15 |
from advanced_alchemy.utils.text import slugify
|
| 16 |
+
from transformers import pipeline
|
| 17 |
+
import re
|
| 18 |
|
| 19 |
safety_settings = [
|
| 20 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
|
|
|
| 28 |
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
|
| 29 |
},
|
| 30 |
]
|
| 31 |
+
_SUMMARY_PIPELINE = pipeline(
|
| 32 |
+
"summarization",
|
| 33 |
+
model="google/long-t5-tglobal-base",
|
| 34 |
+
tokenizer="google/long-t5-tglobal-base",
|
| 35 |
+
device=-1,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
_KEYPHRASE_PIPELINE = pipeline(
|
| 39 |
+
"text2text-generation",
|
| 40 |
+
model="google/long-t5-tglobal-base",
|
| 41 |
+
tokenizer="google/long-t5-tglobal-base",
|
| 42 |
+
framework="pt",
|
| 43 |
+
device=-1,
|
| 44 |
+
)
|
| 45 |
|
| 46 |
|
| 47 |
async def generate_tags_and_summary(article_html_content: str) -> dict:
|
| 48 |
"""
|
| 49 |
+
Summarize and extract tags using small transformer models.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
+
text = re.sub(r"<[^>]+>", " ", article_html_content)
|
| 52 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 53 |
+
if len(text) < 50:
|
| 54 |
+
return {"tags": [], "short_description": text}
|
| 55 |
try:
|
| 56 |
+
summary_out = _SUMMARY_PIPELINE(
|
| 57 |
+
text,
|
| 58 |
+
max_length=200,
|
| 59 |
+
min_length=30,
|
| 60 |
+
do_sample=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
+
short_description = summary_out[0]["summary_text"].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
except Exception as e:
|
| 64 |
+
print(f"Summarization error: {e}")
|
| 65 |
+
short_description = text[:300] + ("…" if len(text) > 300 else "")
|
| 66 |
+
try:
|
| 67 |
+
prompt = "extract keyphrases: " + text[:1000] # limit length
|
| 68 |
+
kpop = _KEYPHRASE_PIPELINE(prompt, max_length=64, do_sample=False)
|
| 69 |
+
raw = kpop[0]["generated_text"]
|
| 70 |
+
tags = re.split(r"[;,]\s*", raw)
|
| 71 |
+
tags = list(dict.fromkeys([t.strip().lower() for t in tags if t.strip()]))
|
| 72 |
+
tags = tags[:7]
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Keyphrase extraction error: {e}")
|
| 75 |
+
tags = []
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"short_description": short_description,
|
| 79 |
+
"tags": tags,
|
| 80 |
+
}
|
| 81 |
|
| 82 |
|
| 83 |
class ArticleFactory(BaseFactory):
|
|
|
|
| 122 |
max_tokens=10000,
|
| 123 |
)
|
| 124 |
text = response.choices[0].message.content.strip()
|
|
|
|
| 125 |
articles = json.loads(text)
|
| 126 |
if not isinstance(articles, list):
|
| 127 |
raise ValueError("Expected a JSON list of articles.")
|
|
|
|
| 153 |
await import_articles_from_json(
|
| 154 |
os.path.join(fixture_path, "articles.json"), session
|
| 155 |
)
|
| 156 |
+
else:
|
| 157 |
+
articles_data = self.fetch_articles_from_openai(count)
|
| 158 |
+
for article_data in articles_data:
|
| 159 |
+
result = await session.execute(
|
| 160 |
+
select(Article).filter_by(title=article_data.get("title"))
|
| 161 |
+
)
|
| 162 |
+
if result.scalars().first():
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
publish_date_str = article_data.get("publish_date")
|
| 166 |
+
try:
|
| 167 |
+
publish_date = datetime.fromisoformat(publish_date_str)
|
| 168 |
+
except Exception:
|
| 169 |
+
publish_date = datetime.now(timezone.utc)
|
| 170 |
+
|
| 171 |
+
tag_names = article_data.get("tags", [])
|
| 172 |
+
tags = await self.get_or_create_tags(session, tag_names)
|
| 173 |
+
|
| 174 |
+
article = Article(
|
| 175 |
+
id=uuid.uuid4(),
|
| 176 |
+
title=article_data.get("title"),
|
| 177 |
+
publish_date=publish_date,
|
| 178 |
+
content=article_data.get("content"),
|
| 179 |
+
short_description=article_data.get("short_description"),
|
| 180 |
+
author=article_data.get("author"),
|
| 181 |
+
tags=tags,
|
| 182 |
+
created_at=datetime.now(timezone.utc),
|
| 183 |
+
updated_at=datetime.now(timezone.utc),
|
| 184 |
+
)
|
| 185 |
+
await self.repository(session=session).add(article)
|
| 186 |
except Exception as e:
|
| 187 |
await session.rollback()
|
| 188 |
print(f"Error during ArticleFactory seeding: {e}")
|
|
|
|
| 195 |
await self.repository(session=session).delete_where(Article.id.is_not(None))
|
| 196 |
await session.commit()
|
| 197 |
|
|
|
|
| 198 |
def parse_vietnamese_datetime(date_str: str) -> datetime | None:
|
| 199 |
"""
|
| 200 |
+
Tries to parse common Vietnamese datetime string formats, including RFC 1123.
|
| 201 |
Returns a timezone-aware datetime object (UTC) or None if parsing fails.
|
| 202 |
"""
|
| 203 |
if not date_str or not isinstance(date_str, str):
|
| 204 |
return None
|
| 205 |
+
|
| 206 |
+
# First: handle ISO8601 with 'T' and timezone info
|
| 207 |
if "T" in date_str and ("Z" in date_str or "+" in date_str or "-" in date_str[10:]):
|
| 208 |
try:
|
| 209 |
dt = datetime.fromisoformat(date_str)
|
|
|
|
| 213 |
except ValueError:
|
| 214 |
pass
|
| 215 |
|
| 216 |
+
# Try known formats, including RFC 1123
|
| 217 |
formats_to_try = [
|
| 218 |
"%d/%m/%Y %H:%M:%S",
|
| 219 |
"%d/%m/%Y %H:%M",
|
|
|
|
| 221 |
"%d-%m-%Y %H:%M",
|
| 222 |
"%Y-%m-%d %H:%M:%S",
|
| 223 |
"%Y/%m/%d %H:%M:%S",
|
| 224 |
+
"%a, %d %b %Y %H:%M:%S GMT", # RFC 1123 (e.g., "Sun, 01 Jun 2025 01:16:00 GMT")
|
| 225 |
]
|
| 226 |
+
|
| 227 |
for fmt in formats_to_try:
|
| 228 |
try:
|
| 229 |
dt_naive = datetime.strptime(date_str.strip(), fmt)
|
|
|
|
| 231 |
return dt_aware
|
| 232 |
except ValueError:
|
| 233 |
continue
|
| 234 |
+
|
| 235 |
print(f"Warning: Could not parse date string: {date_str}")
|
| 236 |
return None
|
| 237 |
|
|
|
|
| 238 |
async def get_or_create_tags(session: AsyncSession, tag_names: List[str]) -> List[Tag]:
|
| 239 |
"""
|
| 240 |
Retrieves existing Tag objects or creates new ones for each tag name.
|
|
|
|
| 272 |
return None
|
| 273 |
gemini_data = await generate_tags_and_summary(html_content)
|
| 274 |
tag_names = gemini_data.get("tags", [])
|
| 275 |
+
short_description = gemini_data.get("short_description")
|
| 276 |
+
if not short_description:
|
| 277 |
+
short_description = "Tóm tắt không có sẵn."
|
|
|
|
| 278 |
if not tag_names:
|
| 279 |
print(f"No tags generated for article: {title}")
|
| 280 |
publish_date = parse_vietnamese_datetime(published_date_str)
|
|
|
|
| 283 |
f"Using current time for article '{title}' due to unparseable date: {published_date_str}"
|
| 284 |
)
|
| 285 |
publish_date = datetime.now(timezone.utc)
|
| 286 |
+
# db_tags = await get_or_create_tags(session, tag_names)
|
| 287 |
new_article = Article(
|
| 288 |
title=title,
|
| 289 |
publish_date=publish_date,
|
| 290 |
content=html_content,
|
| 291 |
short_description=short_description[:499],
|
| 292 |
author=source_name,
|
| 293 |
+
tags=[],
|
| 294 |
)
|
| 295 |
return new_article
|
| 296 |
|
|
|
|
| 322 |
|
| 323 |
articles_to_add = []
|
| 324 |
for i, item_data in enumerate(data_from_json):
|
| 325 |
+
if i > 2:
|
| 326 |
+
break
|
| 327 |
print(f"\n--- Processing item {i+1}/{len(data_from_json)} ---")
|
| 328 |
article_obj = await process_article_data(session, item_data)
|
| 329 |
if article_obj:
|