NeMo_Canary / batch_infer.py
Respair's picture
Create batch_infer.py
7e5c8d4 verified
# from concurrent.futures import ProcessPoolExecutor, as_completed
# import time
# from datetime import timedelta
# import pandas as pd
# import torch
# import warnings
# import logging
# import os
# import traceback
# # --- Load and filter dataframe ---
# df = pd.read_csv("/home/ubuntu/ttsar/ASR_DATA/train_large.csv")
# print('before filtering: ')
# print(df.shape)
# df = df[~df['filename'].str.contains("Sakura, Moyu")]
# print('after filtering: ')
# print(df.shape)
# total_samples = len(df)
# # --- PyTorch settings ---
# torch.set_float32_matmul_precision('high')
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
# def process_batch(batch_data):
# """Process a batch of audio files"""
# batch_id, start_idx, audio_files, config_path, checkpoint_path = batch_data
# model = None # Initialize model to None for the finally block
# try:
# # Import and configure libraries within the worker process
# import torch
# import nemo.collections.asr as nemo_asr
# from omegaconf import OmegaConf, open_dict
# import warnings
# import logging
# # Suppress logs within the worker process to keep the main output clean
# logging.getLogger('nemo_logger').setLevel(logging.ERROR)
# logging.disable(logging.CRITICAL)
# warnings.filterwarnings('ignore')
# # Load model for this worker
# config = OmegaConf.load(config_path)
# with open_dict(config.cfg):
# for ds in ['train_ds', 'validation_ds', 'test_ds']:
# if ds in config.cfg:
# config.cfg[ds].defer_setup = True
# model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg)
# checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
# model.load_state_dict(checkpoint['state_dict'], strict=False)
# model = model.eval().cuda()
# decode_cfg = model.cfg.decoding
# decode_cfg.beam.beam_size = 4
# model.change_decoding_strategy(decode_cfg)
# # Transcribe
# start = time.time()
# hypotheses = model.transcribe(
# audio=audio_files,
# batch_size=64,
# source_lang='ja',
# target_lang='ja',
# task='asr',
# pnc='no',
# verbose=False,
# num_workers=0,
# channel_selector=0
# )
# results = [hyp.text for hyp in hypotheses]
# return batch_id, start_idx, results, len(audio_files), time.time() - start
# finally:
# # NEW: Ensure GPU memory is cleared in the worker process
# if model is not None:
# del model
# import torch
# torch.cuda.empty_cache()
# # --- Parameters ---
# chunk_size = 512 * 4
# n_workers = 4
# checkpoint_interval = 250_000
# config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml"
# checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt"
# # --- Prepare data chunks ---
# audio_files = df['filename'].tolist()
# chunks = []
# for i in range(0, total_samples, chunk_size):
# end_idx = min(i + chunk_size, total_samples)
# chunk_files = audio_files[i:end_idx]
# chunks.append({
# 'batch_id': len(chunks),
# 'start_idx': i,
# 'files': chunk_files,
# 'config_path': config_path,
# 'checkpoint_path': checkpoint_path
# })
# print(f"Processing {total_samples:,} samples")
# print(f"Chunks: {len(chunks)} Γ— ~{chunk_size} samples")
# print(f"Workers: {n_workers}")
# print(f"Checkpoint interval: every {checkpoint_interval:,} samples")
# print("-" * 50)
# # --- Initialize tracking variables ---
# all_results = {}
# failed_chunks = []
# start_time = time.time()
# samples_done = 0
# last_checkpoint = 0
# interrupted = False
# # Initialize 'text' column with a placeholder
# df['text'] = pd.NA
# # --- Main Processing Loop with Graceful Shutdown ---
# try:
# with ProcessPoolExecutor(max_workers=n_workers) as executor:
# future_to_chunk = {
# executor.submit(process_batch,
# (chunk['batch_id'], chunk['start_idx'], chunk['files'], chunk['config_path'], chunk['checkpoint_path'])): chunk
# for chunk in chunks
# }
# for future in as_completed(future_to_chunk):
# original_chunk = future_to_chunk[future]
# batch_id = original_chunk['batch_id']
# try:
# _batch_id, start_idx, results, count, batch_time = future.result()
# all_results[start_idx] = results
# samples_done += count
# end_idx = start_idx + len(results)
# if len(df.iloc[start_idx:end_idx]) == len(results):
# df.loc[start_idx:end_idx-1, 'text'] = results
# else:
# raise ValueError(f"Length mismatch: DataFrame slice vs results")
# elapsed = time.time() - start_time
# speed = samples_done / elapsed if elapsed > 0 else 0
# remaining = total_samples - samples_done
# eta = remaining / speed if speed > 0 else 0
# print(f"βœ“ Batch {batch_id}/{len(chunks)-1} done ({count} samples in {batch_time:.1f}s) | "
# f"Total: {samples_done:,}/{total_samples:,} ({100*samples_done/total_samples:.1f}%) | "
# f"Speed: {speed:.1f} samples/s | "
# f"ETA: {timedelta(seconds=int(eta))}")
# if samples_done - last_checkpoint >= checkpoint_interval or samples_done == total_samples:
# checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{samples_done}.csv"
# df.to_csv(checkpoint_file, index=False)
# print(f" βœ“ Checkpoint saved: {checkpoint_file}")
# last_checkpoint = samples_done
# except Exception:
# failed_chunks.append(original_chunk)
# print("-" * 20 + " ERROR " + "-" * 20)
# print(f"βœ— Batch {batch_id} FAILED. Start index: {original_chunk['start_idx']}. Files: {len(original_chunk['files'])}")
# traceback.print_exc()
# print("-" * 47)
# except KeyboardInterrupt:
# interrupted = True
# print("\n\n" + "="*50)
# print("! KEYBOARD INTERRUPT DETECTED !")
# print("Stopping workers and saving all completed progress...")
# print("The script will exit shortly.")
# print("="*50 + "\n")
# # The `with ProcessPoolExecutor` context manager will automatically
# # handle shutting down the worker processes when we exit this block.
# # --- Finalization and Reporting (this block now runs on completion OR interruption) ---
# total_time = time.time() - start_time
# print("-" * 50)
# if interrupted:
# print(f"PROCESS INTERRUPTED")
# else:
# print(f"TRANSCRIPTION COMPLETE!")
# print(f"Total time elapsed: {timedelta(seconds=int(total_time))}")
# if total_time > 0 and samples_done > 0:
# print(f"Average speed (on completed work): {samples_done/total_time:.1f} samples/second")
# # Save final result
# final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv"
# df.to_csv(final_output, index=False)
# print(f"Final progress saved to: {final_output}")
# print("-" * 50)
# # --- Summary and Verification ---
# successful_transcriptions = df['text'].notna().sum()
# print("Final Run Summary:")
# print(f" - Successfully transcribed: {successful_transcriptions:,} samples")
# print(f" - Failed batches: {len(failed_chunks)}")
# print(f" - Total samples in failed batches: {sum(len(c['files']) for c in failed_chunks):,}")
# if failed_chunks:
# failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt"
# with open(failed_files_path, 'w') as f:
# for chunk in failed_chunks:
# for file_path in chunk['files']:
# f.write(f"{file_path}\n")
# print(f"\nList of files from failed batches saved to: {failed_files_path}")
# print("-" * 50)
#NOTE #NOTE
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
from datetime import timedelta
import pandas as pd
import torch
import warnings
import logging
import os
import traceback
# --- LOAD CHECKPOINT ---
checkpoint_file = "/home/ubuntu/ttsar/csv_kanad/sing/cg_shani_sing.csv"
print(f"Loading checkpoint from: {checkpoint_file}")
df = pd.read_csv(checkpoint_file)
print(f"Checkpoint loaded. Shape: {df.shape}")
# Check if 'text' column exists, if not create it
if 'text' not in df.columns:
df['text'] = pd.NA
# --- FIND ALL MISSING TRANSCRIPTIONS ---
missing_mask = df['text'].isna()
missing_indices = df[missing_mask].index.tolist()
already_done = (~missing_mask).sum()
print(f"Already transcribed: {already_done:,} samples")
print(f"Missing transcriptions: {len(missing_indices):,} samples")
print("-" * 50)
if len(missing_indices) == 0:
print("All samples already transcribed!")
exit(0)
# --- PyTorch settings ---
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def process_batch(batch_data):
"""Process a batch of audio files"""
batch_id, indices, audio_files, config_path, checkpoint_path = batch_data
model = None
try:
# Import and configure libraries within the worker process
import torch
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf, open_dict
import warnings
import logging
# Suppress logs within the worker process
logging.getLogger('nemo_logger').setLevel(logging.ERROR)
logging.disable(logging.CRITICAL)
warnings.filterwarnings('ignore')
# Load model for this worker
config = OmegaConf.load(config_path)
with open_dict(config.cfg):
for ds in ['train_ds', 'validation_ds', 'test_ds']:
if ds in config.cfg:
config.cfg[ds].defer_setup = True
model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg)
checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)
model.load_state_dict(checkpoint['state_dict'], strict=False)
model = model.eval().cuda().bfloat16()
decode_cfg = model.cfg.decoding
decode_cfg.beam.beam_size = 1
model.change_decoding_strategy(decode_cfg)
# Transcribe
start = time.time()
try:
hypotheses = model.transcribe(
audio=audio_files,
batch_size=64,
source_lang='ja',
target_lang='ja',
task='asr',
pnc='no',
verbose=False,
num_workers=0,
channel_selector=0
)
results = [hyp.text for hyp in hypotheses]
except Exception as e:
print(f"Transcription error in batch {batch_id}: {str(e)}")
# Return empty results list on transcription failure
results = []
# Pad results with None if we got fewer results than expected
while len(results) < len(audio_files):
results.append(None)
# Count successful transcriptions
success_count = len([r for r in results if r is not None])
# Return indices and results as a tuple for pairing
return batch_id, list(zip(indices, results)), success_count, time.time() - start
finally:
if model is not None:
del model
import torch
torch.cuda.empty_cache()
# --- Parameters ---
chunk_size = 512 * 4 # 2048
n_workers = 6
checkpoint_interval = 250_000
config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml"
checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt"
# --- Create batches from missing indices ---
chunks = []
for i in range(0, len(missing_indices), chunk_size):
batch_indices = missing_indices[i:i+chunk_size]
batch_files = df.loc[batch_indices, 'filename'].tolist()
chunks.append({
'batch_id': len(chunks),
'indices': batch_indices,
'files': batch_files,
'config_path': config_path,
'checkpoint_path': checkpoint_path
})
print(f"Total batches to process: {len(chunks)}")
print(f"Batch size: ~{chunk_size} samples")
print(f"Workers: {n_workers}")
print(f"Checkpoint interval: every {checkpoint_interval:,} samples")
print("-" * 50)
# --- Initialize tracking variables ---
all_results = {}
failed_chunks = []
failed_files_list = []
start_time = time.time()
samples_done = 0
samples_failed = 0
last_checkpoint = 0
interrupted = False
total_to_process = len(missing_indices)
# --- Main Processing Loop ---
try:
with ProcessPoolExecutor(max_workers=n_workers) as executor:
future_to_chunk = {
executor.submit(process_batch,
(chunk['batch_id'], chunk['indices'], chunk['files'],
chunk['config_path'], chunk['checkpoint_path'])): chunk
for chunk in chunks
}
for future in as_completed(future_to_chunk):
original_chunk = future_to_chunk[future]
batch_id = original_chunk['batch_id']
try:
_batch_id, index_result_pairs, success_count, batch_time = future.result()
# Update DataFrame with results
failed_in_batch = 0
for idx, result in index_result_pairs:
if result is not None:
df.loc[idx, 'text'] = result
else:
df.loc[idx, 'text'] = "[FAILED]"
failed_in_batch += 1
failed_files_list.append(df.loc[idx, 'filename'])
samples_done += success_count
samples_failed += failed_in_batch
elapsed = time.time() - start_time
speed = samples_done / elapsed if elapsed > 0 else 0
remaining = total_to_process - samples_done - samples_failed
eta = remaining / speed if speed > 0 else 0
current_total = already_done + samples_done
status = f"βœ“ Batch {batch_id}/{len(chunks)-1} done ({success_count} success"
if failed_in_batch > 0:
status += f", {failed_in_batch} failed"
status += f" in {batch_time:.1f}s)"
print(f"{status} | "
f"Processed: {samples_done:,}/{total_to_process:,} | "
f"Total: {current_total:,}/{len(df):,} ({100*current_total/len(df):.1f}%) | "
f"Speed: {speed:.1f} samples/s | "
f"ETA: {timedelta(seconds=int(eta))}")
# Save checkpoint
if samples_done - last_checkpoint >= checkpoint_interval or (samples_done + samples_failed) >= total_to_process:
checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{current_total}.csv"
df.to_csv(checkpoint_file, index=False)
print(f" βœ“ Checkpoint saved: {checkpoint_file}")
last_checkpoint = samples_done
except Exception as e:
failed_chunks.append(original_chunk)
print("-" * 20 + " ERROR " + "-" * 20)
print(f"βœ— Batch {batch_id} FAILED. Indices count: {len(original_chunk['indices'])}")
print(f"Error: {str(e)}")
traceback.print_exc()
print("-" * 47)
except KeyboardInterrupt:
interrupted = True
print("\n\n" + "="*50)
print("! KEYBOARD INTERRUPT DETECTED !")
print("Stopping workers and saving progress...")
print("="*50 + "\n")
# --- Finalization ---
total_time = time.time() - start_time
print("-" * 50)
if interrupted:
print(f"PROCESS INTERRUPTED")
else:
print(f"PROCESSING COMPLETE!")
print(f"Session time: {timedelta(seconds=int(total_time))}")
print(f"Samples successfully processed: {samples_done:,}")
print(f"Samples failed: {samples_failed:,}")
if total_time > 0 and samples_done > 0:
print(f"Average speed: {samples_done/total_time:.1f} samples/second")
# Save final result
final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv"
df.to_csv(final_output, index=False)
print(f"Final output saved to: {final_output}")
print("-" * 50)
# --- Summary ---
successful_transcriptions = df['text'].notna().sum() - (df['text'] == "[FAILED]").sum()
failed_transcriptions = (df['text'] == "[FAILED]").sum()
remaining_missing = df['text'].isna().sum()
print("Summary:")
print(f" - Total dataset size: {len(df):,} samples")
print(f" - Successfully transcribed: {successful_transcriptions:,} samples")
print(f" - Failed transcriptions: {failed_transcriptions:,} samples")
print(f" - Still missing (NaN): {remaining_missing:,} samples")
print(f" - Processed this session: {samples_done:,} successful, {samples_failed:,} failed")
print(f" - Failed batches (entire batch): {len(failed_chunks)}")
# Save list of failed files
if failed_files_list:
failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt"
with open(failed_files_path, 'w') as f:
for file_path in failed_files_list:
f.write(f"{file_path}\n")
print(f"\nFailed files saved to: {failed_files_path}")
if failed_chunks:
failed_batches_path = "/home/ubuntu/ttsar/ASR_DATA/failed_batches.txt"
with open(failed_batches_path, 'w') as f:
for chunk in failed_chunks:
f.write(f"Batch {chunk['batch_id']}: indices {chunk['indices'][:5]}... ({len(chunk['indices'])} total)\n")
print(f"Failed batch info saved to: {failed_batches_path}")
print("-" * 50)