FrancisGOS commited on
Commit
683f058
·
1 Parent(s): b2a4481

Fix chat message with RAG

Browse files
app/configs/pinecone.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
2
  import pinecone
3
  pc = pinecone.Pinecone(os.getenv("PINECONE_API_KEY"))
4
- property_index = pc.Index("properties")
 
 
1
  import os
2
  import pinecone
3
  pc = pinecone.Pinecone(os.getenv("PINECONE_API_KEY"))
4
+ property_index = pc.Index("properties")
5
+ article_index = pc.Index("articles")
app/domains/chat_message/controller.py CHANGED
@@ -39,21 +39,34 @@ class ChatMessageController(Controller):
39
  if not user.device_token:
40
  return
41
  notify_service = NotificationService()
42
- title = "You have a new message"
43
- body = f"You have a new message from {user.name}. \n{message.content}"
 
 
 
 
44
  notify_service.send_to_token(
45
  token=user.device_token,
46
  title=title,
47
  body=body,
48
  data={
49
  "type": "chat",
50
- "content": message.content,
51
  "sender_id": str(message.sender_id),
52
  "chat_session_id": str(message.session_id),
53
  "created_at": str(message.created_at.timestamp()),
54
  },
55
  )
56
 
 
 
 
 
 
 
 
 
 
57
  @post("")
58
  async def create_message(
59
  self,
@@ -64,42 +77,29 @@ class ChatMessageController(Controller):
64
  chat_service: ChatMessageService,
65
  chat_session_service: ChatSessionService,
66
  ) -> Response:
67
- if not data.is_ai:
68
- message = await chat_service.create_message(data, request.user.id)
69
- return Response(
70
- chat_service.to_schema(message, schema_type=ChatMessageSchema),
71
- background=BackgroundTasks(
72
- [
73
- BackgroundTask(
74
- self.notify_message,
75
- request.user,
76
- message,
77
- ),
78
- BackgroundTask(
79
- chat_session_service.update_last_message,
80
- message.session_id,
81
- message,
82
- ),
83
- ]
 
84
  ),
85
  )
86
- message = await chat_service.ai_respond_to_user(data, request.user.id)
87
  return Response(
88
  chat_service.to_schema(message, schema_type=ChatMessageSchema),
89
- background=BackgroundTasks(
90
- [
91
- BackgroundTask(
92
- chat_session_service.update_last_message,
93
- message.session_id,
94
- message,
95
- ),
96
- BackgroundTask(
97
- self.notify_message,
98
- request.user,
99
- message,
100
- ),
101
- ]
102
- ),
103
  )
104
 
105
  @post("/ai", no_auth=True, status_code=HTTP_200_OK)
 
39
  if not user.device_token:
40
  return
41
  notify_service = NotificationService()
42
+ title = "AI Assistant"
43
+ body = (
44
+ f"You have a new message from {user.name}."
45
+ if message.sender_id
46
+ else "AI has the answer you need"
47
+ )
48
  notify_service.send_to_token(
49
  token=user.device_token,
50
  title=title,
51
  body=body,
52
  data={
53
  "type": "chat",
54
+ "id": str(message.id),
55
  "sender_id": str(message.sender_id),
56
  "chat_session_id": str(message.session_id),
57
  "created_at": str(message.created_at.timestamp()),
58
  },
59
  )
60
 
61
+ async def chat_with_ai(
62
+ self,
63
+ data: CreateMessageDTO,
64
+ user: User,
65
+ chat_service: ChatMessageService,
66
+ ):
67
+ message = await chat_service.ai_respond_to_user(data, user_id=user.id)
68
+ self.notify_message(user, message)
69
+
70
  @post("")
71
  async def create_message(
72
  self,
 
77
  chat_service: ChatMessageService,
78
  chat_session_service: ChatSessionService,
79
  ) -> Response:
80
+ message = await chat_service.create_message(data, request.user.id)
81
+ background_task_list = [
82
+ BackgroundTask(
83
+ chat_session_service.update_last_message,
84
+ message.session_id,
85
+ message,
86
+ ),
87
+ ]
88
+ if data.is_ai:
89
+ background_task_list.append(
90
+ BackgroundTask(self.chat_with_ai, data, request.user, chat_service)
91
+ )
92
+ else:
93
+ background_task_list.append(
94
+ BackgroundTask(
95
+ self.notify_message,
96
+ request.user,
97
+ message,
98
  ),
99
  )
 
100
  return Response(
101
  chat_service.to_schema(message, schema_type=ChatMessageSchema),
102
+ background=BackgroundTasks(background_task_list),
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  )
104
 
105
  @post("/ai", no_auth=True, status_code=HTTP_200_OK)
app/domains/chat_message/service.py CHANGED
@@ -1,13 +1,15 @@
1
  from collections.abc import AsyncGenerator
2
  from datetime import datetime
3
- from typing import Dict, List
4
  import uuid
5
  from venv import logger
6
  from sqlalchemy.dialects import postgresql # or mysql, sqlite depending on your DB
