Junhoee commited on
Commit
06b2015
ยท
verified ยท
1 Parent(s): 558f57f

Upload 6 files

Browse files
megumin_agent/agent.py CHANGED
@@ -127,7 +127,7 @@ root_agent = LlmAgent(
127
  ์ด tool์€ ์Šคํƒ€์ผ/ํŽ˜๋ฅด์†Œ๋‚˜์šฉ ์‚ฌ๋ก€ top-3์™€ ์‚ฌ์‹ค/์„ค์ •์šฉ ์‚ฌ๋ก€ top-3๋ฅผ 5:5 ๋น„์ค‘์œผ๋กœ ํ•จ๊ป˜ ๋Œ๋ ค์ค๋‹ˆ๋‹ค.
128
  persona_matches๋Š” ๋ฉ”๊ตฌ๋ฐ์˜ ์„ฑ๊ฒฉ, ๋งํˆฌ, ๊ฐ์ •์„ , ๋‹ต๋ณ€ ๋ฆฌ๋“ฌ์„ ์ฐธ๊ณ ํ•˜๋Š” ์šฉ๋„์ž…๋‹ˆ๋‹ค.
129
  fact_matches๋Š” ์„ค์ •, ๊ด€๊ณ„, ์‚ฌ๊ฑด, ์„ธ๊ณ„๊ด€ ์‚ฌ์‹ค์„ ์ฐธ๊ณ ํ•˜๋Š” ์šฉ๋„์ž…๋‹ˆ๋‹ค.
130
- ๋‘ ์ข…๋ฅ˜์˜ ์‚ฌ๋ก€๋ฅผ ๋ชจ๋‘ ์ฐธ๊ณ ํ•˜๋˜, ๊ฒ€์ƒ‰๋œ ๋‹ต๋ณ€์„ ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋งˆ์„ธ์š”.
131
  ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์•ฝํ•˜๊ฑฐ๋‚˜ ์—†๋Š” ๊ฒฝ์šฐ์—๋„ ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋Š” ์œ ์ง€ํ•˜๋˜, ๋ชจ๋ฅด๋Š” ๋‚ด์šฉ์€ ์ง€์–ด๋‚ด์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋‹ตํ•˜์„ธ์š”.
132
  ์ตœ์ข… ๋‹ต๋ณ€์€ ์–ธ์ œ๋‚˜ ๋ฉ”๊ตฌ๋ฐ์˜ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ํ•ด์•ผ ํ•˜๋ฉฐ, ๋‚ด๋ถ€ tool ์ด๋ฆ„์ด๋‚˜ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์€ ๋“œ๋Ÿฌ๋‚ด์ง€ ๋งˆ์„ธ์š”.
133
  """.strip(),
 
127
  ์ด tool์€ ์Šคํƒ€์ผ/ํŽ˜๋ฅด์†Œ๋‚˜์šฉ ์‚ฌ๋ก€ top-3์™€ ์‚ฌ์‹ค/์„ค์ •์šฉ ์‚ฌ๋ก€ top-3๋ฅผ 5:5 ๋น„์ค‘์œผ๋กœ ํ•จ๊ป˜ ๋Œ๋ ค์ค๋‹ˆ๋‹ค.
128
  persona_matches๋Š” ๋ฉ”๊ตฌ๋ฐ์˜ ์„ฑ๊ฒฉ, ๋งํˆฌ, ๊ฐ์ •์„ , ๋‹ต๋ณ€ ๋ฆฌ๋“ฌ์„ ์ฐธ๊ณ ํ•˜๋Š” ์šฉ๋„์ž…๋‹ˆ๋‹ค.
129
  fact_matches๋Š” ์„ค์ •, ๊ด€๊ณ„, ์‚ฌ๊ฑด, ์„ธ๊ณ„๊ด€ ์‚ฌ์‹ค์„ ์ฐธ๊ณ ํ•˜๋Š” ์šฉ๋„์ž…๋‹ˆ๋‹ค.
130
+ ๋‘ ์ข…๋ฅ˜์˜ ์‚ฌ๋ก€๋ฅผ ๋ชจ๋‘ ์ฐธ๊ณ ํ•˜๋˜ ๊ฒ€์ƒ‰๋œ ๋‹ต๋ณ€์„ ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌํ•˜์ง€ ๋งˆ์„ธ์š”.
131
  ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๊ฐ€ ์•ฝํ•˜๊ฑฐ๋‚˜ ์—†๋Š” ๊ฒฝ์šฐ์—๋„ ๋ฉ”๊ตฌ๋ฐ ํŽ˜๋ฅด์†Œ๋‚˜๋Š” ์œ ์ง€ํ•˜๋˜, ๋ชจ๋ฅด๋Š” ๋‚ด์šฉ์€ ์ง€์–ด๋‚ด์ง€ ๋ง๊ณ  ์†”์งํ•˜๊ฒŒ ๋‹ตํ•˜์„ธ์š”.
132
  ์ตœ์ข… ๋‹ต๋ณ€์€ ์–ธ์ œ๋‚˜ ๋ฉ”๊ตฌ๋ฐ์˜ ํŽ˜๋ฅด์†Œ๋‚˜๋ฅผ ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ํ•ด์•ผ ํ•˜๋ฉฐ, ๋‚ด๋ถ€ tool ์ด๋ฆ„์ด๋‚˜ ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ์€ ๋“œ๋Ÿฌ๋‚ด์ง€ ๋งˆ์„ธ์š”.
133
  """.strip(),
megumin_agent/bootstrap.py CHANGED
@@ -1,531 +1,94 @@
1
  from __future__ import annotations
2
 
3
- import json
4
- import math
5
  import os
6
- import re
7
- import unicodedata
8
- from dataclasses import dataclass
9
- from functools import lru_cache
10
  from pathlib import Path
11
- from typing import Any
12
- from typing import Iterable
13
 
14
- import faiss
15
- import numpy as np
16
- from google import genai
17
- from google.genai import types
18
 
19
 
20
- QUESTION_KEYS = (
21
- "question",
22
- "query",
23
- "q",
24
- "prompt",
25
- "user",
26
- "instruction",
27
- "input",
28
- )
29
- ANSWER_KEYS = (
30
- "answer",
31
- "response",
32
- "a",
33
- "output",
34
- "assistant",
35
- "completion",
36
- )
37
- COLLECTION_KEYS = ("items", "data", "examples", "dataset", "records")
38
- EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-001")
39
- EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
40
- EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
41
- FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
42
- FAISS_QA_INDEX_FILENAME = os.getenv(
43
- "MEGUMIN_FAISS_QA_INDEX_FILENAME",
44
- "megumin_question_answer.faiss",
45
- )
46
- FAISS_METADATA_FILENAME = os.getenv(
47
- "MEGUMIN_FAISS_METADATA_FILENAME",
48
- "megumin_questions_meta.json",
49
- )
50
- PERSONA_DATASET_PATTERNS = ("megumin_qa_dataset.json",)
51
- FACT_DATASET_PATTERNS = ("namuwiki*.json",)
52
 
53
 
54
- def _normalize_text(value: Any) -> str:
55
- text = str(value or "")
56
- text = unicodedata.normalize("NFKC", text).strip()
57
- text = re.sub(r"\s+", " ", text)
58
- return text
59
 
60
 
61
- def _safe_excerpt(text: str, limit: int = 220) -> str:
62
- compact = re.sub(r"\s+", " ", str(text or "")).strip()
63
- if len(compact) <= limit:
64
- return compact
65
- return compact[: limit - 3].rstrip() + "..."
66
 
67
 
68
- def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]:
69
- normalized = tuple(pattern.strip() for pattern in (patterns or ()) if pattern.strip())
70
- return normalized
71
 
