cotienbot / train.py
Anothervin1's picture
Update train.py
4c6157e verified
# File: train.py
import json
from src.embedding import Embedding
from src.firestore_db import FirestoreDB
from src.search import Search
from src.lottery import Lottery
from src.logger import logger
def train_data(db: FirestoreDB, search: Search, embedding: Embedding, lottery: Lottery, param: str = None, lottery_data: dict = None):
# FAISS (đối thoại)
if param and not lottery_data:
logger.info(f"Nhập dữ liệu đối thoại: {param}")
try:
try:
with open("data/training_data.json", "r", encoding="utf-8") as f:
content = f.read().strip()
data = json.loads(content) if content else []
if not isinstance(data, list):
logger.warning("File training_data.json không chứa mảng, khởi tạo lại")
data = []
except FileNotFoundError:
logger.info("File training_data.json không tồn tại, tạo mới")
data = []
except json.JSONDecodeError as e:
logger.error(f"File training_data.json không hợp lệ: {str(e)}, khởi tạo lại")
data = []
data.append({"text": param})
with open("data/training_data.json", "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
logger.info("Đã lưu dữ liệu đối thoại vào training_data.json")
except Exception as e:
logger.error(f"Lỗi khi lưu dữ liệu đối thoại: {str(e)}")
raise
for item in data:
text = item["text"]
vector = embedding.generate(text)
db.save_document(text, vector)
search.add(text, vector)
logger.info("Đào tạo FAISS hoàn tất")
return
# Lottery (xổ số)
if lottery_data:
logger.info(f"Nhập dữ liệu xổ số: {lottery_data}")
try:
if isinstance(lottery_data, list):
entries = lottery_data
else:
entries = [lottery_data]
try:
with open("data/lottery_data.json", "r", encoding="utf-8") as f:
content = f.read().strip()
existing_data = json.loads(content) if content else []
if not isinstance(existing_data, list):
logger.warning("File lottery_data.json không chứa mảng, khởi tạo lại")
existing_data = []
except FileNotFoundError:
logger.info("File lottery_data.json không tồn tại, tạo mới")
existing_data = []
except json.JSONDecodeError as e:
logger.error(f"File lottery_data.json không hợp lệ: {str(e)}, khởi tạo lại")
existing_data = []
existing_data.extend(entries)
with open("data/lottery_data.json", "w", encoding="utf-8") as f:
json.dump(existing_data, f, ensure_ascii=False, indent=4)
logger.info("Đã lưu dữ liệu xổ số vào lottery_data.json")
except Exception as e:
logger.error(f"Lỗi khi lưu dữ liệu xổ số: {str(e)}")
raise
try:
with open("data/lottery_data.json", "r", encoding="utf-8") as f:
content = f.read().strip()
data = json.loads(content) if content else []
if not isinstance(data, list):
logger.warning("File lottery_data.json không chứa mảng, khởi tạo lại")
data = []
except FileNotFoundError:
logger.error("Không tìm thấy file: data/lottery_data.json")
raise FileNotFoundError("Không tìm thấy file: data/lottery_data.json")
except json.JSONDecodeError as e:
logger.error(f"File lottery_data.json không hợp lệ: {str(e)}, khởi tạo lại")
data = []
logger.info(f"Xử lý {len(data)} bản ghi xổ số từ file")
for entry in data:
logger.info(f"Xử lý entry: {entry}")
try:
if not isinstance(entry, dict) or "ngay" not in entry or "dai" not in entry or "giai" not in entry:
logger.warning(f"Bỏ qua entry không hợp lệ: {entry}")
continue
db.save_lottery(entry)
lottery.add_lottery_vectors(entry)
except Exception as e:
logger.error(f"Lỗi khi xử lý entry xổ số {entry}: {str(e)}")
continue
logger.info("Đào tạo xổ số hoàn tất")