7
  from sqlalchemy import and_, desc, or_, select
8
  from sqlalchemy.orm import noload
9
- from transformers import pipeline
10
-
 
 
11
  from database.models.property import Property
12
  from domains.properties.service import PropertyService
13
  from domains.chat_session.service import ChatSessionService
@@ -21,6 +23,7 @@ from domains.supabase.service import SupabaseService, provide_supabase_service
21
  from sqlalchemy.ext.asyncio import AsyncSession
22
  from google.genai import types
23
  from configs.gemai import client
 
24
  import re
25
  from litestar.exceptions import ValidationException, InternalServerException
26
  import requests
@@ -100,7 +103,7 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
100
  "model_id": message.id,
101
  }
102
  for image in data.image_list
103
- ],
104
  )
105
  return message
106
  except Exception as e:
@@ -297,9 +300,8 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
297
 
298
  async def summarize_session(self, session_id: uuid.UUID) -> str:
299
  """
300
- Summarize the entire chat session using a lightweight summarization model.
301
  """
302
- # Fetch all messages ordered oldest first
303
  query = (
304
  select(ChatMessage)
305
  .where(ChatMessage.session_id == session_id)
@@ -307,17 +309,39 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
307
  )
308
  result = await self.repository.session.execute(query)
309
  messages: List[ChatMessage] = result.scalars().all()
310
- # Concatenate speaker labels
311
- transcript = "\n".join(
312
- f"{ 'User' if msg.sender_id else 'Assistant' }: {msg.content}"
313
- for msg in messages
314
- )
315
- # Load summarizer (T5-small) on CPU
316
- summarizer = pipeline("summarization", model="t5-small", device=-1)
317
- summary_out = summarizer(
318
- transcript, max_length=150, min_length=50, do_sample=False
319
  )
320
- return summary_out[0]["summary_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  async def build_chat_context(
323
  self, session_id: uuid.UUID, window_size: int = 10
@@ -374,8 +398,10 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
374
  context = await self.build_chat_context(data.session_id, window_size)
375
  else:
376
  context = []
 
377
  context.append(UserContent(data.content))
378
- system_instruction = """You are a real estate assistant that help user choose and find the best match properties. Detect if the user wants property suggestions in any language.
 
379
  Always respond helpfully. If suggestions are requested, at the very end append exactly one line with
380
  #PROPERTY_CRITERIA:<json>
381
  where <json> exactly matches the PropertySchema fields:
@@ -391,14 +417,19 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
391
  "average_rating": number,
392
  "status": boolean,
393
  }
394
- If not, do not append the tag."""
395
- system_instruction += f"Also, here is there summary of the conversation between you and this customer {summary}"
 
 
 
 
 
 
396
  try:
397
  response = client.models.generate_content(
398
  model="gemini-2.0-flash",
399
  contents=context,
400
  config=GenerateContentConfig(
401
- tools=[Tool(google_search=GoogleSearch())],
402
  system_instruction=system_instruction,
403
  ),
404
  )
@@ -414,15 +445,11 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
414
  pass
415
  raise
416
  assistant_text = response.text
417
- message = await self.create_message(
418
- CreateMessageDTO(session_id=data.session_id, content=data.content),
419
- user_id,
420
- auto_commit=False,
421
- )
422
  message = await self.create(
423
  {
424
  "content": assistant_text,
425
- "session_id": message.session_id,
426
  "sender_id": None,
427
  }
428
  )
@@ -434,6 +461,62 @@ class ChatMessageService(SQLAlchemyAsyncRepositoryService[ChatMessage]):
434
  finally:
435
  await self.repository.session.commit()
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
  async def provide_chat_message_service(
439
  db_session: AsyncSession,
 
1
  from collections.abc import AsyncGenerator
2
  from datetime import datetime
3
+ from typing import Dict, List, Union
4
  import uuid
5
  from venv import logger
6
  from sqlalchemy.dialects import postgresql # or mysql, sqlite depending on your DB
7
  from sqlalchemy import and_, desc, or_, select
8
  from sqlalchemy.orm import noload
9
+ from transformers import pipeline, AutoTokenizer
10
+ from pinecone import SearchRerank
11
+ from database.models.article import Article
12
+ from domains.news.service import ArticleService
13
  from database.models.property import Property
14
  from domains.properties.service import PropertyService
15
  from domains.chat_session.service import ChatSessionService
 
23
  from sqlalchemy.ext.asyncio import AsyncSession
24
  from google.genai import types
25
  from configs.gemai import client
26
+ from configs.pinecone import article_index, pc
27
  import re
28
  from litestar.exceptions import ValidationException, InternalServerException
29
  import requests
 
103
  "model_id": message.id,
104
  }
105
  for image in data.image_list
106
+ ],
107
  )
108
  return message
109
  except Exception as e:
 
300
 
301
  async def summarize_session(self, session_id: uuid.UUID) -> str:
302
  """
303
+ Summarize the entire chat session by chunking the transcript to respect the model's token limit.
304
  """
 
305
  query = (
306
  select(ChatMessage)
307
  .where(ChatMessage.session_id == session_id)
 
309
  )
310
  result = await self.repository.session.execute(query)
311
  messages: List[ChatMessage] = result.scalars().all()
312
+ chunks: List[str] = []
313
+ current_chunk = []
314
+ current_tokens = 0
315
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-vi")
316
+ summarizer = pipeline(
317
+ "summarization",
318
+ model="Helsinki-NLP/opus-mt-en-vi",
319
+ tokenizer=tokenizer,
320
+ device=-1,
321
  )
322
+ for msg in messages:
323
+ speaker = "User" if msg.sender_id else "Assistant"
324
+ line = f"{speaker}: {msg.content}"
325
+ tokens = len(tokenizer(line, add_special_tokens=False))
326
+ if current_tokens + tokens > tokenizer.model_max_length:
327
+ chunks.append("\n".join(current_chunk))
328
+ current_chunk = [line]
329
+ current_tokens = tokens
330
+ else:
331
+ current_chunk.append(line)
332
+ current_tokens += tokens
333
+ if current_chunk:
334
+ chunks.append("\n".join(current_chunk))
335
+ partial_summaries = []
336
+ for chunk in chunks:
337
+ summary_out = summarizer(
338
+ chunk, max_length=150, min_length=10, do_sample=False, truncation=True
339
+ )
340
+ partial_summaries.append(summary_out[0]["summary_text"])
341
+ combined = "\n".join(partial_summaries)
342
+ final_out = summarizer(combined, max_length=200, min_length=50, do_sample=False)
343
+
344
+ return final_out[0]["summary_text"]
345
 