72
 
73
- def _record_search_text(record: "QaRecord", mode: str) -> str:
74
- if mode == "question_answer":
75
- return f"{record.question}\n{record.answer}".strip()
76
- return record.question
77
 
78
 
79
- @dataclass(frozen=True)
80
- class QaRecord:
81
- question: str
82
- answer: str
83
- source_file: str
84
- metadata: dict[str, Any]
85
 
86
- @property
87
- def normalized_question(self) -> str:
88
- return _normalize_text(self.question)
89
 
 
 
90
 
91
- @dataclass(frozen=True)
92
- class VectorStore:
93
- records: tuple[QaRecord, ...]
94
- index: faiss.Index
95
- embedding_model: str
96
- dimension: int
97
 
 
 
98
 
99
- def _extract_collection(payload: Any) -> list[Any]:
100
- if isinstance(payload, list):
101
- return payload
102
- if isinstance(payload, dict):
103
- for key in COLLECTION_KEYS:
104
- value = payload.get(key)
105
- if isinstance(value, list):
106
- return value
107
- return []
108
 
 
 
109
 
110
- def _pick_first(mapping: dict[str, Any], keys: tuple[str, ...]) -> str | None:
111
- lowered = {str(key).lower(): value for key, value in mapping.items()}
112
- for key in keys:
113
- if key in lowered and lowered[key] not in (None, ""):
114
- return str(lowered[key]).strip()
115
- return None
116
 
 
 
117
 
118
- def _record_from_mapping(item: dict[str, Any], source_file: str) -> QaRecord | None:
119
- question = _pick_first(item, QUESTION_KEYS)
120
- answer = _pick_first(item, ANSWER_KEYS)
121
- if not question or not answer:
122
- return None
123
 
124
- metadata = {
125
- key: value
126
- for key, value in item.items()
127
- if str(key).lower() not in QUESTION_KEYS + ANSWER_KEYS
128
- }
129
- return QaRecord(
130
- question=question,
131
- answer=answer,
132
- source_file=source_file,
133
- metadata=metadata,
134
- )
135
 
136
 
137
- def _load_json_records(path: Path) -> list[QaRecord]:
138
- raw_text = path.read_text(encoding="utf-8")
139
- stripped = raw_text.strip()
140
- if not stripped:
141
- return []
142
-
143
- records: list[QaRecord] = []
144
 
145
  try:
146
- payload = json.loads(stripped)
147
- except json.JSONDecodeError:
148
- payload = None
149
-
150
- if payload is not None:
151
- for item in _extract_collection(payload):
152
- if isinstance(item, dict):
153
- record = _record_from_mapping(item, path.name)
154
- if record:
155
- records.append(record)
156
- if records:
157
- return records
158
-
159
- for line in stripped.splitlines():
160
- line = line.strip()
161
- if not line:
162
- continue
163
- try:
164
- item = json.loads(line)
165
- except json.JSONDecodeError:
166
- continue
167
- if isinstance(item, dict):
168
- record = _record_from_mapping(item, path.name)
169
- if record:
170
- records.append(record)
171
-
172
- return records
173
-
174
-
175
- def _load_metadata_records(path: Path) -> tuple[QaRecord, ...]:
176
- payload = json.loads(path.read_text(encoding="utf-8"))
177
- records: list[QaRecord] = []
178
- for item in _extract_collection(payload):
179
- if isinstance(item, dict):
180
- record = _record_from_mapping(item, path.name)
181
- if record:
182
- records.append(record)
183
- return tuple(records)
184
-
185
-
186
- def _iter_matching_paths(root: Path, include_patterns: tuple[str, ...]) -> list[Path]:
187
- if not include_patterns:
188
- return sorted(root.glob("*.json"))
189
-
190
- seen: set[Path] = set()
191
- paths: list[Path] = []
192
- for pattern in include_patterns:
193
- for path in sorted(root.glob(pattern)):
194
- if path in seen or path.suffix.lower() != ".json":
195
- continue
196
- seen.add(path)
197
- paths.append(path)
198
- return paths
199
-
200
-
201
- @lru_cache(maxsize=16)
202
- def _load_records(dataset_dir: str, include_patterns: tuple[str, ...] = ()) -> tuple[QaRecord, ...]:
203
- root = Path(dataset_dir)
204
- if not root.exists():
205
- return tuple()
206
-
207
- all_records: list[QaRecord] = []
208
- for path in _iter_matching_paths(root, include_patterns):
209
- try:
210
- all_records.extend(_load_json_records(path))
211
- except OSError:
212
- continue
213
- except UnicodeDecodeError:
214
- continue
215
- return tuple(all_records)
216
-
217
-
218
- @lru_cache(maxsize=2)
219
- def _get_genai_client() -> genai.Client:
220
- return genai.Client()
221
-
222
-
223
- def _embed_texts(
224
- texts: list[str],
225
- *,
226
- task_type: str,
227
- embedding_model: str,
228
- output_dimensionality: int,
229
- ) -> np.ndarray:
230
- if not texts:
231
- return np.zeros((0, output_dimensionality), dtype="float32")
232
-
233
- batches: list[np.ndarray] = []
234
- batch_size = max(1, min(EMBEDDING_BATCH_SIZE, 100))
235
- for start in range(0, len(texts), batch_size):
236
- chunk = texts[start : start + batch_size]
237
- response = _get_genai_client().models.embed_content(
238
- model=embedding_model,
239
- contents=chunk,
240
- config=types.EmbedContentConfig(
241
- task_type=task_type,
242
- output_dimensionality=output_dimensionality,
243
- ),
244
  )
