|
import os |
|
import numpy as np |
|
import faiss |
|
import tensorflow as tf |
|
import h5py |
|
import math |
|
import random |
|
import gc |
|
from tqdm.auto import tqdm |
|
import json |
|
from pathlib import Path |
|
from typing import Union, Optional, Dict, List, Tuple, Generator |
|
from transformers import AutoTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
from chatbot_config import ChatbotConfig |
|
from typing import List, Tuple, Generator |
|
from transformers import AutoTokenizer |
|
import random |
|
|
|
from logger_config import config_logger |
|
logger = config_logger(__name__) |
|
|
|
class TFDataPipeline: |
|
def __init__( |
|
self, |
|
config: ChatbotConfig, |
|
tokenizer: AutoTokenizer, |
|
encoder: SentenceTransformer, |
|
response_pool: List[str], |
|
query_embeddings_cache: dict, |
|
index_type: str = 'IndexFlatIP', |
|
faiss_index_file_path: str = 'models/faiss_indices/faiss_index_production.index', |
|
): |
|
self.config = config |
|
self.tokenizer = tokenizer |
|
self.encoder = encoder |
|
self.model = SentenceTransformer(config.pretrained_model) |
|
self.faiss_index_file_path = faiss_index_file_path |
|
self.response_pool = response_pool |
|
self.query_embeddings_cache = query_embeddings_cache |
|
self.index_type = index_type |
|
self.neg_samples = config.neg_samples |
|
self.nlist = config.nlist |
|
self.dimension = config.embedding_dim |
|
self.max_context_length = config.max_context_length |
|
self.embedding_batch_size = config.embedding_batch_size |
|
self.search_batch_size = config.search_batch_size |
|
self.max_batch_size = config.max_batch_size |
|
self.max_retries = config.max_retries |
|
|
|
|
|
self._text_domain_map = {} |
|
self.build_text_to_domain_map() |
|
|
|
|
|
if os.path.exists(faiss_index_file_path): |
|
logger.info(f"Loading existing FAISS index from {faiss_index_file_path}...") |
|
self.index = faiss.read_index(faiss_index_file_path) |
|
self.validate_faiss_index() |
|
logger.info("FAISS index loaded and validated successfully.") |
|
else: |
|
self.index = faiss.IndexFlatIP(self.dimension) |
|
logger.info(f"Initialized FAISS IndexFlatIP with dimension {self.dimension}.") |
|
|
|
if not self.index.is_trained: |
|
|
|
dimension = self.query_embeddings_cache[next(iter(self.query_embeddings_cache))].shape[0] |
|
self.index.train(np.array(list(self.query_embeddings_cache.values())).astype(np.float32)) |
|
self.index.add(np.array(list(self.query_embeddings_cache.values())).astype(np.float32)) |
|
|
|
def save_embeddings_cache_hdf5(self, cache_file_path: str): |
|
"""Save embeddings cache to HDF5 file.""" |
|
with h5py.File(cache_file_path, 'w') as hf: |
|
for query, emb in self.query_embeddings_cache.items(): |
|
hf.create_dataset(query, data=emb) |
|
logger.info(f"Embeddings cache saved to {cache_file_path}.") |
|
|
|
def load_embeddings_cache_hdf5(self, cache_file_path: str): |
|
"""Load embeddings cache from HDF5 file.""" |
|
with h5py.File(cache_file_path, 'r') as hf: |
|
for query in hf.keys(): |
|
self.query_embeddings_cache[query] = hf[query][:] |
|
logger.info(f"Embeddings cache loaded from {cache_file_path}.") |
|
|
|
def save_faiss_index(self, faiss_index_file_path: str): |
|
faiss.write_index(self.index, faiss_index_file_path) |
|
logger.info(f"FAISS index saved to {faiss_index_file_path}") |
|
|
|
def load_faiss_index(self, faiss_index_file_path: str): |
|
"""Load FAISS index from specified file path.""" |
|
if os.path.exists(faiss_index_file_path): |
|
self.index = faiss.read_index(faiss_index_file_path) |
|
logger.info(f"FAISS index loaded from {faiss_index_file_path}.") |
|
else: |
|
logger.error(f"FAISS index file not found at {faiss_index_file_path}.") |
|
raise FileNotFoundError(f"FAISS index file not found at {faiss_index_file_path}.") |
|
|
|
def validate_faiss_index(self): |
|
"""Validates FAISS index dimensionality.""" |
|
expected_dim = self.dimension |
|
if self.index.d != expected_dim: |
|
logger.error(f"FAISS index dimension {self.index.d} does not match encoder embedding dimension {expected_dim}.") |
|
raise ValueError("FAISS index dimensionality mismatch.") |
|
logger.info("FAISS index dimension validated successfully.") |
|
|
|
def save_tokenizer(self, tokenizer_dir: str): |
|
self.tokenizer.save_pretrained(tokenizer_dir) |
|
logger.info(f"Tokenizer saved to {tokenizer_dir}") |
|
|
|
def load_tokenizer(self, tokenizer_dir: str): |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) |
|
logger.info(f"Tokenizer loaded from {tokenizer_dir}") |
|
|
|
@staticmethod |
|
def load_json_training_data(data_path: Union[str, Path], debug_samples: Optional[int] = None) -> List[dict]: |
|
""" |
|
Load training data from a JSON file. |
|
Args: |
|
data_path (Union[str, Path]): Path to the JSON file containing dialogues. |
|
debug_samples (Optional[int]): Number of samples to load for debugging. |
|
|
|
Returns: |
|
List[dict]: List of dialogue dictionaries. |
|
""" |
|
logger.info(f"Loading training data from {data_path}...") |
|
data_path = Path(data_path) |
|
if not data_path.exists(): |
|
logger.error(f"Data file {data_path} does not exist.") |
|
return [] |
|
|
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
dialogues = json.load(f) |
|
|
|
if debug_samples is not None: |
|
dialogues = dialogues[:debug_samples] |
|
logger.info(f"Debug mode: Limited to {debug_samples} dialogues") |
|
|
|
logger.info(f"Loaded {len(dialogues)} dialogues.") |
|
return dialogues |
|
|
|
def collect_responses_with_domain(self, dialogues: List[dict]) -> List[Dict[str, str]]: |
|
""" |
|
Extract unique assistant responses and their domains from dialogues. |
|
Returns List[Dict[str: "domain", str: text"]] |
|
""" |
|
response_set = set() |
|
results = [] |
|
|
|
for dialogue in tqdm(dialogues, desc="Processing Dialogues", unit="dialogue"): |
|
domain = dialogue.get('domain', 'other') |
|
turns = dialogue.get('turns', []) |
|
for turn in turns: |
|
speaker = turn.get('speaker') |
|
text = turn.get('text', '').strip() |
|
if speaker == 'assistant' and text: |
|
if len(text) <= self.max_context_length: |
|
|
|
key = (domain, text) |
|
if key not in response_set: |
|
response_set.add(key) |
|
results.append({ |
|
"domain": domain, |
|
"text": text |
|
}) |
|
|
|
logger.info(f"Collected {len(results)} unique assistant responses from dialogues.") |
|
return results |
|
|
|
def _extract_pairs_from_dialogue(self, dialogue: dict) -> List[Tuple[str, str]]: |
|
"""Extract query-response pairs from a dialogue.""" |
|
pairs = [] |
|
turns = dialogue.get('turns', []) |
|
|
|
for i in range(len(turns) - 1): |
|
current_turn = turns[i] |
|
next_turn = turns[i+1] |
|
|
|
if (current_turn.get('speaker') == 'user' and |
|
next_turn.get('speaker') == 'assistant' and |
|
'text' in current_turn and |
|
'text' in next_turn): |
|
|
|
query = current_turn['text'].strip() |
|
positive = next_turn['text'].strip() |
|
pairs.append((query, positive)) |
|
|
|
return pairs |
|
|
|
def compute_and_index_response_embeddings(self): |
|
""" |
|
Compute embeddings for the response pool using SentenceTransformer |
|
and add them to the FAISS index. |
|
""" |
|
if not self.response_pool: |
|
logger.warning("Response pool is empty. No embeddings to compute.") |
|
return |
|
|
|
logger.info("Computing embeddings for the response pool...") |
|
texts = [resp["text"] for resp in self.response_pool] |
|
logger.debug(f"Total texts to embed: {len(texts)}") |
|
|
|
embeddings = [] |
|
batch_size = self.embedding_batch_size |
|
|
|
|
|
with tqdm(total=len(texts), desc="Computing Embeddings", unit="response") as pbar: |
|
for i in range(0, len(texts), batch_size): |
|
batch_texts = texts[i:i + batch_size] |
|
|
|
|
|
batch_embeddings = self.encoder.encode( |
|
batch_texts, |
|
batch_size=batch_size, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True |
|
) |
|
|
|
embeddings.append(batch_embeddings) |
|
pbar.update(len(batch_texts)) |
|
|
|
|
|
all_embeddings = np.vstack(embeddings).astype(np.float32) |
|
logger.info(f"Adding {len(all_embeddings)} response embeddings to FAISS index...") |
|
|
|
|
|
self.index.add(all_embeddings) |
|
|
|
|
|
self.response_embeddings = all_embeddings |
|
logger.info(f"FAISS index now contains {self.index.ntotal} vectors.") |
|
|
|
def _find_hard_negatives(self, queries: List[str], positives: List[str], batch_size: int = 128) -> List[List[str]]: |
|
""" |
|
Find hard negatives for a batch of queries using FAISS search. |
|
Fallback: in-domain negatives, then random negatives when needed. |
|
""" |
|
retry_count = 0 |
|
total_responses = len(self.response_pool) |
|
|
|
while retry_count < self.max_retries: |
|
try: |
|
|
|
query_embeddings = [] |
|
for i in range(0, len(queries), batch_size): |
|
sub_queries = queries[i : i + batch_size] |
|
sub_embeds = [self.query_embeddings_cache[q] for q in sub_queries] |
|
sub_embeds = np.vstack(sub_embeds).astype(np.float32) |
|
faiss.normalize_L2(sub_embeds) |
|
query_embeddings.append(sub_embeds) |
|
|
|
query_embeddings = np.vstack(query_embeddings) |
|
query_embeddings = np.ascontiguousarray(query_embeddings) |
|
|
|
|
|
distances, indices = self.index.search(query_embeddings, self.neg_samples) |
|
|
|
all_negatives = [] |
|
|
|
for query_indices, query_text, pos_text in zip(indices, queries, positives): |
|
negative_list = [] |
|
|
|
|
|
seen = {pos_text.strip()} |
|
|
|
domain_of_positive = self._detect_domain_for_text(pos_text) |
|
|
|
|
|
for idx in query_indices: |
|
if 0 <= idx < total_responses: |
|
candidate_dict = self.response_pool[idx] |
|
candidate_text = candidate_dict["text"].strip() |
|
if candidate_text and candidate_text not in seen: |
|
seen.add(candidate_text) |
|
negative_list.append(candidate_text) |
|
if len(negative_list) >= self.neg_samples: |
|
break |
|
|
|
|
|
if len(negative_list) < self.neg_samples: |
|
needed = self.neg_samples - len(negative_list) |
|
|
|
random_negatives = self._get_random_negatives(needed, seen, domain=domain_of_positive) |
|
negative_list.extend(random_negatives) |
|
|
|
all_negatives.append(negative_list) |
|
|
|
return all_negatives |
|
|
|
except KeyError as ke: |
|
retry_count += 1 |
|
logger.warning(f"Hard negative search attempt {retry_count} failed due to missing embeddings: {ke}") |
|
if retry_count == self.max_retries: |
|
logger.error("Max retries reached for hard negative search due to missing embeddings.") |
|
return self._fallback_negatives(queries, positives, reason="key_error") |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
except Exception as e: |
|
retry_count += 1 |
|
logger.warning(f"Hard negative search attempt {retry_count} failed: {e}") |
|
if retry_count == self.max_retries: |
|
logger.error("Max retries reached for hard negative search.") |
|
return self._fallback_negatives(queries, positives, reason="generic_error") |
|
gc.collect() |
|
if tf.config.list_physical_devices('GPU'): |
|
tf.keras.backend.clear_session() |
|
|
|
def _detect_domain_for_text(self, text: str) -> Optional[str]: |
|
""" |
|
Domain detection for related negatives. |
|
""" |
|
stripped_text = text.strip() |
|
return self._text_domain_map.get(stripped_text, None) |
|
|
|
def _get_random_negatives(self, needed: int, seen: set, domain: Optional[str] = None) -> List[str]: |
|
""" |
|
Return a list of negative texts from the same domain. Fall back to any domain. |
|
""" |
|
|
|
if domain: |
|
domain_texts = [r["text"] for r in self.response_pool if r["domain"] == domain] |
|
|
|
if len(domain_texts) < needed * 2: |
|
domain_texts = [r["text"] for r in self.response_pool] |
|
else: |
|
domain_texts = [r["text"] for r in self.response_pool] |
|
|
|
negatives = [] |
|
tries = 0 |
|
max_tries = needed * 10 |
|
while len(negatives) < needed and tries < max_tries: |
|
tries += 1 |
|
candidate = random.choice(domain_texts).strip() |
|
if candidate and candidate not in seen: |
|
negatives.append(candidate) |
|
seen.add(candidate) |
|
|
|
if len(negatives) < needed: |
|
logger.warning(f"Could not find enough domain-based random negatives; needed {needed}, got {len(negatives)}.") |
|
|
|
return negatives |
|
|
|
def _fallback_negatives(self, queries: List[str], positives: List[str], reason: str) -> List[List[str]]: |
|
""" |
|
Called if FAISS fails or embeddings are missing. |
|
We use entirely random negatives for each query, ignoring FAISS, |
|
but still attempt domain-based selection if possible. |
|
""" |
|
logger.error(f"Falling back to random negatives due to: {reason}") |
|
all_negatives = [] |
|
|
|
for pos_text in positives: |
|
|
|
seen = {pos_text.strip()} |
|
|
|
|
|
domain_of_positive = self._detect_domain_for_text(pos_text) |
|
|
|
|
|
negs = self._get_random_negatives(self.neg_samples, seen, domain=domain_of_positive) |
|
all_negatives.append(negs) |
|
|
|
return all_negatives |
|
|
|
def build_text_to_domain_map(self): |
|
""" |
|
Build O(1) lookup dict: text -> domain for hard negative sampling. |
|
""" |
|
self._text_domain_map = {} |
|
|
|
for item in self.response_pool: |
|
stripped_text = item["text"].strip() |
|
domain = item["domain"] |
|
|
|
if stripped_text in self._text_domain_map: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
continue |
|
else: |
|
|
|
self._text_domain_map[stripped_text] = domain |
|
|
|
logger.info(f"Built text -> domain map with {len(self._text_domain_map)} unique text entries.") |
|
|
|
def encode_query(self, query: str) -> np.ndarray: |
|
"""Generate embedding for a query string.""" |
|
return self.encoder.encode(query, convert_to_numpy=True) |
|
|
|
def encode_responses( |
|
self, |
|
responses: List[str], |
|
context: Optional[List[Tuple[str, str]]] = None |
|
) -> np.ndarray: |
|
""" |
|
Encode multiple response texts into embeddings, injecting <ASSISTANT> literally. |
|
""" |
|
USER_TOKEN = "<USER>" |
|
ASSISTANT_TOKEN = "<ASSISTANT>" |
|
|
|
if context: |
|
relevant_history = context[-self.config.max_context_turns:] |
|
prepared = [] |
|
for resp in responses: |
|
context_str_parts = [] |
|
|
|
for (u_text, a_text) in relevant_history: |
|
context_str_parts.append( |
|
f"{USER_TOKEN} {u_text} {ASSISTANT_TOKEN} {a_text}" |
|
) |
|
context_str = " ".join(context_str_parts) |
|
|
|
full_resp = f"{context_str} {ASSISTANT_TOKEN} {resp}" |
|
prepared.append(full_resp) |
|
else: |
|
|
|
prepared = [f"{ASSISTANT_TOKEN} {r}" for r in responses] |
|
|
|
|
|
encodings = self.tokenizer( |
|
prepared, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=self.max_context_length, |
|
return_tensors='np' |
|
) |
|
input_ids = encodings['input_ids'] |
|
|
|
|
|
max_id = np.max(input_ids) |
|
vocab_size = len(self.tokenizer) |
|
if max_id >= vocab_size: |
|
logger.error(f"Token ID {max_id} >= tokenizer vocab size {vocab_size}") |
|
raise ValueError("Token ID exceeds vocabulary size.") |
|
|
|
|
|
embeddings = self.encoder.encode(prepared, convert_to_numpy=True) |
|
|
|
return embeddings.astype('float32') |
|
|
|
def retrieve_responses(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]: |
|
""" |
|
Retrieve top-k responses for a query using FAISS. |
|
""" |
|
query_embedding = self.encode_query(query).reshape(1, -1).astype("float32") |
|
distances, indices = self.index.search(query_embedding, top_k) |
|
|
|
results = [] |
|
for idx, dist in tqdm( |
|
zip(indices[0], distances[0]), |
|
disable=True |
|
): |
|
if idx < 0: |
|
continue |
|
response = self.response_pool[idx] |
|
results.append((response, dist)) |
|
|
|
return results |
|
|
|
def prepare_and_save_data(self, dialogues: List[dict], tf_record_path: str, batch_size: int = 32): |
|
""" |
|
Batch-Process dialogues and save to TFRecord file. |
|
""" |
|
logger.info(f"Preparing and saving data to {tf_record_path}...") |
|
|
|
num_dialogues = len(dialogues) |
|
num_batches = math.ceil(num_dialogues / batch_size) |
|
|
|
with tf.io.TFRecordWriter(tf_record_path) as writer: |
|
with tqdm(total=num_batches, desc="Preparing Data Batches", unit="batch") as pbar: |
|
for i in range(num_batches): |
|
start_idx = i * batch_size |
|
end_idx = min(start_idx + batch_size, num_dialogues) |
|
batch_dialogues = dialogues[start_idx:end_idx] |
|
|
|
|
|
queries = [] |
|
positives = [] |
|
for dialogue in batch_dialogues: |
|
pairs = self._extract_pairs_from_dialogue(dialogue) |
|
for query, positive in pairs: |
|
if len(query) <= self.max_context_length and len(positive) <= self.max_context_length: |
|
queries.append(query) |
|
positives.append(positive) |
|
|
|
if not queries: |
|
pbar.update(1) |
|
continue |
|
|
|
|
|
try: |
|
self._compute_embeddings(queries) |
|
except Exception as e: |
|
logger.error(f"Error computing embeddings: {e}") |
|
pbar.update(1) |
|
continue |
|
|
|
|
|
try: |
|
hard_negatives = self._find_hard_negatives(queries, positives) |
|
except Exception as e: |
|
logger.error(f"Error finding hard negatives: {e}") |
|
pbar.update(1) |
|
continue |
|
|
|
|
|
try: |
|
encoded_queries = self.tokenizer.batch_encode_plus( |
|
queries, |
|
max_length=self.config.max_context_length, |
|
truncation=True, |
|
padding='max_length', |
|
return_tensors='tf' |
|
) |
|
encoded_positives = self.tokenizer.batch_encode_plus( |
|
positives, |
|
max_length=self.config.max_context_length, |
|
truncation=True, |
|
padding='max_length', |
|
return_tensors='tf' |
|
) |
|
except Exception as e: |
|
logger.error(f"Error during tokenization: {e}") |
|
pbar.update(1) |
|
continue |
|
|
|
|
|
|
|
try: |
|
flattened_negatives = [neg for sublist in hard_negatives for neg in sublist] |
|
encoded_negatives = self.tokenizer.batch_encode_plus( |
|
flattened_negatives, |
|
max_length=self.config.max_context_length, |
|
truncation=True, |
|
padding='max_length', |
|
return_tensors='tf' |
|
) |
|
|
|
|
|
num_negatives = self.config.neg_samples |
|
reshaped_negatives = encoded_negatives['input_ids'].numpy().reshape(-1, num_negatives, self.config.max_context_length) |
|
except Exception as e: |
|
logger.error(f"Error during negatives tokenization: {e}") |
|
pbar.update(1) |
|
continue |
|
|
|
|
|
for j in range(len(queries)): |
|
try: |
|
q_id = encoded_queries['input_ids'][j].numpy() |
|
p_id = encoded_positives['input_ids'][j].numpy() |
|
n_id = reshaped_negatives[j] |
|
|
|
feature = { |
|
'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=q_id)), |
|
'positive_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=p_id)), |
|
'negative_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=n_id.flatten())), |
|
} |
|
example = tf.train.Example(features=tf.train.Features(feature=feature)) |
|
writer.write(example.SerializeToString()) |
|
except Exception as e: |
|
logger.error(f"Error serializing example {j} in batch {i}: {e}") |
|
continue |
|
|
|
|
|
pbar.update(1) |
|
|
|
logger.info(f"Data preparation complete. TFRecord saved.") |
|
|
|
def _compute_embeddings(self, queries: List[str]) -> None: |
|
""" |
|
Compute embeddings for new queries and update the cache. |
|
""" |
|
new_queries = [q for q in queries if q not in self.query_embeddings_cache] |
|
if not new_queries: |
|
return |
|
|
|
|
|
new_embeddings = [] |
|
for i in range(0, len(new_queries), self.embedding_batch_size): |
|
batch_queries = new_queries[i:i + self.embedding_batch_size] |
|
encoded = self.tokenizer( |
|
batch_queries, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_context_length, |
|
return_tensors='tf' |
|
) |
|
batch_embeddings = self.encoder(encoded['input_ids'], training=False).numpy() |
|
faiss.normalize_L2(batch_embeddings) |
|
new_embeddings.extend(batch_embeddings) |
|
|
|
|
|
for query, emb in zip(new_queries, new_embeddings): |
|
self.query_embeddings_cache[query] = emb |
|
|
|
def data_generator(self, dialogues: List[dict]) -> Generator[Tuple[str, str, List[str]], None, None]: |
|
""" |
|
Generate training examples: (query, positive, [hard_negatives]). |
|
""" |
|
total_dialogues = len(dialogues) |
|
logger.debug(f"Total dialogues to process: {total_dialogues}") |
|
|
|
with tqdm(total=total_dialogues, desc="Processing Dialogues", unit="dialogue") as pbar: |
|
for dialogue in dialogues: |
|
pairs = self._extract_pairs_from_dialogue(dialogue) |
|
for query, positive in pairs: |
|
|
|
self._compute_embeddings([query]) |
|
hard_negatives = self._find_hard_negatives([query], [positive])[0] |
|
yield (query, positive, hard_negatives) |
|
pbar.update(1) |
|
|
|
def get_tf_dataset(self, dialogues: List[dict], batch_size: int) -> tf.data.Dataset: |
|
""" |
|
Creates a tf.data.Dataset for streaming training. |
|
yields (input_ids_query, input_ids_positive, input_ids_negatives). |
|
""" |
|
|
|
dataset = tf.data.Dataset.from_generator( |
|
lambda: self.data_generator(dialogues), |
|
output_signature=( |
|
tf.TensorSpec(shape=(), dtype=tf.string), |
|
tf.TensorSpec(shape=(), dtype=tf.string), |
|
tf.TensorSpec(shape=(self.neg_samples,), dtype=tf.string) |
|
) |
|
) |
|
|
|
|
|
|
|
dataset = dataset.batch(batch_size, drop_remainder=True) |
|
dataset = dataset.map( |
|
lambda q, p, n: self._tokenize_triple(q, p, n), |
|
num_parallel_calls=1 |
|
) |
|
|
|
dataset = dataset.prefetch(tf.data.AUTOTUNE) |
|
return dataset |
|
|
|
def _tokenize_triple( |
|
self, |
|
q: tf.Tensor, |
|
p: tf.Tensor, |
|
n: tf.Tensor |
|
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: |
|
""" |
|
Wraps a Python function. Convert tf.Tensors of strings -> Python lists of strings -> HF tokenizer -> Tensors of IDs. |
|
q is shape [batch_size], p is shape [batch_size], n is shape [batch_size, neg_samples] (list of negatives). |
|
""" |
|
|
|
q_ids, p_ids, n_ids = tf.py_function( |
|
func=self._tokenize_triple_py, |
|
inp=[q, p, n, tf.constant(self.max_context_length), tf.constant(self.neg_samples)], |
|
Tout=[tf.int32, tf.int32, tf.int32] |
|
) |
|
|
|
|
|
q_ids.set_shape([None, self.max_context_length]) |
|
p_ids.set_shape([None, self.max_context_length]) |
|
n_ids.set_shape([None, self.neg_samples, self.max_context_length]) |
|
|
|
return q_ids, p_ids, n_ids |
|
|
|
def _tokenize_triple_py( |
|
self, |
|
q: tf.Tensor, |
|
p: tf.Tensor, |
|
n: tf.Tensor, |
|
max_len: tf.Tensor, |
|
neg_samples: tf.Tensor |
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
""" |
|
Decodes tf.string Tensor to Python List[str], then tokenize. |
|
Reshapes negatives to [batch_size, neg_samples, max_length]. |
|
Returns np.array(int32) for (q_ids, p_ids, n_ids). |
|
|
|
q: shape [batch_size], p: shape [batch_size] |
|
n: shape [batch_size, neg_samples] |
|
max_len: int |
|
neg_samples: int |
|
""" |
|
max_len = int(max_len.numpy()) |
|
neg_samples = int(neg_samples.numpy()) |
|
|
|
|
|
q_list = [q_i.decode("utf-8") for q_i in q.numpy()] |
|
p_list = [p_i.decode("utf-8") for p_i in p.numpy()] |
|
|
|
|
|
n_list = [] |
|
for row in n.numpy(): |
|
|
|
decoded = [neg.decode("utf-8") for neg in row] |
|
n_list.append(decoded) |
|
|
|
|
|
q_enc = self.tokenizer( |
|
q_list, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_len, |
|
return_tensors="np" |
|
) |
|
p_enc = self.tokenizer( |
|
p_list, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_len, |
|
return_tensors="np" |
|
) |
|
|
|
|
|
|
|
flattened_negatives = [neg for row in n_list for neg in row] |
|
if len(flattened_negatives) == 0: |
|
|
|
n_ids = np.zeros((len(q_list), neg_samples, max_len), dtype=np.int32) |
|
else: |
|
n_enc = self.tokenizer( |
|
flattened_negatives, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=max_len, |
|
return_tensors="np" |
|
) |
|
|
|
n_input_ids = n_enc["input_ids"] |
|
|
|
|
|
batch_size = len(q_list) |
|
n_ids_list = [] |
|
for i in range(batch_size): |
|
start_idx = i * neg_samples |
|
end_idx = start_idx + neg_samples |
|
row_negs = n_input_ids[start_idx:end_idx] |
|
|
|
|
|
if row_negs.shape[0] < neg_samples: |
|
deficit = neg_samples - row_negs.shape[0] |
|
pad_arr = np.zeros((deficit, max_len), dtype=np.int32) |
|
row_negs = np.concatenate([row_negs, pad_arr], axis=0) |
|
|
|
n_ids_list.append(row_negs) |
|
|
|
|
|
n_ids = np.stack(n_ids_list, axis=0) |
|
|
|
|
|
q_ids = q_enc["input_ids"].astype(np.int32) |
|
p_ids = p_enc["input_ids"].astype(np.int32) |
|
n_ids = n_ids.astype(np.int32) |
|
|
|
return q_ids, p_ids, n_ids |
|
|