m97j's picture
Initial commit
07a6690
import os, json
from typing import List, Dict, Any, Optional
from chromadb import PersistentClient
from chromadb.utils.embedding_functions import EmbeddingFunction
from config import CHROMA_DIR
# === ์ดˆ๊ธฐํ™” ===
CHROMA_DIR.mkdir(parents=True, exist_ok=True)
print(f"๐Ÿ“‚ ChromaDB ๊ฒฝ๋กœ: {CHROMA_DIR.resolve()}")
_client = PersistentClient(path=str(CHROMA_DIR))
_collection = _client.get_or_create_collection(name="game_docs")
_embedder: Optional[EmbeddingFunction] = None
# === Embedder ์„ค์ • ===
def set_embedder(embedder: Any):
global _embedder
_embedder = embedder
def chroma_initialized() -> bool:
return os.path.exists(str(CHROMA_DIR)) and len(os.listdir(str(CHROMA_DIR))) > 0
# === type๋ณ„ content ์ถ”์ถœ ===
def extract_content(doc: Dict[str, Any]) -> str:
"""๋ฌธ์„œ type์— ๋”ฐ๋ผ content ํ•„๋“œ๋ฅผ ์ƒ์„ฑ"""
if "content" in doc and isinstance(doc["content"], str):
return doc["content"]
t = doc.get("type", "").lower()
if t in ["description", "lore", "fallback", "main_res_validate", "npc_persona"]:
return doc.get("description", "") or doc.get("content", "")
elif t == "trigger_def":
return doc.get("description", json.dumps(doc.get("trigger", {}), ensure_ascii=False))
elif t == "dialogue_turn":
# player + npc ๋Œ€์‚ฌ๋ฅผ ํ•ฉ์ณ์„œ ์ €์žฅ
return f"PLAYER: {doc.get('player', '')}\nNPC: {doc.get('npc', '')}".strip()
elif t == "flag_def":
return "\n".join(doc.get("examples_positive", []))
elif t == "trigger_meta":
return doc.get("trigger", "")
else:
# ์•Œ ์ˆ˜ ์—†๋Š” type์ด๋ฉด ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ํ…์ŠคํŠธ ํ•„๋“œ ํ•ฉ์นจ
text_parts = []
for k, v in doc.items():
if isinstance(v, str):
text_parts.append(v)
return "\n".join(text_parts)
# === ๋””์Šคํฌ์—์„œ ๋ฌธ์„œ ๋กœ๋“œ ===
def load_game_docs_from_disk(path: str) -> List[Dict[str, Any]]:
docs = []
for filename in os.listdir(path):
full = os.path.join(path, filename)
if filename.endswith(".json"):
with open(full, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
for i, doc in enumerate(data):
if "id" not in doc:
doc["id"] = f"{filename}_{i}"
doc["content"] = extract_content(doc)
docs.append(doc)
elif isinstance(data, dict):
if "id" not in data:
data["id"] = filename
data["content"] = extract_content(data)
docs.append(data)
elif filename.endswith(".txt"):
with open(full, "r", encoding="utf-8") as f:
content = f.read()
docs.append({
"id": filename,
"type": "text",
"content": content,
"metadata": {}
})
return docs
# === ๋ฌธ์„œ ์ถ”๊ฐ€ ===
def add_docs(docs: List[Dict[str, Any]], batch_size: int = 32):
assert _embedder is not None, "Embedder not initialized"
for i in range(0, len(docs), batch_size):
batch = docs[i:i+batch_size]
ids, contents, embeddings, metadatas = [], [], [], []
for doc in batch:
# id๋Š” ํ•„์ˆ˜, content๋Š” ์—†์œผ๋ฉด ๋นˆ ๋ฌธ์ž์—ด
doc_id = doc.get("id", f"doc_{i}")
content = doc.get("content", "")
ids.append(doc_id)
contents.append(content)
metadatas.append(doc) # ์›๋ณธ ์ „์ฒด ์ €์žฅ
emb = _embedder.encode(content).tolist() if content else []
embeddings.append(emb)
_collection.add(
documents=contents,
embeddings=embeddings,
metadatas=metadatas,
ids=ids
)
# === ๋ฌธ์„œ ๊ฒ€์ƒ‰ ===
def retrieve(query: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, top_k: int = 5) -> List[Dict[str, Any]]:
assert _embedder is not None, "Embedder not initialized"
if query:
q_emb = _embedder.encode(query).tolist()
res = _collection.query(
query_embeddings=[q_emb],
n_results=top_k,
where=filters or {}
)
docs = res.get("documents", [[]])[0]
metas = res.get("metadatas", [[]])[0]
else:
res = _collection.get(
where=filters or {},
limit=top_k
)
docs = res.get("documents", [])
metas = res.get("metadatas", [])
# ์›๋ณธ ๊ตฌ์กฐ ๋ณต์›
results = []
for d, m in zip(docs, metas):
if isinstance(m, dict):
results.append({
"id": m.get("id", ""),
"type": m.get("type", "unknown"),
"npc_id": m.get("npc_id", ""),
"quest_stage": m.get("quest_stage", ""),
"location": m.get("location", ""),
"content": d,
"metadata": m
})
else:
results.append({
"id": "",
"type": "unknown",
"npc_id": "",
"quest_stage": "",
"location": "",
"content": d,
"metadata": {}
})
return results