346
  async def build_chat_context(
347
  self, session_id: uuid.UUID, window_size: int = 10
 
398
  context = await self.build_chat_context(data.session_id, window_size)
399
  else:
400
  context = []
401
+ articles = await self.rag_article(data.content)
402
  context.append(UserContent(data.content))
403
+ system_instruction = """
404
+ You are a real estate assistant that help user choose and find the best match properties. Detect if the user wants property suggestions in any language.
405
  Always respond helpfully. If suggestions are requested, at the very end append exactly one line with
406
  #PROPERTY_CRITERIA:<json>
407
  where <json> exactly matches the PropertySchema fields:
 
417
  "average_rating": number,
418
  "status": boolean,
419
  }
420
+ If not, do not append the tag.
421
+ You will be provided with a list of relative articles that might help you answer user.
422
+ Each article is separated by the mark: ======== Article <number> =======.
423
+ If there are conflicts in information of articles, use the newer information.
424
+ Here is the list of relative articles that you can based on to response to user: """
425
+ for i, article in enumerate(articles):
426
+ system_instruction += f"\n ======== Article {i + 1} ============ \nTitle: {article.title} \nContent: {article.content} \nPublished date: {article.publish_date.isoformat()}"
427
+ system_instruction += f" If you use information from any provided article. Reference that article with the link. Also, here is there summary of the conversation between you and this customer {summary}"
428
  try:
429
  response = client.models.generate_content(
430
  model="gemini-2.0-flash",
431
  contents=context,
432
  config=GenerateContentConfig(
 
433
  system_instruction=system_instruction,
434
  ),
435
  )
 
445
  pass
446
  raise
447
  assistant_text = response.text
448
+
 
 
 
 
449
  message = await self.create(
450
  {
451
  "content": assistant_text,
452
+ "session_id": data.session_id,
453
  "sender_id": None,
454
  }
455
  )
 
461
  finally:
462
  await self.repository.session.commit()
463
 
464
+ async def rag_article(self, query: str) -> list[Article]:
465
+ summarized_query = self.summarize_query_for_rag(
466
+ query, max_length=len(query) // 2
467
+ )
468
+ reranked_articles = self.get_relevant_articles(summarized_query, 20, 10)
469
+ article_service = ArticleService(session=self.repository.session)
470
+ full_articles = await article_service.list(
471
+ Article.id.in_([article["_id"] for article in reranked_articles])
472
+ )
473
+ return full_articles
474
+
475
+ def summarize_query_for_rag(
476
+ self,
477
+ text: str,
478
+ max_length: int = 100,
479
+ min_length: int = 5,
480
+ device: Union[str, int] = -1,
481
+ ) -> str:
482
+ """
483
+ Summarizes a user query in any language for use in a RAG retriever.
484
+
485
+ Args:
486
+ text (str): The input text/query in any supported language.
487
+ max_length (int): Maximum length of the summary/query.
488
+ min_length (int): Minimum length of the summary/query.
489
+ device (Union[str, int]): Device for inference (-1 for CPU, 0 or 1 for GPU).
490
+
491
+ Returns:
492
+ str: Summarized query text.
493
+ """
494
+ summarizer = pipeline(
495
+ "summarization",
496
+ model="Helsinki-NLP/opus-mt-en-vi",
497
+ tokenizer="Helsinki-NLP/opus-mt-en-vi",
498
+ device=device,
499
+ )
500
+ summary = summarizer(
501
+ text, max_length=max_length, min_length=min_length, do_sample=False
502
+ )
503
+ return summary[0]["summary_text"]
504
+
505
+ def get_relevant_articles(
506
+ self, query: str, retrieval_n: int = 10, rerank_n: int = 3
507
+ ) -> Dict:
508
+ result = article_index.search(
509
+ "__default__",
510
+ query={"top_k": retrieval_n, "inputs": {"text": query}},
511
+ rerank=SearchRerank(
512
+ model="bge-reranker-v2-m3",
513
+ rank_fields=["summary"],
514
+ top_n=rerank_n,
515
+ parameters={"truncate": "END"},
516
+ ),
517
+ )
518
+ return result.to_dict()["result"]["hits"]
519
+
520
 
521
  async def provide_chat_message_service(
522
  db_session: AsyncSession,
app/domains/properties/service.py CHANGED
@@ -199,17 +199,13 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
199
  pagination: LimitOffset,
200
  user_id: uuid.UUID,
201
  ) -> CursorPagination[str, Property]:
202
- # 1) Build Pinecone metadata filter
203
  meta_filter = self._build_pinecone_filter(search_param)
204
- # 2) Generate user embedding from past interactions
205
  user_embedding = await self._compute_user_embedding(user_id)
206
- # 3) Query Pinecone
207
  pine_res = property_index.query(
208
  vector=user_embedding,
209
  filter=meta_filter,
210
  top_k=pagination.limit,
211
- include_metadata=True,
212
- # next_page_token=search_param.next_page_token,
213
  )
214
  ids = [m["id"] for m in pine_res["matches"]]
215
  props = await self._fetch_properties_from_ids(ids)
@@ -261,19 +257,21 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
261
  if search_param.lat is not None and search_param.lng is not None:
262
  query = query.join(Property.address)
263
  radius_meters = search_param.radius * 1000
264
- radius_degrees = radius_meters / 111320.0
265
  lat = search_param.lat
266
  lng = search_param.lng
267
  min_lat = lat - radius_degrees
268
  max_lat = lat + radius_degrees
269
  min_lng = lng - radius_degrees
270
  max_lng = lng + radius_degrees
271
- query = query.where(and_(
272
- Address.latitude >= min_lat,
273
- Address.latitude <= max_lat,
274
- Address.longitude >= min_lng,
275
- Address.longitude <= max_lng,
276
- ))
 
 
277
  # price filters
278
  if search_param.min_price is not None:
279
  query = query.where(Property.price >= search_param.min_price)
@@ -281,10 +279,7 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
281
  query = query.where(Property.price <= search_param.max_price)
282
  # Have review
283
  if search_param.has_review:
284
- subquery = (
285
- select(Review.id)
286
- .where(Review.property_id == Property.id)
287
- )
288
  query = query.where(exists(subquery))
289
  # categorical
290
  if search_param.property_category:
@@ -392,16 +387,12 @@ class PropertyService(SQLAlchemyAsyncRepositoryService[Property]):
392
 
393
  async def _compute_user_embedding(self, user_id: uuid.UUID) -> list[float]:
394
  user_action_repository = UserActionRepository(session=self.repository.session)
395
- properties_action = await user_action_repository.get_relevant_properties(
396
  user_id=user_id
397
  )
398
- if len(properties_action) < 5:
399
- return next(
400
- iter(property_index.fetch(["0"], namespace="Mean").vectors.values())
401
- ).values
402
- result = await self.fetch_pinecone_document_by_id(
403
- [UUID(id) for id in properties_action.keys()]
404
- )
405
  vectors = [value.values for value in result.values()]
406
  mean_vector = np.mean(vectors, axis=0).tolist()
407
  return mean_vector
 
199
  pagination: LimitOffset,
200
  user_id: uuid.UUID,
201
  ) -> CursorPagination[str, Property]:
 
202
  meta_filter = self._build_pinecone_filter(search_param)
 
203
  user_embedding = await self._compute_user_embedding(user_id)
 
204
  pine_res = property_index.query(
205
  vector=user_embedding,
206
  filter=meta_filter,
207
  top_k=pagination.limit,
208
+ include_metadata=False,
 
209
  )
210
  ids = [m["id"] for m in pine_res["matches"]]
211
  props = await self._fetch_properties_from_ids(ids)
 
257
  if search_param.lat is not None and search_param.lng is not None:
258
  query = query.join(Property.address)
259
  radius_meters = search_param.radius * 1000
260
+ radius_degrees = radius_meters / 111320.0
261
  lat = search_param.lat
262
  lng = search_param.lng
263
  min_lat = lat - radius_degrees
264
  max_lat = lat + radius_degrees
265
  min_lng = lng - radius_degrees
266
  max_lng = lng + radius_degrees