245
- vectors = np.array(
246
- [embedding.values for embedding in response.embeddings],
247
- dtype="float32",
248
- )
249
- if vectors.size == 0:
250
- continue
251
- faiss.normalize_L2(vectors)
252
- batches.append(vectors)
253
-
254
- if not batches:
255
- return np.zeros((0, output_dimensionality), dtype="float32")
256
- return np.vstack(batches)
257
-
258
-
259
- def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]:
260
- root = Path(dataset_dir)
261
- return (
262
- root / FAISS_INDEX_FILENAME,
263
- root / FAISS_METADATA_FILENAME,
264
- )
265
-
266
-
267
- def _build_index_from_records(
268
- records: tuple[QaRecord, ...],
269
- *,
270
- embedding_model: str,
271
- output_dimensionality: int,
272
- mode: str,
273
- ) -> faiss.IndexFlatIP:
274
- search_texts = [_record_search_text(record, mode) for record in records]
275
- vectors = _embed_texts(
276
- search_texts,
277
- task_type="RETRIEVAL_DOCUMENT",
278
- embedding_model=embedding_model,
279
- output_dimensionality=output_dimensionality,
280
- )
281
- if vectors.size == 0:
282
- raise RuntimeError("No embeddings were generated for the dataset records.")
283
-
284
- index = faiss.IndexFlatIP(int(vectors.shape[1]))
285
- index.add(vectors)
286
- return index
287
-
288
-
289
- def build_and_save_faiss_index(
290
- dataset_dir: str | Path,
291
- *,
292
- embedding_model: str = EMBEDDING_MODEL_NAME,
293
- output_dimensionality: int = EMBEDDING_DIMENSION,
294
- index_filename: str = FAISS_INDEX_FILENAME,
295
- qa_index_filename: str = FAISS_QA_INDEX_FILENAME,
296
- metadata_filename: str = FAISS_METADATA_FILENAME,
297
- include_patterns: Iterable[str] | None = None,
298
- ) -> tuple[Path, Path, Path]:
299
- root = Path(dataset_dir)
300
- records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
301
- if not records:
302
- raise FileNotFoundError(f"No JSON records found under {root}")
303
-
304
- question_index = _build_index_from_records(
305
- records,
306
- embedding_model=embedding_model,
307
- output_dimensionality=output_dimensionality,
308
- mode="question",
309
- )
310
- qa_index = _build_index_from_records(
311
- records,
312
- embedding_model=embedding_model,
313
- output_dimensionality=output_dimensionality,
314
- mode="question_answer",
315
- )
316
- index_path = root / index_filename
317
- qa_index_path = root / qa_index_filename
318
- metadata_path = root / metadata_filename
319
- faiss.write_index(question_index, str(index_path))
320
- faiss.write_index(qa_index, str(qa_index_path))
321
- metadata_payload = {
322
- "items": [
323
- {
324
- "question": record.question,
325
- "answer": record.answer,
326
- "source_file": record.source_file,
327
- **record.metadata,
328
- }
329
- for record in records
330
- ]
331
- }
332
- metadata_path.write_text(
333
- json.dumps(metadata_payload, ensure_ascii=False, indent=2),
334
- encoding="utf-8",
335
- )
336
- return index_path, qa_index_path, metadata_path
337
-
338
-
339
- @lru_cache(maxsize=8)
340
- def _load_vector_store(
341
- dataset_dir: str,
342
- embedding_model: str,
343
- output_dimensionality: int,
344
- include_patterns: tuple[str, ...] = (),
345
- index_filename: str | None = FAISS_INDEX_FILENAME,
346
- qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
347
- metadata_filename: str | None = FAISS_METADATA_FILENAME,
348
- mode: str = "question",
349
- ) -> VectorStore:
350
- selected_index_filename = index_filename if mode == "question" else qa_index_filename
351
- if selected_index_filename and metadata_filename:
352
- index_path = Path(dataset_dir) / selected_index_filename
353
- metadata_path = Path(dataset_dir) / metadata_filename
354
- else:
355
- index_path = metadata_path = None
356
-
357
- if index_path and metadata_path and index_path.exists() and metadata_path.exists():
358
- index = faiss.read_index(str(index_path))
359
- records = _load_metadata_records(metadata_path)
360
- if index.ntotal != len(records):
361
- raise ValueError(
362
- f"FAISS index size ({index.ntotal}) does not match metadata size ({len(records)})."
363
- )
364
- return VectorStore(
365
- records=records,
366
- index=index,
367
- embedding_model=embedding_model,
368
- dimension=index.d,
369
- )
370
-
371
- records = _load_records(dataset_dir, include_patterns)
372
- if not records:
373
- empty_index = faiss.IndexFlatIP(output_dimensionality)
374
- return VectorStore(
375
- records=tuple(),
376
- index=empty_index,
377
- embedding_model=embedding_model,
378
- dimension=output_dimensionality,
379
- )
380
-
381
- index = _build_index_from_records(
382
- records,
383
- embedding_model=embedding_model,
384
- output_dimensionality=output_dimensionality,
385
- mode=mode,
386
- )
387
- return VectorStore(
388
- records=records,
389
- index=index,
390
- embedding_model=embedding_model,
391
- dimension=index.d,
392
- )
393
-
394
-
395
- class JsonQaRetriever:
396
- def __init__(
397
- self,
398
- dataset_dir: str | Path,
399
- *,
400
- embedding_model: str = EMBEDDING_MODEL_NAME,
401
- output_dimensionality: int = EMBEDDING_DIMENSION,
402
- include_patterns: Iterable[str] | None = None,
403
- index_filename: str | None = FAISS_INDEX_FILENAME,
404
- qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
405
- metadata_filename: str | None = FAISS_METADATA_FILENAME,
406
- ):
407
- self.dataset_dir = Path(dataset_dir)
408
- self.embedding_model = embedding_model
409
- self.output_dimensionality = output_dimensionality
410
- self.include_patterns = _normalize_patterns(include_patterns)
411
- self.index_filename = index_filename
412
- self.qa_index_filename = qa_index_filename
413
- self.metadata_filename = metadata_filename
414
-
415
- def warmup(self) -> None:
416
- _load_vector_store(
417
- str(self.dataset_dir.resolve()),
418
- self.embedding_model,
419
- self.output_dimensionality,
420
- self.include_patterns,
421
- self.index_filename,
422
- self.qa_index_filename,
423
- self.metadata_filename,
424
- "question",
425
- )
426
- _load_vector_store(
427
- str(self.dataset_dir.resolve()),
428
- self.embedding_model,
429
- self.output_dimensionality,
430
- self.include_patterns,
431
- self.index_filename,
432
- self.qa_index_filename,
433
- self.metadata_filename,
434
- "question_answer",
435
- )
436
-
437
- def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
438
- if not matches:
439
- return [
440
- "No strong example was retrieved, so stay in Megumin's persona without inventing unsupported canon facts.",
441
- ]
442
-
443
- notes = [
444
- "Answer in first person as Megumin, with respectful but dramatic confidence.",
445
- "Use the retrieved cases to mirror tone and answer shape, but do not copy them verbatim.",
446
- "Prefer the retrieved answers as evidence for facts, relationships, and recurring phrasing.",
447
- ]
448
-
449
- long_answers = sum(
450
- 1 for match in matches if len(match.get("answer", "")) >= 180
451
- )
452
- if long_answers >= max(1, math.ceil(len(matches) / 2)):
453
- notes.append(
454
- "The retrieved examples skew narrative, so a short anecdotal lead-in is acceptable."
455
- )
456
- else:
457
- notes.append(
458
- "The retrieved examples are compact, so keep the answer concise and pointed."
459
- )
460
- return notes
461
-
462
- def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
463
- question_store = _load_vector_store(
464
- str(self.dataset_dir.resolve()),
465
- self.embedding_model,
466
- self.output_dimensionality,
467
- self.include_patterns,
468
- self.index_filename,
469
- self.qa_index_filename,
470
- self.metadata_filename,
471
- "question",
472
- )
473
- qa_store = _load_vector_store(
474
- str(self.dataset_dir.resolve()),
475
- self.embedding_model,
476
- self.output_dimensionality,
477
- self.include_patterns,
478
- self.index_filename,
479
- self.qa_index_filename,
480
- self.metadata_filename,
481
- "question_answer",
482
- )
483
- if not question_store.records:
484
- return {
485
- "query": query,
486
- "match_count": 0,
487
- "matches": [],
488
- "style_notes": [
489
- "No processed JSON dataset was found for retrieval.",
490
- ],
491
- }
492
-
493
- query_vector = _embed_texts(
494
- [_normalize_text(query) or query],
495
- task_type="RETRIEVAL_QUERY",
496
- embedding_model=question_store.embedding_model,
497
- output_dimensionality=question_store.dimension,
498
- )
499
- search_k = max(1, min(top_k, len(question_store.records)))
500
-
501
- candidates: dict[int, dict[str, Any]] = {}
502
- for store_name, store in (("question", question_store), ("question_answer", qa_store)):
503
- scores, indices = store.index.search(query_vector, search_k)
504
- for score, index in zip(scores[0], indices[0]):
505
- if index < 0:
506
  continue
