|
import os |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
import tensorflow as tf |
|
from typing import List, Tuple, Dict, Optional, Union, Any |
|
import math |
|
from dataclasses import dataclass |
|
import json |
|
from pathlib import Path |
|
import datetime |
|
import faiss |
|
import gc |
|
import re |
|
from response_quality_checker import ResponseQualityChecker |
|
from cross_encoder_reranker import CrossEncoderReranker |
|
from conversation_summarizer import DeviceAwareModel, Summarizer |
|
from chatbot_config import ChatbotConfig |
|
from tf_data_pipeline import TFDataPipeline |
|
import absl.logging |
|
from logger_config import config_logger |
|
from tqdm.auto import tqdm |
|
|
|
absl.logging.set_verbosity(absl.logging.WARNING) |
|
logger = config_logger(__name__) |
|
logger.setLevel("WARNING") |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
tqdm(disable=True) |
|
|
|
class RetrievalChatbot(DeviceAwareModel): |
|
""" |
|
Retrieval-based learning chatbot model. |
|
Uses trained embeddings and FAISS for similarity search. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
device: str = None, |
|
strategy=None, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
mode: str = 'training' |
|
): |
|
|
|
super().__init__() |
|
self.config = config |
|
self.strategy = strategy |
|
self.device = device or self._setup_default_device() |
|
self.mode = mode.lower() |
|
|
|
|
|
self.encoder = self._initialize_encoder() |
|
self.tokenizer = self.encoder.tokenizer |
|
self.reranker = reranker or self._initialize_reranker() |
|
self.summarizer = summarizer or self._initialize_summarizer() |
|
|
|
|
|
logger.info("Initializing TFDataPipeline.") |
|
|
|
self.data_pipeline = TFDataPipeline( |
|
config=self.config, |
|
tokenizer=self.tokenizer, |
|
encoder=self.encoder, |
|
response_pool=[], |
|
query_embeddings_cache={}, |
|
) |
|
|
|
|
|
if self.mode == 'inference': |
|
logger.info("Mode set to 'inference'. Loading FAISS index and response pool.") |
|
self._load_faiss_index_and_responses() |
|
elif self.mode != 'training': |
|
logger.error(f"Unsupported mode in RetrievalChatbot init: {self.mode}") |
|
raise ValueError(f"Unsupported mode in RetrievalChatbot init: {self.mode}") |
|
|
|
|
|
self.history = { |
|
"train_loss": [], |
|
"val_loss": [], |
|
"train_metrics": {}, |
|
"val_metrics": {} |
|
} |
|
|
|
def _setup_default_device(self) -> str: |
|
"""Set up default device if none is provided.""" |
|
if tf.config.list_physical_devices('GPU'): |
|
return 'GPU' |
|
else: |
|
return 'CPU' |
|
|
|
def _initialize_reranker(self) -> CrossEncoderReranker: |
|
"""Initialize the CrossEncoderReranker.""" |
|
logger.info("Initializing default CrossEncoderReranker...") |
|
return CrossEncoderReranker(model_name=self.config.cross_encoder_model) |
|
|
|
def _initialize_summarizer(self) -> Summarizer: |
|
"""Initialize the Summarizer.""" |
|
return Summarizer( |
|
tokenizer=self.tokenizer, |
|
model_name=self.config.summarizer_model, |
|
max_summary_length=self.config.max_context_length // 4, |
|
device=self.device, |
|
max_summary_rounds=2 |
|
) |
|
|
|
def _initialize_encoder(self) -> SentenceTransformer: |
|
"""Initialize the Sentence Transformer model.""" |
|
logger.info("Initializing SentenceTransformer encoder model...") |
|
encoder = SentenceTransformer(self.config.pretrained_model) |
|
return encoder |
|
|
|
def _load_faiss_index_and_responses(self) -> None: |
|
"""Load FAISS index and response pool for inference.""" |
|
try: |
|
logger.info(f"Loading FAISS index from {self.data_pipeline.faiss_index_file_path}...") |
|
self.data_pipeline.load_faiss_index(self.data_pipeline.faiss_index_file_path) |
|
logger.info("FAISS index loaded successfully.") |
|
|
|
|
|
response_pool_path = self.data_pipeline.faiss_index_file_path.replace('.index', '_responses.json') |
|
if os.path.exists(response_pool_path): |
|
with open(response_pool_path, 'r', encoding='utf-8') as f: |
|
self.data_pipeline.response_pool = json.load(f) |
|
logger.info(f"Loaded {len(self.data_pipeline.response_pool)} responses from {response_pool_path}.") |
|
else: |
|
logger.error(f"Response pool file not found at {response_pool_path}.") |
|
raise FileNotFoundError(f"Response pool file not found at {response_pool_path}.") |
|
|
|
|
|
self.data_pipeline.validate_faiss_index() |
|
logger.info("FAISS index and response pool validated successfully.") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load FAISS index and response pool: {e}") |
|
raise |
|
|
|
@classmethod |
|
def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot': |
|
"""Load chatbot model and configuration.""" |
|
load_dir = Path(load_dir) |
|
|
|
|
|
config_path = load_dir / "config.json" |
|
if config_path.exists(): |
|
with open(config_path, "r") as f: |
|
config = ChatbotConfig.from_dict(json.load(f)) |
|
logger.info("Loaded ChatbotConfig from config.json.") |
|
else: |
|
raise FileNotFoundError(f"Config file not found at {config_path}. Please ensure it exists.") |
|
|
|
|
|
chatbot = cls(config, mode=mode) |
|
|
|
|
|
model_path = load_dir / "sentence_transformer" |
|
if model_path.exists(): |
|
|
|
chatbot.encoder = SentenceTransformer(str(model_path)) |
|
logger.info("Loaded SentenceTransformer model from local path successfully.") |
|
else: |
|
|
|
chatbot.encoder = SentenceTransformer(config.pretrained_model) |
|
logger.info(f"Loaded SentenceTransformer model '{config.pretrained_model}' from the hub successfully.") |
|
|
|
return chatbot |
|
|
|
@classmethod |
|
def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None: |
|
"""Load inference components.""" |
|
try: |
|
|
|
faiss_path = load_dir / 'faiss_indices/faiss_index_production.index' |
|
if faiss_path.exists(): |
|
chatbot.index = faiss.read_index(str(faiss_path)) |
|
logger.info("FAISS index loaded successfully") |
|
else: |
|
raise FileNotFoundError(f"FAISS index not found at {faiss_path}") |
|
|
|
|
|
response_pool_path = load_dir / 'faiss_indices/faiss_index_production_responses.json' |
|
if response_pool_path.exists(): |
|
with open(response_pool_path, 'r') as f: |
|
chatbot.response_pool = json.load(f) |
|
logger.info(f"Loaded {len(chatbot.response_pool)} responses") |
|
else: |
|
raise FileNotFoundError(f"Response pool not found at {response_pool_path}") |
|
|
|
|
|
if chatbot.index.d != chatbot.config.embedding_dim: |
|
raise ValueError( |
|
f"FAISS index dimension {chatbot.index.d} doesn't match " |
|
f"model dimension {chatbot.config.embedding_dim}" |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading inference components: {e}") |
|
raise |
|
|
|
def save_models(self, save_dir: Union[str, Path]): |
|
"""Save SentenceTransformer model and config.""" |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(save_dir / "config.json", "w") as f: |
|
json.dump(self.config.to_dict(), f, indent=2) |
|
|
|
|
|
self.encoder.save(save_dir / "sentence_transformer") |
|
logger.info(f"Model and config saved to {save_dir}.") |
|
|
|
def retrieve_responses( |
|
self, |
|
query: str, |
|
top_k: int = 10, |
|
reranker: Optional[CrossEncoderReranker] = None, |
|
summarizer: Optional[Summarizer] = None, |
|
summarize_threshold: int = 512, |
|
boost_factor: float = 1.15 |
|
) -> List[Tuple[str, float]]: |
|
""" |
|
Retrieve top-k responses using FAISS and cross-encoder re-ranking. |
|
Args: |
|
query: The user's input text. |
|
top_k: Number of responses to return. |
|
reranker: Optional reranker for refined scoring. |
|
summarizer: Optional summarizer for long queries. |
|
summarize_threshold: Threshold to summarize long queries. |
|
boost_factor: Factor to boost scores for keyword matches. |
|
Returns: |
|
List of (response_text, final_score). |
|
""" |
|
def sigmoid(x: float) -> float: |
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
if summarizer and len(query.split()) > summarize_threshold: |
|
logger.info(f"Query is long ({len(query.split())} words). Summarizing...") |
|
query = summarizer.summarize_text(query) |
|
logger.info(f"Summarized query: {query}") |
|
|
|
|
|
detected_domain = self.detect_domain_from_query(query) |
|
|
|
|
|
|
|
|
|
faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10) |
|
|
|
if not faiss_candidates: |
|
logger.warning("No candidates retrieved from FAISS.") |
|
return [] |
|
|
|
|
|
if detected_domain != 'other': |
|
in_domain_candidates = [c for c in faiss_candidates if c[0]["domain"] == detected_domain] |
|
if in_domain_candidates: |
|
faiss_candidates = in_domain_candidates |
|
else: |
|
logger.info(f"No in-domain responses found for '{query}'. Using all candidates.") |
|
|
|
|
|
|
|
texts = [item[0]["text"] for item in faiss_candidates] |
|
faiss_scores = [item[1] for item in faiss_candidates] |
|
|
|
if reranker is None: |
|
reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model) |
|
|
|
ce_logits = reranker.rerank(query, texts, max_length=256) |
|
|
|
|
|
final_candidates = [] |
|
for resp_text, faiss_score, logit in zip(texts, faiss_scores, ce_logits): |
|
ce_prob = sigmoid(logit) |
|
faiss_norm = (faiss_score + 1) / 2 |
|
combined_score = 0.75 * ce_prob + 0.25 * faiss_norm |
|
|
|
|
|
query_keywords = self.extract_keywords(query) |
|
if query_keywords and any(kw in resp_text.lower() for kw in query_keywords): |
|
combined_score *= boost_factor |
|
|
|
|
|
length_adjusted_score = self.length_adjust_score(resp_text, combined_score) |
|
|
|
final_candidates.append((resp_text, length_adjusted_score)) |
|
|
|
|
|
final_candidates.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
return final_candidates[:top_k] |
|
|
|
def extract_keywords(self, query: str) -> List[str]: |
|
""" |
|
Return any domain keywords present in the query (lowercased). |
|
""" |
|
domain_keywords = { |
|
'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'], |
|
'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'], |
|
'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'], |
|
'coffee': ['coffee', 'café', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'], |
|
'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'], |
|
'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'], |
|
} |
|
|
|
query_lower = query.lower() |
|
found = set() |
|
for domain, kw_list in domain_keywords.items(): |
|
for kw in kw_list: |
|
if kw in query_lower: |
|
found.add(kw) |
|
return list(found) |
|
|
|
def length_adjust_score(self, text: str, base_score: float) -> float: |
|
""" |
|
Penalize very short lines, reward longer lines. |
|
""" |
|
words = text.split() |
|
wcount = len(words) |
|
|
|
|
|
if wcount < 4: |
|
return base_score * 0.8 |
|
|
|
|
|
if wcount > 15: |
|
bonus = min(0.03, 0.001 * (wcount - 15)) |
|
base_score += bonus |
|
|
|
return base_score |
|
|
|
def detect_domain_from_query(self, query: str) -> str: |
|
""" |
|
Detect the domain of the query based on keywords. Used for filtering FAISS search. |
|
""" |
|
domain_patterns = { |
|
'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b', |
|
'movie': r'\b(movie|movies?|cinema|cinemas?|film|films?|ticket|tickets?|showtime|showtimes?|showing|showings?|theater|theaters?|flick|flicks?|screening|screenings?|film\s?ticket|film\s?tickets?|film\s?show|film\s?shows?|blockbuster|blockbusters?|premiere|premieres?|trailer|trailers?|director|directors?|actor|actors?|actress|actresses?|plot|plots?|genre|genres?|screen|screens?|sequel|sequels?|animation|animations?|documentary|documentaries)\b', |
|
'ride_share': r'\b(ride|rides?|taxi|taxis?|uber|lyft|car\s?service|car\s?services?|pickup|pickups?|dropoff|dropoffs?|driver|drivers?|cab|cabs?|hailing|hailings?|rideshare|rideshares?|ride\s?hailing|ride\s?hailings?|carpool|carpools?|chauffeur|chauffeurs?|transit|transits?|transportation|transportations?|hail\s?ride|hail\s?rides?)\b', |
|
'coffee': r'\b(coffee|coffees?|café|cafés?|cafe|cafes?|starbucks|espresso|espressos?|latte|lattes?|mocha|mochas?|americano|americanos?|barista|baristas?|brew|brews?|cappuccino|cappuccinos?|macchiato|macchiatos?|iced\s?coffee|iced\s?coffees?|cold\s?brew|cold\s?brews?|espresso\s?machine|espresso\s?machines?|coffee\s?shop|coffee\s?shops?|tea|teas?|chai|chais?|java|javas?|bean|beans?|roast|roasts?|decaf)\b', |
|
'pizza': r'\b(pizza|pizzas?|delivery|deliveries|order\s?food|order\s?foods?|pepperoni|pepperonis?|topping|toppings?|pizzeria|pizzerias?|slice|slices?|pie|pies?|margherita|margheritas?|deep\s?dish|deep\s?dishes?|thin\s?crust|thin\s?crusts?|cheese|cheeses?|oven|ovens?|tossed|tosses?|sauce|sauces?|garlic\s?bread|garlic\s?breads?|calzone|calzones?)\b', |
|
'auto': r'\b(car|cars?|vehicle|vehicles?|repair|repairs?|maintenance|maintenances?|mechanic|mechanics?|oil\s?change|oil\s?changes?|garage|garages?|auto\s?shop|auto\s?shops?|tire|tires?|check\s?engine|check\s?engines?|battery|batteries?|transmission|transmissions?|brake|brakes?|engine\s?diagnostics|engine\s?diagnostic|carwash|carwashes?|detail|details?|alignment|alignments?|exhaust|exhausts?|spark\s?plug|spark\s?plugs?|dashboard|dashboards?)\b', |
|
} |
|
|
|
|
|
for domain, pattern in domain_patterns.items(): |
|
if re.search(pattern, query.lower()): |
|
return domain |
|
|
|
return 'other' |
|
|
|
def is_numeric_response(self, text: str) -> bool: |
|
""" |
|
Return True if `text` is purely digits and/or spaces. |
|
""" |
|
pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$' |
|
return bool(re.match(pattern, text.strip())) |
|
|
|
def introduction_message(self) -> None: |
|
"""Print an introduction message to introduce the chatbot.""" |
|
print( |
|
"\nAssistant: Hello! I'm a simple chatbot assistant. I've been trained to answer " |
|
"basic questions about topics including restaurants, movies, ride sharing, coffee, and pizza. " |
|
"Please ask me a question and I'll do my best to assist you." |
|
) |
|
|
|
def run_interactive_chat(self, quality_checker, show_alternatives=False): |
|
"""Separate function for interactive chat loop.""" |
|
|
|
|
|
self.introduction_message() |
|
|
|
|
|
while True: |
|
try: |
|
user_input = input("\nYou: ") |
|
except (KeyboardInterrupt, EOFError): |
|
print("\nAssistant: Goodbye!") |
|
break |
|
|
|
if user_input.lower() in ["quit", "exit", "bye"]: |
|
print("\nAssistant: Goodbye!") |
|
break |
|
|
|
response, candidates, metrics, top_response_score = self.chat( |
|
query=user_input, |
|
conversation_history=None, |
|
quality_checker=quality_checker, |
|
top_k=10 |
|
) |
|
|
|
print(f"\nAssistant: {response}") |
|
|
|
if show_alternatives and candidates and metrics.get("is_confident", False): |
|
print("\n Alternative responses:") |
|
for resp, score in candidates[1:4]: |
|
print(f" Score: {score:.4f} - {resp}") |
|
elif top_response_score < 0.7: |
|
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.") |
|
|
|
def chat( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] = None, |
|
quality_checker: Optional['ResponseQualityChecker'] = None, |
|
top_k: int = 10, |
|
) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: |
|
""" |
|
Live chat with the chatbot. Uses same processing flow as validation, except for context handling and quality checking. |
|
""" |
|
@self.run_on_device |
|
def get_response(self_arg, query_arg): |
|
|
|
conversation_str = self_arg._build_conversation_context(query_arg, conversation_history) |
|
|
|
|
|
responses = self_arg.retrieve_responses( |
|
query=conversation_str, |
|
top_k=top_k, |
|
reranker=self_arg.reranker, |
|
summarizer=self_arg.summarizer, |
|
summarize_threshold=512 |
|
) |
|
|
|
|
|
if not responses: |
|
return ("I'm sorry, but I couldn't find a relevant response.", [], {}) |
|
|
|
|
|
metrics = quality_checker.check_response_quality(query_arg, responses) |
|
is_confident = metrics.get('is_confident', False) |
|
top_response_score = responses[0][1] |
|
|
|
|
|
if not is_confident or top_response_score < 0.5: |
|
return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics, top_response_score) |
|
|
|
|
|
return responses[0][0], responses, metrics, top_response_score |
|
|
|
return get_response(self, query) |
|
|
|
def _build_conversation_context( |
|
self, |
|
query: str, |
|
conversation_history: Optional[List[Tuple[str, str]]] |
|
) -> str: |
|
""" |
|
Build conversation context string from conversation history, |
|
using literal <USER> and <ASSISTANT> tokens (no tokenizer special index). |
|
""" |
|
USER_TOKEN = "<USER>" |
|
ASSISTANT_TOKEN = "<ASSISTANT>" |
|
|
|
if not conversation_history: |
|
return f"{USER_TOKEN} {query}" |
|
|
|
conversation_parts = [] |
|
for user_txt, assistant_txt in conversation_history: |
|
|
|
conversation_parts.append(f"{USER_TOKEN} {user_txt}") |
|
conversation_parts.append(f"{ASSISTANT_TOKEN} {assistant_txt}") |
|
|
|
conversation_parts.append(f"{USER_TOKEN} {query}") |
|
return "\n".join(conversation_parts) |
|
|
|
def train_model( |
|
self, |
|
tfrecord_file_path: str, |
|
epochs: int = 20, |
|
batch_size: int = 16, |
|
validation_split: float = 0.2, |
|
checkpoint_dir: str = "checkpoints/", |
|
use_lr_schedule: bool = True, |
|
peak_lr: float = 1e-5, |
|
warmup_steps_ratio: float = 0.1, |
|
early_stopping_patience: int = 3, |
|
min_delta: float = 1e-4, |
|
test_mode: bool = False, |
|
initial_epoch: int = 0 |
|
) -> None: |
|
""" |
|
Train the retrieval model using a pre-prepared TFRecord dataset. |
|
- Checkpoint loading/restoring |
|
- LR scheduling |
|
- Epoch/iteration tracking |
|
- Training-history logging |
|
- Early stopping |
|
- Custom loss function (Contrastive loss with hard negative sampling)) |
|
""" |
|
logger.info("Starting training with pre-prepared TFRecord dataset...") |
|
|
|
def parse_tfrecord_fn(example_proto, max_length, neg_samples): |
|
""" |
|
Parses a single TFRecord example. |
|
""" |
|
feature_description = { |
|
'query_ids': tf.io.FixedLenFeature([max_length], tf.int64), |
|
'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64), |
|
'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64), |
|
} |
|
parsed_features = tf.io.parse_single_example(example_proto, feature_description) |
|
|
|
query_ids = tf.cast(parsed_features['query_ids'], tf.int32) |
|
positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32) |
|
negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32) |
|
negative_ids = tf.reshape(negative_ids, [neg_samples, max_length]) |
|
|
|
return query_ids, positive_ids, negative_ids |
|
|
|
|
|
raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
|
total_pairs = sum(1 for _ in raw_dataset) |
|
logger.info(f"Total pairs in TFRecord: {total_pairs}") |
|
|
|
train_size = int(total_pairs * (1 - validation_split)) |
|
val_size = total_pairs - train_size |
|
steps_per_epoch = math.ceil(train_size / batch_size) |
|
val_steps = math.ceil(val_size / batch_size) |
|
total_steps = steps_per_epoch * epochs |
|
buffer_size = max(1, total_pairs // 2) |
|
|
|
logger.info(f"Training pairs: {train_size}") |
|
logger.info(f"Validation pairs: {val_size}") |
|
logger.info(f"Steps per epoch: {steps_per_epoch}") |
|
logger.info(f"Validation steps: {val_steps}") |
|
logger.info(f"Total steps: {total_steps}") |
|
|
|
|
|
if use_lr_schedule: |
|
warmup_steps = int(total_steps * warmup_steps_ratio) |
|
lr_schedule = self._get_lr_schedule( |
|
total_steps=total_steps, |
|
peak_lr=tf.cast(peak_lr, tf.float32), |
|
warmup_steps=warmup_steps |
|
) |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
|
logger.info("Using custom learning rate schedule.") |
|
else: |
|
self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32)) |
|
logger.info("Using fixed learning rate.") |
|
|
|
|
|
dummy_input = tf.zeros((1, self.config.max_context_length), dtype=tf.int32) |
|
with tf.GradientTape() as tape: |
|
dummy_output = self.encoder(dummy_input) |
|
dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32) |
|
dummy_grads = tape.gradient(dummy_loss, self.encoder.trainable_variables) |
|
self.optimizer.apply_gradients(zip(dummy_grads, self.encoder.trainable_variables)) |
|
|
|
|
|
checkpoint = tf.train.Checkpoint( |
|
epoch=tf.Variable(0, dtype=tf.int32), |
|
optimizer=self.optimizer, |
|
model=self.encoder |
|
) |
|
|
|
|
|
manager = tf.train.CheckpointManager( |
|
checkpoint, |
|
directory=checkpoint_dir, |
|
max_to_keep=3, |
|
checkpoint_name='ckpt' |
|
) |
|
|
|
|
|
latest_checkpoint = manager.latest_checkpoint |
|
history_path = Path(checkpoint_dir) / 'training_history.json' |
|
|
|
|
|
if not hasattr(self, 'history'): |
|
self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []} |
|
|
|
if latest_checkpoint and not test_mode: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
status = checkpoint.restore(latest_checkpoint) |
|
status.assert_consumed() |
|
logger.info(f"Restored from checkpoint: {latest_checkpoint}") |
|
logger.info(f"Optimizer iterations after restore: {self.optimizer.iterations.numpy()}") |
|
|
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
logger.info(f"Current learning rate after restore: {current_lr:.2e}") |
|
|
|
|
|
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1]) |
|
if initial_epoch == 0: |
|
initial_epoch = ckpt_number |
|
|
|
|
|
checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32)) |
|
logger.info(f"Resuming from epoch {initial_epoch}") |
|
|
|
|
|
if history_path.exists(): |
|
try: |
|
with open(history_path, 'r') as f: |
|
self.history = json.load(f) |
|
logger.info(f"Loaded previous training history from {history_path}") |
|
except Exception as e: |
|
logger.warning(f"Could not load history, starting fresh: {e}") |
|
|
|
|
|
|
|
|
|
|
|
self.save_models(Path(checkpoint_dir) / "pretrained_full_model") |
|
logger.info(f"Manually saved custom weights after restore.") |
|
else: |
|
logger.info("Starting training from scratch") |
|
checkpoint.epoch.assign(tf.cast(0, tf.int32)) |
|
initial_epoch = 0 |
|
|
|
|
|
log_dir = Path(checkpoint_dir) / "tensorboard_logs" |
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
|
train_log_dir = str(log_dir / f"train_{current_time}") |
|
val_log_dir = str(log_dir / f"val_{current_time}") |
|
train_summary_writer = tf.summary.create_file_writer(train_log_dir) |
|
val_summary_writer = tf.summary.create_file_writer(val_log_dir) |
|
logger.info(f"TensorBoard logs will be saved in {log_dir}") |
|
|
|
|
|
dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
|
|
|
|
|
if test_mode: |
|
subset_size = 200 |
|
dataset = dataset.take(subset_size) |
|
logger.info(f"TEST MODE: Using only {subset_size} examples") |
|
|
|
total_pairs = subset_size |
|
train_size = int(total_pairs * (1 - validation_split)) |
|
val_size = total_pairs - train_size |
|
batch_size = min(batch_size, val_size) |
|
steps_per_epoch = math.ceil(train_size / batch_size) |
|
val_steps = math.ceil(val_size / batch_size) |
|
total_steps = steps_per_epoch * epochs |
|
buffer_size = max(1, total_pairs // 10) |
|
epochs = min(epochs, 5) |
|
early_stopping_patience = 2 |
|
logger.info(f"New training pairs: {train_size}") |
|
logger.info(f"New validation pairs: {val_size}") |
|
|
|
dataset = dataset.map( |
|
lambda x: parse_tfrecord_fn(x, self.config.max_context_length, self.data_pipeline.neg_samples), |
|
num_parallel_calls=tf.data.AUTOTUNE |
|
) |
|
|
|
|
|
train_dataset = dataset.take(train_size) |
|
val_dataset = dataset.skip(train_size).take(val_size) |
|
|
|
|
|
train_dataset = train_dataset.shuffle(buffer_size=buffer_size) |
|
train_dataset = train_dataset.batch(batch_size, drop_remainder=True) |
|
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) |
|
|
|
val_dataset = val_dataset.batch(batch_size, drop_remainder=False) |
|
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) |
|
val_dataset = val_dataset.cache() |
|
|
|
|
|
best_val_loss = float("inf") |
|
epochs_no_improve = 0 |
|
|
|
for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1): |
|
checkpoint.epoch.assign(epoch) |
|
logger.info(f"Starting Epoch {epoch}...") |
|
|
|
epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32) |
|
batches_processed = 0 |
|
|
|
try: |
|
train_pbar = tqdm( |
|
total=steps_per_epoch, |
|
desc=f"Training Epoch {epoch}", |
|
unit="batch" |
|
) |
|
is_tqdm_train = True |
|
except ImportError: |
|
train_pbar = None |
|
is_tqdm_train = False |
|
|
|
|
|
for q_batch, p_batch, n_batch in train_dataset: |
|
loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch) |
|
epoch_loss_avg(loss) |
|
batches_processed += 1 |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
step = (epoch - 1) * steps_per_epoch + batches_processed |
|
tf.summary.scalar("loss", tf.cast(loss, tf.float32), step=step) |
|
tf.summary.scalar("gradient_norm_pre_clip", tf.cast(grad_norm, tf.float32), step=step) |
|
tf.summary.scalar("gradient_norm_post_clip", tf.cast(post_clip_norm, tf.float32), step=step) |
|
|
|
|
|
if use_lr_schedule: |
|
current_lr = float(lr_schedule(self.optimizer.iterations)) |
|
else: |
|
current_lr = float(self.optimizer.learning_rate.numpy()) |
|
|
|
if is_tqdm_train: |
|
train_pbar.update(1) |
|
train_pbar.set_postfix({ |
|
"loss": f"{loss.numpy():.4f}", |
|
"pre_clip": f"{grad_norm.numpy():.2e}", |
|
"post_clip": f"{post_clip_norm.numpy():.2e}", |
|
"lr": f"{current_lr:.2e}", |
|
"batches": f"{batches_processed}/{steps_per_epoch}" |
|
}) |
|
|
|
gc.collect() |
|
|
|
|
|
if batches_processed >= steps_per_epoch: |
|
break |
|
|
|
if is_tqdm_train and train_pbar: |
|
train_pbar.close() |
|
|
|
|
|
val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32) |
|
val_batches_processed = 0 |
|
|
|
try: |
|
val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch") |
|
is_tqdm_val = True |
|
except ImportError: |
|
val_pbar = None |
|
is_tqdm_val = False |
|
|
|
last_valid_val_loss = None |
|
valid_batches = False |
|
|
|
for q_batch, p_batch, n_batch in val_dataset: |
|
|
|
if tf.shape(q_batch)[0] < 2: |
|
logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}") |
|
continue |
|
|
|
valid_batches = True |
|
val_loss = self.validation_step(q_batch, p_batch, n_batch) |
|
val_loss_avg(val_loss) |
|
last_valid_val_loss = val_loss |
|
val_batches_processed += 1 |
|
|
|
if is_tqdm_val: |
|
val_pbar.update(1) |
|
val_pbar.set_postfix({ |
|
"val_loss": f"{val_loss.numpy():.4f}", |
|
"batches": f"{val_batches_processed}/{val_steps}" |
|
}) |
|
|
|
gc.collect() |
|
|
|
if val_batches_processed >= val_steps: |
|
break |
|
|
|
if not valid_batches: |
|
|
|
logger.warning("No valid validation batches in this epoch") |
|
if last_valid_val_loss is not None: |
|
val_loss = last_valid_val_loss |
|
val_loss_avg(val_loss) |
|
else: |
|
val_loss = epoch_loss_avg.result() |
|
val_loss_avg(val_loss) |
|
|
|
if is_tqdm_val and val_pbar: |
|
val_pbar.close() |
|
|
|
|
|
train_loss = epoch_loss_avg.result().numpy() |
|
val_loss = val_loss_avg.result().numpy() |
|
logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
|
|
|
|
|
with train_summary_writer.as_default(): |
|
tf.summary.scalar("epoch_loss", train_loss, step=epoch) |
|
with val_summary_writer.as_default(): |
|
tf.summary.scalar("val_loss", val_loss, step=epoch) |
|
|
|
|
|
manager.save() |
|
|
|
|
|
model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}" |
|
self.save_models(model_save_path) |
|
logger.info(f"Saved model for epoch {epoch} at {model_save_path}") |
|
|
|
|
|
self.history['train_loss'].append(train_loss) |
|
self.history['val_loss'].append(val_loss) |
|
self.history.setdefault('learning_rate', []).append(current_lr) |
|
|
|
def convert_to_py_floats(obj): |
|
if isinstance(obj, dict): |
|
return {k: convert_to_py_floats(v) for k, v in obj.items()} |
|
elif isinstance(obj, list): |
|
return [convert_to_py_floats(x) for x in obj] |
|
elif isinstance(obj, (np.float32, np.float64)): |
|
return float(obj) |
|
elif tf.is_tensor(obj): |
|
return float(obj.numpy()) |
|
else: |
|
return obj |
|
|
|
json_history = convert_to_py_floats(self.history) |
|
|
|
|
|
with open(history_path, 'w') as f: |
|
json.dump(json_history, f) |
|
logger.info(f"Saved training history to {history_path}") |
|
|
|
|
|
if val_loss < best_val_loss - min_delta: |
|
best_val_loss = val_loss |
|
epochs_no_improve = 0 |
|
logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.") |
|
else: |
|
epochs_no_improve += 1 |
|
logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}") |
|
if epochs_no_improve >= early_stopping_patience: |
|
logger.info("Early stopping triggered.") |
|
break |
|
|
|
logger.info("Training completed!") |
|
|
|
@tf.function |
|
def train_step( |
|
self, |
|
q_batch: tf.Tensor, |
|
p_batch: tf.Tensor, |
|
n_batch: tf.Tensor |
|
) -> tf.Tensor: |
|
""" |
|
Single training step using queries, positives, and hard negatives. |
|
""" |
|
with tf.GradientTape() as tape: |
|
|
|
q_enc = self.encoder(q_batch, training=True) |
|
p_enc = self.encoder(p_batch, training=True) |
|
shape = tf.shape(n_batch) |
|
bs = shape[0] |
|
neg_samples = shape[1] |
|
|
|
|
|
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
|
n_enc_flat = self.encoder(n_batch_flat, training=True) |
|
|
|
|
|
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
|
|
|
|
|
|
|
combined_p_n = tf.concat([tf.expand_dims(p_enc, axis=1), n_enc], axis=1) |
|
|
|
|
|
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32) |
|
labels = tf.zeros([bs], dtype=tf.int32) |
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, |
|
logits=dot_products |
|
) |
|
loss = tf.cast(tf.reduce_mean(loss), tf.float32) |
|
|
|
|
|
gradients = tape.gradient(loss, self.encoder.trainable_variables) |
|
gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32) |
|
max_grad_norm = tf.constant(1.5, dtype=tf.float32) |
|
gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm) |
|
post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32) |
|
|
|
self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables)) |
|
|
|
return loss, gradients_norm, post_clip_norm |
|
|
|
@tf.function |
|
def validation_step( |
|
self, |
|
q_batch: tf.Tensor, |
|
p_batch: tf.Tensor, |
|
n_batch: tf.Tensor |
|
) -> tf.Tensor: |
|
""" |
|
Single validation step using queries, positives, and hard negatives. |
|
Same idea as train_step, but without gradient updates. |
|
""" |
|
q_enc = self.encoder(q_batch, training=False) |
|
p_enc = self.encoder(p_batch, training=False) |
|
|
|
shape = tf.shape(n_batch) |
|
bs = shape[0] |
|
neg_samples = shape[1] |
|
|
|
n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
|
n_enc_flat = self.encoder(n_batch_flat, training=False) |
|
n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
|
|
|
combined_p_n = tf.concat( |
|
[tf.expand_dims(p_enc, axis=1), n_enc], |
|
axis=1 |
|
) |
|
|
|
dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32) |
|
labels = tf.zeros([bs], dtype=tf.int32) |
|
|
|
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
labels=labels, |
|
logits=dot_products |
|
) |
|
loss = tf.cast(tf.reduce_mean(loss), tf.float32) |
|
|
|
return loss |
|
|
|
def _get_lr_schedule( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
) -> tf.keras.optimizers.schedules.LearningRateSchedule: |
|
""" |
|
Custom learning rate schedule with warmup and cosine decay. |
|
""" |
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__( |
|
self, |
|
total_steps: int, |
|
peak_lr: float, |
|
warmup_steps: int |
|
): |
|
super().__init__() |
|
self.total_steps = tf.cast(total_steps, tf.float32) |
|
self.peak_lr = tf.cast(peak_lr, tf.float32) |
|
|
|
|
|
adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10)) |
|
self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32) |
|
|
|
|
|
self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32) |
|
self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32) |
|
|
|
logger.info(f"Learning rate schedule initialized:") |
|
logger.info(f" Initial LR: {float(self.initial_lr):.6f}") |
|
logger.info(f" Peak LR: {float(self.peak_lr):.6f}") |
|
logger.info(f" Min LR: {float(self.min_lr):.6f}") |
|
logger.info(f" Warmup steps: {int(self.warmup_steps)}") |
|
logger.info(f" Total steps: {int(self.total_steps)}") |
|
|
|
def __call__(self, step): |
|
step = tf.cast(step, tf.float32) |
|
|
|
|
|
warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32) |
|
warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor |
|
|
|
|
|
decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32) |
|
decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32) |
|
decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32) |
|
cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32) |
|
decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay |
|
|
|
final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr) |
|
|
|
|
|
final_lr = tf.maximum(self.min_lr, final_lr) |
|
final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr) |
|
|
|
return final_lr |
|
|
|
def get_config(self): |
|
return { |
|
"total_steps": self.total_steps, |
|
"peak_lr": self.peak_lr, |
|
"warmup_steps": self.warmup_steps, |
|
} |
|
|
|
return CustomSchedule(total_steps, peak_lr, warmup_steps) |
|
|