267
+ query = query.where(
268
+ and_(
269
+ Address.latitude >= min_lat,
270
+ Address.latitude <= max_lat,
271
+ Address.longitude >= min_lng,
272
+ Address.longitude <= max_lng,
273
+ )
274
+ )
275
  # price filters
276
  if search_param.min_price is not None:
277
  query = query.where(Property.price >= search_param.min_price)
 
279
  query = query.where(Property.price <= search_param.max_price)
280
  # Have review
281
  if search_param.has_review:
282
+ subquery = select(Review.id).where(Review.property_id == Property.id)
 
 
 
283
  query = query.where(exists(subquery))
284
  # categorical
285
  if search_param.property_category:
 
387
 
388
  async def _compute_user_embedding(self, user_id: uuid.UUID) -> list[float]:
389
  user_action_repository = UserActionRepository(session=self.repository.session)
390
+ property_id_list = await user_action_repository.get_relevant_properties(
391
  user_id=user_id
392
  )
393
+ if len(property_id_list) == 0:
394
+ return next(iter(property_index.fetch(["0"]).vectors.values())).values
395
+ result = await self.fetch_pinecone_document_by_id(property_id_list)
 
 
 
 
396
  vectors = [value.values for value in result.values()]
397
  mean_vector = np.mean(vectors, axis=0).tolist()
398
  return mean_vector
app/domains/user_action/service.py CHANGED
@@ -1,5 +1,6 @@
1
  from collections import defaultdict
2
  from collections.abc import AsyncGenerator
 
3
  import uuid
4
 
5
  from sqlalchemy import select
@@ -11,29 +12,28 @@ from sqlalchemy.ext.asyncio import AsyncSession
11
 
12
  class UserActionRepository(SQLAlchemyAsyncRepository[UserAction]):
13
  model_type = UserAction
14
- async def get_relevant_properties(self, user_id: uuid.UUID) -> dict:
 
15
  prop_ids_subq = (
16
  select(UserAction.property_id)
17
  .where(UserAction.user_id == user_id)
 
18
  .distinct()
19
  .limit(10)
20
  ).subquery()
21
 
22
- # Step 2: fetch all actions for those properties
23
  result = await self.session.execute(
24
  select(UserAction)
25
  .where(
26
  UserAction.user_id == user_id,
27
- UserAction.property_id.in_(select(prop_ids_subq))
28
  )
29
  .order_by(UserAction.property_id, UserAction.created_at)
30
  )
31
  actions = result.scalars().all()
 
 
32
 
33
- grouped: dict = {}
34
- for act in actions:
35
- grouped[str(act.property_id)].append(act)
36
- return grouped
37
  class UserActionService(SQLAlchemyAsyncRepositoryService[UserAction]):
38
  repository_type = UserActionRepository
39
 
 
1
  from collections import defaultdict
2
  from collections.abc import AsyncGenerator
3
+ from typing import List
4
  import uuid
5
 
6
  from sqlalchemy import select
 
12
 
13
  class UserActionRepository(SQLAlchemyAsyncRepository[UserAction]):
14
  model_type = UserAction
15
+
16
+ async def get_relevant_properties(self, user_id: uuid.UUID) -> List[uuid.UUID]:
17
  prop_ids_subq = (
18
  select(UserAction.property_id)
19
  .where(UserAction.user_id == user_id)
20
+ .where(UserAction.action == "view")
21
  .distinct()
22
  .limit(10)
23
  ).subquery()
24
 
 
25
  result = await self.session.execute(
26
  select(UserAction)
27
  .where(
28
  UserAction.user_id == user_id,
29
+ UserAction.property_id.in_(select(prop_ids_subq)),
30
  )
31
  .order_by(UserAction.property_id, UserAction.created_at)
32
  )
33
  actions = result.scalars().all()
34
+ return [action.property_id for action in actions]
35
+
36
 
 
 
 
 
37
  class UserActionService(SQLAlchemyAsyncRepositoryService[UserAction]):
38
  repository_type = UserActionRepository
39
 
app/seed/factories/article.py CHANGED
@@ -13,6 +13,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
13
  from configs.gemai import client
14
  from google.genai.types import GenerateContentConfig
15
  from advanced_alchemy.utils.text import slugify
 
 
16
 