507
- record = store.records[int(index)]
508
- current = candidates.get(int(index))
509
- score_value = round(float(score), 6)
510
- if current is None or score_value > current["score"]:
511
- candidates[int(index)] = {
512
- "question": record.question,
513
- "answer": _safe_excerpt(record.answer),
514
- "score": score_value,
515
- "source_file": record.source_file,
516
- "metadata": record.metadata,
517
- "matched_via": store_name,
518
- }
519
-
520
- matches = sorted(
521
- candidates.values(),
522
- key=lambda item: item["score"],
523
- reverse=True,
524
- )[:top_k]
525
-
526
- return {
527
- "query": query,
528
- "match_count": len(matches),
529
- "matches": matches,
530
- "style_notes": self._style_notes(matches),
531
- }
 
1
  from __future__ import annotations
2
 
 
 
3
  import os
4
+ import sys
 
 
 
5
  from pathlib import Path
 
 
6
 
7
+ from dotenv import load_dotenv
8
+ from huggingface_hub import hf_hub_download
 
 
9
 
10
 
11
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
12
+ ADK_SRC = PROJECT_ROOT / "adk-python" / "src"
13
+ LOCAL_DATASET_DIR = PROJECT_ROOT / "data" / "processed"
14
+ RUNTIME_DATASET_DIR = PROJECT_ROOT / "data" / "_runtime_processed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
+ def _dataset_repo_id() -> str:
18
+ return os.getenv("MEGUMIN_HF_DATASET_REPO_ID", "Junhoee/megumin-chat")
 
 
 
19
 
20
 
21
+ def _dataset_filename() -> str:
22
+ return os.getenv("MEGUMIN_HF_DATASET_FILENAME", "megumin_qa_dataset.json")
 
 
 
23
 
24
 
25
+ def _index_filename() -> str:
26
+ return os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
 
27
 
28
 
29
+ def _qa_index_filename() -> str:
30
+ return os.getenv("MEGUMIN_FAISS_QA_INDEX_FILENAME", "megumin_question_answer.faiss")
 
 
31
 
32
 
33
+ def _metadata_filename() -> str:
34
+ return os.getenv("MEGUMIN_FAISS_METADATA_FILENAME", "megumin_questions_meta.json")
 
 
 
 
35
 
 
 
 
36
 
37
+ def _fact_dataset_filename() -> str:
38
+ return os.getenv("MEGUMIN_HF_FACT_DATASET_FILENAME", "namuwiki_qa.json")
39
 
 
 
 
 
 
 
40
 
41
+ def _fact_index_filename() -> str:
42
+ return os.getenv("MEGUMIN_HF_FACT_INDEX_FILENAME", "namuwiki_questions.faiss")
43
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def _fact_qa_index_filename() -> str:
46
+ return os.getenv("MEGUMIN_HF_FACT_QA_INDEX_FILENAME", "namuwiki_question_answer.faiss")
47
 
 
 
 
 
 
 
48
 
49
+ def _fact_metadata_filename() -> str:
50
+ return os.getenv("MEGUMIN_HF_FACT_METADATA_FILENAME", "namuwiki_questions_meta.json")
51
 
 
 
 
 
 
52
 
53
+ def bootstrap_environment() -> None:
54
+ load_dotenv(PROJECT_ROOT / ".env", override=True)
55
+ if ADK_SRC.exists():
56
+ adk_src = str(ADK_SRC)
57
+ if adk_src not in sys.path:
58
+ sys.path.insert(0, adk_src)
 
 
 
 
 
59
 
60
 
61
+ def resolve_dataset_dir() -> Path:
62
+ RUNTIME_DATASET_DIR.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
63
 
64
  try:
65
+ hf_token = os.getenv("HF_TOKEN") or None
66
+ repo_id = _dataset_repo_id()
67
+ artifact_names = (
68
+ _dataset_filename(),
69
+ _index_filename(),
70
+ _qa_index_filename(),
71
+ _metadata_filename(),
72
+ _fact_dataset_filename(),
73
+ _fact_index_filename(),
74
+ _fact_qa_index_filename(),
75
+ _fact_metadata_filename(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
+ for artifact_name in artifact_names:
78
+ try:
79
+ hf_hub_download(
80
+ repo_id=repo_id,
81
+ repo_type="dataset",
82
+ filename=artifact_name,
83
+ token=hf_token,
84
+ local_dir=str(RUNTIME_DATASET_DIR),
85
+ )
86
+ except Exception:
87
+ if artifact_name not in {_dataset_filename(), _fact_dataset_filename()}:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  continue
89
+ raise
90
+ return RUNTIME_DATASET_DIR
91
+ except Exception:
92
+ if LOCAL_DATASET_DIR.exists() and any(LOCAL_DATASET_DIR.glob("*.json")):
93
+ return LOCAL_DATASET_DIR
94
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megumin_agent/retrieval.py CHANGED
@@ -39,6 +39,10 @@ EMBEDDING_MODEL_NAME = os.getenv("MEGUMIN_EMBEDDING_MODEL", "gemini-embedding-00
39
  EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
40
  EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
41
  FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
 
 
 
 
42
  FAISS_METADATA_FILENAME = os.getenv(
43
  "MEGUMIN_FAISS_METADATA_FILENAME",
44
  "megumin_questions_meta.json",
@@ -66,6 +70,12 @@ def _normalize_patterns(patterns: Iterable[str] | None) -> tuple[str, ...]:
66
  return normalized
67
 
68
 
 
 
 
 
 
 
69
  @dataclass(frozen=True)
70
  class QaRecord:
71
  question: str
@@ -254,36 +264,60 @@ def _index_artifact_paths(dataset_dir: str | Path) -> tuple[Path, Path]:
254
  )
255
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  def build_and_save_faiss_index(
258
  dataset_dir: str | Path,
259
  *,
260
  embedding_model: str = EMBEDDING_MODEL_NAME,
261
  output_dimensionality: int = EMBEDDING_DIMENSION,
262
  index_filename: str = FAISS_INDEX_FILENAME,
 
263
  metadata_filename: str = FAISS_METADATA_FILENAME,
264
  include_patterns: Iterable[str] | None = None,
265
- ) -> tuple[Path, Path]:
266
  root = Path(dataset_dir)
267
  records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
268
  if not records:
269
  raise FileNotFoundError(f"No JSON records found under {root}")
270
 
271
- questions = [record.normalized_question or record.question for record in records]
272
- question_vectors = _embed_texts(
273
- questions,
274
- task_type="RETRIEVAL_DOCUMENT",
275
  embedding_model=embedding_model,
276
  output_dimensionality=output_dimensionality,
 
 
 
 
 
 
 
277
  )
278
- if question_vectors.size == 0:
279
- raise RuntimeError("No embeddings were generated for the dataset questions.")
280
-
281
- index = faiss.IndexFlatIP(int(question_vectors.shape[1]))
282
- index.add(question_vectors)
283
-
284
  index_path = root / index_filename
 
285
  metadata_path = root / metadata_filename
286
- faiss.write_index(index, str(index_path))
 
287
  metadata_payload = {
288
  "items": [
289
  {
@@ -299,7 +333,7 @@ def build_and_save_faiss_index(
299
  json.dumps(metadata_payload, ensure_ascii=False, indent=2),
300
  encoding="utf-8",
301
  )
302
- return index_path, metadata_path
303
 
304
 
305
  @lru_cache(maxsize=8)
@@ -309,10 +343,13 @@ def _load_vector_store(
309
  output_dimensionality: int,
310
  include_patterns: tuple[str, ...] = (),
311
  index_filename: str | None = FAISS_INDEX_FILENAME,
 
312
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
 
313
  ) -> VectorStore:
314
- if index_filename and metadata_filename:
315
- index_path = Path(dataset_dir) / index_filename
 
316
  metadata_path = Path(dataset_dir) / metadata_filename
317
  else:
318
  index_path = metadata_path = None
@@ -341,21 +378,17 @@ def _load_vector_store(
341
  dimension=output_dimensionality,
342
  )
343
 
344
- questions = [record.normalized_question or record.question for record in records]
345
- question_vectors = _embed_texts(
346
- questions,
347
- task_type="RETRIEVAL_DOCUMENT",
348
  embedding_model=embedding_model,
349
  output_dimensionality=output_dimensionality,
 
350
  )
351
- dimension = int(question_vectors.shape[1])
352
- index = faiss.IndexFlatIP(dimension)
353
- index.add(question_vectors)
354
  return VectorStore(
355
  records=records,
356
  index=index,
357
  embedding_model=embedding_model,
358
- dimension=dimension,
359
  )
360
 
361
 
@@ -368,6 +401,7 @@ class JsonQaRetriever:
368
  output_dimensionality: int = EMBEDDING_DIMENSION,
369
  include_patterns: Iterable[str] | None = None,
370
  index_filename: str | None = FAISS_INDEX_FILENAME,
 
371
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
372
  ):
373
  self.dataset_dir = Path(dataset_dir)
@@ -375,6 +409,7 @@ class JsonQaRetriever:
375
  self.output_dimensionality = output_dimensionality
376
  self.include_patterns = _normalize_patterns(include_patterns)
377
  self.index_filename = index_filename
 
378
  self.metadata_filename = metadata_filename
379
 
380
  def warmup(self) -> None:
@@ -384,7 +419,19 @@ class JsonQaRetriever:
384
  self.output_dimensionality,
385
  self.include_patterns,
386
  self.index_filename,
 
 
 
 
 
 
 
 
 
 
 
387
  self.metadata_filename,
 
388
  )
389
 
390
  def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
@@ -413,15 +460,27 @@ class JsonQaRetriever:
413
  return notes
414
 
415
  def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
416
- store = _load_vector_store(
417
  str(self.dataset_dir.resolve()),
418
  self.embedding_model,
419
  self.output_dimensionality,
420
  self.include_patterns,
421
  self.index_filename,
 
422
  self.metadata_filename,
 
423
  )
424
- if not store.records:
 
 
 
 
 
 
 
 
 
 
425
  return {
426
  "query": query,
427
  "match_count": 0,
@@ -434,26 +493,35 @@ class JsonQaRetriever:
434
  query_vector = _embed_texts(
435
  [_normalize_text(query) or query],
436
  task_type="RETRIEVAL_QUERY",
437
- embedding_model=store.embedding_model,
438
- output_dimensionality=store.dimension,
439
  )
440
- search_k = max(1, min(top_k, len(store.records)))
441
- scores, indices = store.index.search(query_vector, search_k)
442
-
443
- matches: list[dict[str, Any]] = []
444
- for score, index in zip(scores[0], indices[0]):
445
- if index < 0:
446
- continue
447
- record = store.records[int(index)]
448
- matches.append(
449
- {
450
- "question": record.question,
451
- "answer": _safe_excerpt(record.answer),
452
- "score": round(float(score), 6),
453
- "source_file": record.source_file,
454
- "metadata": record.metadata,
455
- }
456
- )
 
 
 
 
 
 
 
 
 
457
 
458
  return {
459
  "query": query,
 
39
  EMBEDDING_DIMENSION = int(os.getenv("MEGUMIN_EMBEDDING_DIM", "768"))
40
  EMBEDDING_BATCH_SIZE = int(os.getenv("MEGUMIN_EMBEDDING_BATCH_SIZE", "100"))
41
  FAISS_INDEX_FILENAME = os.getenv("MEGUMIN_FAISS_INDEX_FILENAME", "megumin_questions.faiss")
42
+ FAISS_QA_INDEX_FILENAME = os.getenv(
43
+ "MEGUMIN_FAISS_QA_INDEX_FILENAME",
44
+ "megumin_question_answer.faiss",
45
+ )
46
  FAISS_METADATA_FILENAME = os.getenv(
47
  "MEGUMIN_FAISS_METADATA_FILENAME",
48
  "megumin_questions_meta.json",
 
70
  return normalized
71
 
72
 
73
+ def _record_search_text(record: "QaRecord", mode: str) -> str:
74
+ if mode == "question_answer":
75
+ return f"{record.question}\n{record.answer}".strip()
76
+ return record.question
77
+
78
+
79
  @dataclass(frozen=True)
80
  class QaRecord:
81
  question: str
 
264
  )
265
 
266
 
267
+ def _build_index_from_records(
268
+ records: tuple[QaRecord, ...],
269
+ *,
270
+ embedding_model: str,
271
+ output_dimensionality: int,
272
+ mode: str,
273
+ ) -> faiss.IndexFlatIP:
274
+ search_texts = [_record_search_text(record, mode) for record in records]
275
+ vectors = _embed_texts(
276
+ search_texts,
277
+ task_type="RETRIEVAL_DOCUMENT",
278
+ embedding_model=embedding_model,
279
+ output_dimensionality=output_dimensionality,
280
+ )
281
+ if vectors.size == 0:
282
+ raise RuntimeError("No embeddings were generated for the dataset records.")
283
+
284
+ index = faiss.IndexFlatIP(int(vectors.shape[1]))
285
+ index.add(vectors)
286
+ return index
287
+
288
+
289
  def build_and_save_faiss_index(
290
  dataset_dir: str | Path,
291
  *,
292
  embedding_model: str = EMBEDDING_MODEL_NAME,
293
  output_dimensionality: int = EMBEDDING_DIMENSION,
294
  index_filename: str = FAISS_INDEX_FILENAME,
295
+ qa_index_filename: str = FAISS_QA_INDEX_FILENAME,
296
  metadata_filename: str = FAISS_METADATA_FILENAME,
297
  include_patterns: Iterable[str] | None = None,
298
+ ) -> tuple[Path, Path, Path]:
299
  root = Path(dataset_dir)
300
  records = _load_records(str(root.resolve()), _normalize_patterns(include_patterns))
301
  if not records:
302
  raise FileNotFoundError(f"No JSON records found under {root}")
303
 
304
+ question_index = _build_index_from_records(
305
+ records,
 
 
306
  embedding_model=embedding_model,
307
  output_dimensionality=output_dimensionality,
308
+ mode="question",
309
+ )
310
+ qa_index = _build_index_from_records(
311
+ records,
312
+ embedding_model=embedding_model,
313
+ output_dimensionality=output_dimensionality,
314
+ mode="question_answer",
315
  )
 
 
 
 
 
 
316
  index_path = root / index_filename
317
+ qa_index_path = root / qa_index_filename
318
  metadata_path = root / metadata_filename
319
+ faiss.write_index(question_index, str(index_path))
320
+ faiss.write_index(qa_index, str(qa_index_path))
321
  metadata_payload = {
322
  "items": [
323
  {
 
333
  json.dumps(metadata_payload, ensure_ascii=False, indent=2),
334
  encoding="utf-8",
335
  )
336
+ return index_path, qa_index_path, metadata_path
337
 
338
 
339
  @lru_cache(maxsize=8)
 
343
  output_dimensionality: int,
344
  include_patterns: tuple[str, ...] = (),
345
  index_filename: str | None = FAISS_INDEX_FILENAME,
346
+ qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
347
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
348
+ mode: str = "question",
349
  ) -> VectorStore:
350
+ selected_index_filename = index_filename if mode == "question" else qa_index_filename
351
+ if selected_index_filename and metadata_filename:
352
+ index_path = Path(dataset_dir) / selected_index_filename
353
  metadata_path = Path(dataset_dir) / metadata_filename
354
  else:
355
  index_path = metadata_path = None
 
378
  dimension=output_dimensionality,
379
  )
380
 
381
+ index = _build_index_from_records(
382
+ records,
 
 
383
  embedding_model=embedding_model,
384
  output_dimensionality=output_dimensionality,
385
+ mode=mode,
386
  )
 
 
 
387
  return VectorStore(
388
  records=records,
389
  index=index,
390
  embedding_model=embedding_model,
391
+ dimension=index.d,
392
  )
393
 
394
 
 
401
  output_dimensionality: int = EMBEDDING_DIMENSION,
402
  include_patterns: Iterable[str] | None = None,
403
  index_filename: str | None = FAISS_INDEX_FILENAME,
404
+ qa_index_filename: str | None = FAISS_QA_INDEX_FILENAME,
405
  metadata_filename: str | None = FAISS_METADATA_FILENAME,
406
  ):
407
  self.dataset_dir = Path(dataset_dir)
 
409
  self.output_dimensionality = output_dimensionality
410
  self.include_patterns = _normalize_patterns(include_patterns)
411
  self.index_filename = index_filename
412
+ self.qa_index_filename = qa_index_filename
413
  self.metadata_filename = metadata_filename
414
 
415
  def warmup(self) -> None:
 
419
  self.output_dimensionality,
420
  self.include_patterns,
421
  self.index_filename,
422
+ self.qa_index_filename,
423
+ self.metadata_filename,
424
+ "question",
425
+ )
426
+ _load_vector_store(
427
+ str(self.dataset_dir.resolve()),
428
+ self.embedding_model,
429
+ self.output_dimensionality,
430
+ self.include_patterns,
431
+ self.index_filename,
432
+ self.qa_index_filename,
433
  self.metadata_filename,
434
+ "question_answer",
435
  )
436
 
437
  def _style_notes(self, matches: list[dict[str, Any]]) -> list[str]:
 
460
  return notes
461
 
462
  def retrieve(self, query: str, top_k: int = 3) -> dict[str, Any]:
463
+ question_store = _load_vector_store(
464
  str(self.dataset_dir.resolve()),
465
  self.embedding_model,
466
  self.output_dimensionality,
467
  self.include_patterns,
468
  self.index_filename,
469
+ self.qa_index_filename,
470
  self.metadata_filename,
471
+ "question",
472
  )
473
+ qa_store = _load_vector_store(
474
+ str(self.dataset_dir.resolve()),
475
+ self.embedding_model,
476
+ self.output_dimensionality,
477
+ self.include_patterns,
478
+ self.index_filename,
479
+ self.qa_index_filename,
480
+ self.metadata_filename,
481
+ "question_answer",
482
+ )
483
+ if not question_store.records:
484
  return {
485
  "query": query,
486
  "match_count": 0,
 
493
  query_vector = _embed_texts(
494
  [_normalize_text(query) or query],
495
  task_type="RETRIEVAL_QUERY",
496
+ embedding_model=question_store.embedding_model,
497
+ output_dimensionality=question_store.dimension,
498
  )
499
+ search_k = max(1, min(top_k, len(question_store.records)))
500
+
501
+ candidates: dict[int, dict[str, Any]] = {}
502
+ for store_name, store in (("question", question_store), ("question_answer", qa_store)):
503
+ scores, indices = store.index.search(query_vector, search_k)
504
+ for score, index in zip(scores[0], indices[0]):
505
+ if index < 0:
506
+ continue
507
+ record = store.records[int(index)]
508
+ current = candidates.get(int(index))
509
+ score_value = round(float(score), 6)
510
+ if current is None or score_value > current["score"]:
511
+ candidates[int(index)] = {
512
+ "question": record.question,
513
+ "answer": _safe_excerpt(record.answer),
514
+ "score": score_value,
515
+ "source_file": record.source_file,
516
+ "metadata": record.metadata,
517
+ "matched_via": store_name,
518
+ }
519
+
520
+ matches = sorted(
521
+ candidates.values(),
522
+ key=lambda item: item["score"],
523
+ reverse=True,
524
+ )[:top_k]
525
 
526
  return {
527
  "query": query,