17
  safety_settings = [
18
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
@@ -26,86 +28,56 @@ safety_settings = [
26
  "threshold": "BLOCK_MEDIUM_AND_ABOVE",
27
  },
28
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  async def generate_tags_and_summary(article_html_content: str) -> dict:
32
  """
33
- Generates tags and a short description for an article using Gemini.
34
-
35
- Args:
36
- article_html_content: The HTML content of the article.
37
-
38
- Returns:
39
- A dictionary with "tags" (list of strings) and "short_description" (string).
40
- Returns empty values if generation fails.
41
- """
42
- prompt = f"""
43
- Analyze the following Vietnamese news article content (provided in HTML format) and perform two tasks:
44
- 1. Generate a concise short description (summary) of the article in Vietnamese. This description should be no more than 80 words and capture the main points.
45
- 2. Extract 3 to 7 relevant keywords (tags) for this article in Vietnamese. These tags should be single words or short phrases.
46
-
47
- Article Content:
48
- ```html
49
- {article_html_content[:15000]}
50
- ```
51
-
52
- Provide your response strictly as a JSON object with two keys: "short_description" and "tags".
53
- The "tags" value should be a list of strings.
54
- Example JSON output:
55
- {{
56
- "short_description": "Một bản tóm tắt ngắn gọn của bài báo bằng tiếng Việt...,
57
- "tags": ["bất động sản", "thị trường", "dự án mới", "Việt Nam"]
58
- }}
59
  """
 
 
 
 
60
  try:
61
- print(
62
- f"Sending content to Gemini (first 100 chars): {article_html_content[:100]}..."
63
- )
64
- response = client.models.generate_content(
65
- model="gemini-1.0-flash",
66
- contents=[
67
- prompt,
68
- ],
69
- config=GenerateContentConfig(
70
- safety_settings=safety_settings,
71
- top_p=1,
72
- temperature=0.7,
73
- max_output_tokens=2048,
74
- response_modalities=["TEXT"],
75
- ),
76
  )
77
- cleaned_response_text = response.text.strip()
78
- if cleaned_response_text.startswith("```json"):
79
- cleaned_response_text = cleaned_response_text[7:]
80
- if cleaned_response_text.endswith("```"):
81
- cleaned_response_text = cleaned_response_text[:-3]
82
- cleaned_response_text = cleaned_response_text.strip()
83
- data = json.loads(cleaned_response_text)
84
-
85
- tags = data.get("tags", [])
86
- short_desc = data.get("short_description", "")
87
-
88
- if not isinstance(tags, list):
89
- print(
90
- f"Warning: Gemini returned tags not as a list: {tags}. Using empty list."
91
- )
92
- tags = []
93
- if not isinstance(short_desc, str):
94
- print(
95
- f"Warning: Gemini returned short_description not as a string: {short_desc}. Using empty string."
96
- )
97
- short_desc = ""
98
-
99
- return {"tags": tags, "short_description": short_desc}
100
-
101
  except Exception as e:
102
- print(f"Error generating tags/summary with Gemini: {e}")
103
- print(
104
- f"Failed prompt was based on content (first 100 chars): {article_html_content[:100]}..."
105
- )
106
- if hasattr(response, "prompt_feedback") and response.prompt_feedback:
107
- print(f"Gemini Prompt Feedback: {response.prompt_feedback}")
108
- return {"tags": [], "short_description": "Không thể tạo tóm tắt."}
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  class ArticleFactory(BaseFactory):
@@ -150,7 +122,6 @@ class ArticleFactory(BaseFactory):
150
  max_tokens=10000,
151
  )
152
  text = response.choices[0].message.content.strip()
153
- print(text)
154
  articles = json.loads(text)
155
  if not isinstance(articles, list):
156
  raise ValueError("Expected a JSON list of articles.")
@@ -182,35 +153,36 @@ class ArticleFactory(BaseFactory):
182
  await import_articles_from_json(
183
  os.path.join(fixture_path, "articles.json"), session
184
  )
185
- articles_data = self.fetch_articles_from_openai(count)
186
- for article_data in articles_data:
187
- result = await session.execute(
188
- select(Article).filter_by(title=article_data.get("title"))
189
- )
190
- if result.scalars().first():
191
- continue
192
-
193
- publish_date_str = article_data.get("publish_date")
194
- try:
195
- publish_date = datetime.fromisoformat(publish_date_str)
196
- except Exception:
197
- publish_date = datetime.now(timezone.utc)
198
-
199
- tag_names = article_data.get("tags", [])
200
- tags = await self.get_or_create_tags(session, tag_names)
201
-
202
- article = Article(
203
- id=uuid.uuid4(),
204
- title=article_data.get("title"),
205
- publish_date=publish_date,
206
- content=article_data.get("content"),
207
- short_description=article_data.get("short_description"),
208
- author=article_data.get("author"),
209
- tags=tags,
210
- created_at=datetime.now(timezone.utc),
211
- updated_at=datetime.now(timezone.utc),
212
- )
213
- await self.repository(session=session).add(article)
 
214
  except Exception as e:
215
  await session.rollback()
216
  print(f"Error during ArticleFactory seeding: {e}")
@@ -223,14 +195,15 @@ class ArticleFactory(BaseFactory):
223
  await self.repository(session=session).delete_where(Article.id.is_not(None))
224
  await session.commit()
225
 
226
-
227
  def parse_vietnamese_datetime(date_str: str) -> datetime | None:
228
  """
229
- Tries to parse common Vietnamese datetime string formats.
230
  Returns a timezone-aware datetime object (UTC) or None if parsing fails.
231
  """
232
  if not date_str or not isinstance(date_str, str):
233
  return None
 
 
234
  if "T" in date_str and ("Z" in date_str or "+" in date_str or "-" in date_str[10:]):
235
  try:
236
  dt = datetime.fromisoformat(date_str)
@@ -240,6 +213,7 @@ def parse_vietnamese_datetime(date_str: str) -> datetime | None:
240
  except ValueError:
241
  pass
242
 
 
243
  formats_to_try = [
244
  "%d/%m/%Y %H:%M:%S",
245
  "%d/%m/%Y %H:%M",
@@ -247,7 +221,9 @@ def parse_vietnamese_datetime(date_str: str) -> datetime | None:
247
  "%d-%m-%Y %H:%M",
248
  "%Y-%m-%d %H:%M:%S",
249
  "%Y/%m/%d %H:%M:%S",
 
250
  ]
 
251
  for fmt in formats_to_try:
252
  try:
253
  dt_naive = datetime.strptime(date_str.strip(), fmt)
@@ -255,10 +231,10 @@ def parse_vietnamese_datetime(date_str: str) -> datetime | None:
255
  return dt_aware
256
  except ValueError:
257
  continue
 
258
  print(f"Warning: Could not parse date string: {date_str}")
259
  return None
260
 
261
-
262
  async def get_or_create_tags(session: AsyncSession, tag_names: List[str]) -> List[Tag]:
263
  """
264
  Retrieves existing Tag objects or creates new ones for each tag name.
@@ -296,10 +272,9 @@ async def process_article_data(session: AsyncSession, article_data: Dict[str, An
296
  return None
297
  gemini_data = await generate_tags_and_summary(html_content)
298
  tag_names = gemini_data.get("tags", [])
299
- # short_description = gemini_data.get("short_description")
300
-
301
- # if not short_description:
302
- short_description = "Tóm tắt không có sẵn."
303
  if not tag_names:
304
  print(f"No tags generated for article: {title}")
305
  publish_date = parse_vietnamese_datetime(published_date_str)
@@ -308,14 +283,14 @@ async def process_article_data(session: AsyncSession, article_data: Dict[str, An
308
  f"Using current time for article '{title}' due to unparseable date: {published_date_str}"
309
  )
310
  publish_date = datetime.now(timezone.utc)
311
- db_tags = await get_or_create_tags(session, tag_names)
312
  new_article = Article(
313
  title=title,
314
  publish_date=publish_date,
315
  content=html_content,
316
  short_description=short_description[:499],
317
  author=source_name,
318
- tags=db_tags,
319
  )
320
  return new_article
321
 
@@ -347,6 +322,8 @@ async def import_articles_from_json(json_filepath: str, session: AsyncSession):
347
 
348
  articles_to_add = []
349
  for i, item_data in enumerate(data_from_json):
 
 
350
  print(f"\n--- Processing item {i+1}/{len(data_from_json)} ---")
351
  article_obj = await process_article_data(session, item_data)
352
  if article_obj:
 
13
  from configs.gemai import client
14
  from google.genai.types import GenerateContentConfig
15
  from advanced_alchemy.utils.text import slugify
16
+ from transformers import pipeline
17
+ import re
18
 
19
  safety_settings = [
20
  {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
 
28
  "threshold": "BLOCK_MEDIUM_AND_ABOVE",
29
  },
30
  ]
31
+ _SUMMARY_PIPELINE = pipeline(
32
+ "summarization",
33
+ model="google/long-t5-tglobal-base",
34
+ tokenizer="google/long-t5-tglobal-base",
35
+ device=-1,
36
+ )
37
+
38
+ _KEYPHRASE_PIPELINE = pipeline(
39
+ "text2text-generation",
40
+ model="google/long-t5-tglobal-base",
41
+ tokenizer="google/long-t5-tglobal-base",
42
+ framework="pt",
43
+ device=-1,
44
+ )
45
 
46
 
47
  async def generate_tags_and_summary(article_html_content: str) -> dict:
48
  """
49
+ Summarize and extract tags using small transformer models.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
+ text = re.sub(r"<[^>]+>", " ", article_html_content)
52
+ text = re.sub(r"\s+", " ", text).strip()
53
+ if len(text) < 50:
54
+ return {"tags": [], "short_description": text}
55
  try:
56
+ summary_out = _SUMMARY_PIPELINE(
57
+ text,
58
+ max_length=200,
59
+ min_length=30,
60
+ do_sample=False,
 
 
 
 
 
 
 
 
 
 
61
  )
62
+ short_description = summary_out[0]["summary_text"].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  except Exception as e:
64
+ print(f"Summarization error: {e}")
65
+ short_description = text[:300] + ("…" if len(text) > 300 else "")
66
+ try:
67
+ prompt = "extract keyphrases: " + text[:1000] # limit length
68
+ kpop = _KEYPHRASE_PIPELINE(prompt, max_length=64, do_sample=False)
69
+ raw = kpop[0]["generated_text"]
70
+ tags = re.split(r"[;,]\s*", raw)
71
+ tags = list(dict.fromkeys([t.strip().lower() for t in tags if t.strip()]))
72
+ tags = tags[:7]
73
+ except Exception as e:
74
+ print(f"Keyphrase extraction error: {e}")
75
+ tags = []
76
+
77
+ return {
78
+ "short_description": short_description,
79
+ "tags": tags,
80
+ }
81
 
82
 
83
  class ArticleFactory(BaseFactory):
 
122
  max_tokens=10000,
123
  )
124
  text = response.choices[0].message.content.strip()
 
125
  articles = json.loads(text)
126
  if not isinstance(articles, list):
127
  raise ValueError("Expected a JSON list of articles.")
 
153
  await import_articles_from_json(
154
  os.path.join(fixture_path, "articles.json"), session
155
  )
156
+ else:
157
+ articles_data = self.fetch_articles_from_openai(count)
158
+ for article_data in articles_data:
159
+ result = await session.execute(
160
+ select(Article).filter_by(title=article_data.get("title"))
161
+ )
162
+ if result.scalars().first():
163
+ continue
164
+
165
+ publish_date_str = article_data.get("publish_date")
166
+ try:
167
+ publish_date = datetime.fromisoformat(publish_date_str)
168
+ except Exception:
169
+ publish_date = datetime.now(timezone.utc)
170
+
171
+ tag_names = article_data.get("tags", [])
172
+ tags = await self.get_or_create_tags(session, tag_names)
173
+
174
+ article = Article(
175
+ id=uuid.uuid4(),
176
+ title=article_data.get("title"),
177
+ publish_date=publish_date,
178
+ content=article_data.get("content"),
179
+ short_description=article_data.get("short_description"),
180
+ author=article_data.get("author"),
181
+ tags=tags,
182
+ created_at=datetime.now(timezone.utc),
183
+ updated_at=datetime.now(timezone.utc),
184
+ )
185
+ await self.repository(session=session).add(article)
186
  except Exception as e:
187
  await session.rollback()
188
  print(f"Error during ArticleFactory seeding: {e}")
 
195
  await self.repository(session=session).delete_where(Article.id.is_not(None))
196
  await session.commit()
197
 
 
198
  def parse_vietnamese_datetime(date_str: str) -> datetime | None:
199
  """
200
+ Tries to parse common Vietnamese datetime string formats, including RFC 1123.
201
  Returns a timezone-aware datetime object (UTC) or None if parsing fails.
202
  """
203
  if not date_str or not isinstance(date_str, str):
204
  return None
205
+
206
+ # First: handle ISO8601 with 'T' and timezone info
207
  if "T" in date_str and ("Z" in date_str or "+" in date_str or "-" in date_str[10:]):
208
  try:
209
  dt = datetime.fromisoformat(date_str)
 
213
  except ValueError:
214
  pass
215
 
216
+ # Try known formats, including RFC 1123
217
  formats_to_try = [
218
  "%d/%m/%Y %H:%M:%S",
219
  "%d/%m/%Y %H:%M",
 
221
  "%d-%m-%Y %H:%M",
222
  "%Y-%m-%d %H:%M:%S",
223
  "%Y/%m/%d %H:%M:%S",
224
+ "%a, %d %b %Y %H:%M:%S GMT", # RFC 1123 (e.g., "Sun, 01 Jun 2025 01:16:00 GMT")
225
  ]
226
+
227
  for fmt in formats_to_try:
228
  try:
229
  dt_naive = datetime.strptime(date_str.strip(), fmt)
 
231
  return dt_aware
232
  except ValueError:
233
  continue
234
+
235
  print(f"Warning: Could not parse date string: {date_str}")
236
  return None
237
 
 
238
  async def get_or_create_tags(session: AsyncSession, tag_names: List[str]) -> List[Tag]:
239
  """
240
  Retrieves existing Tag objects or creates new ones for each tag name.
 
272
  return None
273
  gemini_data = await generate_tags_and_summary(html_content)
274
  tag_names = gemini_data.get("tags", [])
275
+ short_description = gemini_data.get("short_description")
276
+ if not short_description:
277
+ short_description = "Tóm tắt không có sẵn."
 
278
  if not tag_names:
279
  print(f"No tags generated for article: {title}")
280
  publish_date = parse_vietnamese_datetime(published_date_str)
 
283
  f"Using current time for article '{title}' due to unparseable date: {published_date_str}"
284
  )
285
  publish_date = datetime.now(timezone.utc)
286
+ # db_tags = await get_or_create_tags(session, tag_names)
287
  new_article = Article(
288
  title=title,
289
  publish_date=publish_date,
290
  content=html_content,
291
  short_description=short_description[:499],
292
  author=source_name,
293
+ tags=[],
294
  )
295
  return new_article
296
 
 
322
 
323
  articles_to_add = []
324
  for i, item_data in enumerate(data_from_json):
325
+ if i > 2:
326
+ break
327
  print(f"\n--- Processing item {i+1}/{len(data_from_json)} ---")
328
  article_obj = await process_article_data(session, item_data)
329
  if article_obj: