# from concurrent.futures import ProcessPoolExecutor, as_completed | |
# import time | |
# from datetime import timedelta | |
# import pandas as pd | |
# import torch | |
# import warnings | |
# import logging | |
# import os | |
# import traceback | |
# # --- Load and filter dataframe --- | |
# df = pd.read_csv("/home/ubuntu/ttsar/ASR_DATA/train_large.csv") | |
# print('before filtering: ') | |
# print(df.shape) | |
# df = df[~df['filename'].str.contains("Sakura, Moyu")] | |
# print('after filtering: ') | |
# print(df.shape) | |
# total_samples = len(df) | |
# # --- PyTorch settings --- | |
# torch.set_float32_matmul_precision('high') | |
# torch.backends.cuda.matmul.allow_tf32 = True | |
# torch.backends.cudnn.allow_tf32 = True | |
# def process_batch(batch_data): | |
# """Process a batch of audio files""" | |
# batch_id, start_idx, audio_files, config_path, checkpoint_path = batch_data | |
# model = None # Initialize model to None for the finally block | |
# try: | |
# # Import and configure libraries within the worker process | |
# import torch | |
# import nemo.collections.asr as nemo_asr | |
# from omegaconf import OmegaConf, open_dict | |
# import warnings | |
# import logging | |
# # Suppress logs within the worker process to keep the main output clean | |
# logging.getLogger('nemo_logger').setLevel(logging.ERROR) | |
# logging.disable(logging.CRITICAL) | |
# warnings.filterwarnings('ignore') | |
# # Load model for this worker | |
# config = OmegaConf.load(config_path) | |
# with open_dict(config.cfg): | |
# for ds in ['train_ds', 'validation_ds', 'test_ds']: | |
# if ds in config.cfg: | |
# config.cfg[ds].defer_setup = True | |
# model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg) | |
# checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) | |
# model.load_state_dict(checkpoint['state_dict'], strict=False) | |
# model = model.eval().cuda() | |
# decode_cfg = model.cfg.decoding | |
# decode_cfg.beam.beam_size = 4 | |
# model.change_decoding_strategy(decode_cfg) | |
# # Transcribe | |
# start = time.time() | |
# hypotheses = model.transcribe( | |
# audio=audio_files, | |
# batch_size=64, | |
# source_lang='ja', | |
# target_lang='ja', | |
# task='asr', | |
# pnc='no', | |
# verbose=False, | |
# num_workers=0, | |
# channel_selector=0 | |
# ) | |
# results = [hyp.text for hyp in hypotheses] | |
# return batch_id, start_idx, results, len(audio_files), time.time() - start | |
# finally: | |
# # NEW: Ensure GPU memory is cleared in the worker process | |
# if model is not None: | |
# del model | |
# import torch | |
# torch.cuda.empty_cache() | |
# # --- Parameters --- | |
# chunk_size = 512 * 4 | |
# n_workers = 4 | |
# checkpoint_interval = 250_000 | |
# config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml" | |
# checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt" | |
# # --- Prepare data chunks --- | |
# audio_files = df['filename'].tolist() | |
# chunks = [] | |
# for i in range(0, total_samples, chunk_size): | |
# end_idx = min(i + chunk_size, total_samples) | |
# chunk_files = audio_files[i:end_idx] | |
# chunks.append({ | |
# 'batch_id': len(chunks), | |
# 'start_idx': i, | |
# 'files': chunk_files, | |
# 'config_path': config_path, | |
# 'checkpoint_path': checkpoint_path | |
# }) | |
# print(f"Processing {total_samples:,} samples") | |
# print(f"Chunks: {len(chunks)} Γ ~{chunk_size} samples") | |
# print(f"Workers: {n_workers}") | |
# print(f"Checkpoint interval: every {checkpoint_interval:,} samples") | |
# print("-" * 50) | |
# # --- Initialize tracking variables --- | |
# all_results = {} | |
# failed_chunks = [] | |
# start_time = time.time() | |
# samples_done = 0 | |
# last_checkpoint = 0 | |
# interrupted = False | |
# # Initialize 'text' column with a placeholder | |
# df['text'] = pd.NA | |
# # --- Main Processing Loop with Graceful Shutdown --- | |
# try: | |
# with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
# future_to_chunk = { | |
# executor.submit(process_batch, | |
# (chunk['batch_id'], chunk['start_idx'], chunk['files'], chunk['config_path'], chunk['checkpoint_path'])): chunk | |
# for chunk in chunks | |
# } | |
# for future in as_completed(future_to_chunk): | |
# original_chunk = future_to_chunk[future] | |
# batch_id = original_chunk['batch_id'] | |
# try: | |
# _batch_id, start_idx, results, count, batch_time = future.result() | |
# all_results[start_idx] = results | |
# samples_done += count | |
# end_idx = start_idx + len(results) | |
# if len(df.iloc[start_idx:end_idx]) == len(results): | |
# df.loc[start_idx:end_idx-1, 'text'] = results | |
# else: | |
# raise ValueError(f"Length mismatch: DataFrame slice vs results") | |
# elapsed = time.time() - start_time | |
# speed = samples_done / elapsed if elapsed > 0 else 0 | |
# remaining = total_samples - samples_done | |
# eta = remaining / speed if speed > 0 else 0 | |
# print(f"β Batch {batch_id}/{len(chunks)-1} done ({count} samples in {batch_time:.1f}s) | " | |
# f"Total: {samples_done:,}/{total_samples:,} ({100*samples_done/total_samples:.1f}%) | " | |
# f"Speed: {speed:.1f} samples/s | " | |
# f"ETA: {timedelta(seconds=int(eta))}") | |
# if samples_done - last_checkpoint >= checkpoint_interval or samples_done == total_samples: | |
# checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{samples_done}.csv" | |
# df.to_csv(checkpoint_file, index=False) | |
# print(f" β Checkpoint saved: {checkpoint_file}") | |
# last_checkpoint = samples_done | |
# except Exception: | |
# failed_chunks.append(original_chunk) | |
# print("-" * 20 + " ERROR " + "-" * 20) | |
# print(f"β Batch {batch_id} FAILED. Start index: {original_chunk['start_idx']}. Files: {len(original_chunk['files'])}") | |
# traceback.print_exc() | |
# print("-" * 47) | |
# except KeyboardInterrupt: | |
# interrupted = True | |
# print("\n\n" + "="*50) | |
# print("! KEYBOARD INTERRUPT DETECTED !") | |
# print("Stopping workers and saving all completed progress...") | |
# print("The script will exit shortly.") | |
# print("="*50 + "\n") | |
# # The `with ProcessPoolExecutor` context manager will automatically | |
# # handle shutting down the worker processes when we exit this block. | |
# # --- Finalization and Reporting (this block now runs on completion OR interruption) --- | |
# total_time = time.time() - start_time | |
# print("-" * 50) | |
# if interrupted: | |
# print(f"PROCESS INTERRUPTED") | |
# else: | |
# print(f"TRANSCRIPTION COMPLETE!") | |
# print(f"Total time elapsed: {timedelta(seconds=int(total_time))}") | |
# if total_time > 0 and samples_done > 0: | |
# print(f"Average speed (on completed work): {samples_done/total_time:.1f} samples/second") | |
# # Save final result | |
# final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv" | |
# df.to_csv(final_output, index=False) | |
# print(f"Final progress saved to: {final_output}") | |
# print("-" * 50) | |
# # --- Summary and Verification --- | |
# successful_transcriptions = df['text'].notna().sum() | |
# print("Final Run Summary:") | |
# print(f" - Successfully transcribed: {successful_transcriptions:,} samples") | |
# print(f" - Failed batches: {len(failed_chunks)}") | |
# print(f" - Total samples in failed batches: {sum(len(c['files']) for c in failed_chunks):,}") | |
# if failed_chunks: | |
# failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt" | |
# with open(failed_files_path, 'w') as f: | |
# for chunk in failed_chunks: | |
# for file_path in chunk['files']: | |
# f.write(f"{file_path}\n") | |
# print(f"\nList of files from failed batches saved to: {failed_files_path}") | |
# print("-" * 50) | |
#NOTE #NOTE | |
from concurrent.futures import ProcessPoolExecutor, as_completed | |
import time | |
from datetime import timedelta | |
import pandas as pd | |
import torch | |
import warnings | |
import logging | |
import os | |
import traceback | |
# --- LOAD CHECKPOINT --- | |
checkpoint_file = "/home/ubuntu/ttsar/csv_kanad/sing/cg_shani_sing.csv" | |
print(f"Loading checkpoint from: {checkpoint_file}") | |
df = pd.read_csv(checkpoint_file) | |
print(f"Checkpoint loaded. Shape: {df.shape}") | |
# Check if 'text' column exists, if not create it | |
if 'text' not in df.columns: | |
df['text'] = pd.NA | |
# --- FIND ALL MISSING TRANSCRIPTIONS --- | |
missing_mask = df['text'].isna() | |
missing_indices = df[missing_mask].index.tolist() | |
already_done = (~missing_mask).sum() | |
print(f"Already transcribed: {already_done:,} samples") | |
print(f"Missing transcriptions: {len(missing_indices):,} samples") | |
print("-" * 50) | |
if len(missing_indices) == 0: | |
print("All samples already transcribed!") | |
exit(0) | |
# --- PyTorch settings --- | |
torch.set_float32_matmul_precision('high') | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
def process_batch(batch_data): | |
"""Process a batch of audio files""" | |
batch_id, indices, audio_files, config_path, checkpoint_path = batch_data | |
model = None | |
try: | |
# Import and configure libraries within the worker process | |
import torch | |
import nemo.collections.asr as nemo_asr | |
from omegaconf import OmegaConf, open_dict | |
import warnings | |
import logging | |
# Suppress logs within the worker process | |
logging.getLogger('nemo_logger').setLevel(logging.ERROR) | |
logging.disable(logging.CRITICAL) | |
warnings.filterwarnings('ignore') | |
# Load model for this worker | |
config = OmegaConf.load(config_path) | |
with open_dict(config.cfg): | |
for ds in ['train_ds', 'validation_ds', 'test_ds']: | |
if ds in config.cfg: | |
config.cfg[ds].defer_setup = True | |
model = nemo_asr.models.EncDecMultiTaskModel(cfg=config.cfg) | |
checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False) | |
model.load_state_dict(checkpoint['state_dict'], strict=False) | |
model = model.eval().cuda().bfloat16() | |
decode_cfg = model.cfg.decoding | |
decode_cfg.beam.beam_size = 1 | |
model.change_decoding_strategy(decode_cfg) | |
# Transcribe | |
start = time.time() | |
try: | |
hypotheses = model.transcribe( | |
audio=audio_files, | |
batch_size=64, | |
source_lang='ja', | |
target_lang='ja', | |
task='asr', | |
pnc='no', | |
verbose=False, | |
num_workers=0, | |
channel_selector=0 | |
) | |
results = [hyp.text for hyp in hypotheses] | |
except Exception as e: | |
print(f"Transcription error in batch {batch_id}: {str(e)}") | |
# Return empty results list on transcription failure | |
results = [] | |
# Pad results with None if we got fewer results than expected | |
while len(results) < len(audio_files): | |
results.append(None) | |
# Count successful transcriptions | |
success_count = len([r for r in results if r is not None]) | |
# Return indices and results as a tuple for pairing | |
return batch_id, list(zip(indices, results)), success_count, time.time() - start | |
finally: | |
if model is not None: | |
del model | |
import torch | |
torch.cuda.empty_cache() | |
# --- Parameters --- | |
chunk_size = 512 * 4 # 2048 | |
n_workers = 6 | |
checkpoint_interval = 250_000 | |
config_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02/version_4/hparams.yaml" | |
checkpoint_path = "/home/ubuntu/NeMo_Canary/canary_results/Higurashi_ASR_v.02_plus/checkpoints/Higurashi_ASR_v.02_plus--step=174650.0000-epoch=8-last.ckpt" | |
# --- Create batches from missing indices --- | |
chunks = [] | |
for i in range(0, len(missing_indices), chunk_size): | |
batch_indices = missing_indices[i:i+chunk_size] | |
batch_files = df.loc[batch_indices, 'filename'].tolist() | |
chunks.append({ | |
'batch_id': len(chunks), | |
'indices': batch_indices, | |
'files': batch_files, | |
'config_path': config_path, | |
'checkpoint_path': checkpoint_path | |
}) | |
print(f"Total batches to process: {len(chunks)}") | |
print(f"Batch size: ~{chunk_size} samples") | |
print(f"Workers: {n_workers}") | |
print(f"Checkpoint interval: every {checkpoint_interval:,} samples") | |
print("-" * 50) | |
# --- Initialize tracking variables --- | |
all_results = {} | |
failed_chunks = [] | |
failed_files_list = [] | |
start_time = time.time() | |
samples_done = 0 | |
samples_failed = 0 | |
last_checkpoint = 0 | |
interrupted = False | |
total_to_process = len(missing_indices) | |
# --- Main Processing Loop --- | |
try: | |
with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
future_to_chunk = { | |
executor.submit(process_batch, | |
(chunk['batch_id'], chunk['indices'], chunk['files'], | |
chunk['config_path'], chunk['checkpoint_path'])): chunk | |
for chunk in chunks | |
} | |
for future in as_completed(future_to_chunk): | |
original_chunk = future_to_chunk[future] | |
batch_id = original_chunk['batch_id'] | |
try: | |
_batch_id, index_result_pairs, success_count, batch_time = future.result() | |
# Update DataFrame with results | |
failed_in_batch = 0 | |
for idx, result in index_result_pairs: | |
if result is not None: | |
df.loc[idx, 'text'] = result | |
else: | |
df.loc[idx, 'text'] = "[FAILED]" | |
failed_in_batch += 1 | |
failed_files_list.append(df.loc[idx, 'filename']) | |
samples_done += success_count | |
samples_failed += failed_in_batch | |
elapsed = time.time() - start_time | |
speed = samples_done / elapsed if elapsed > 0 else 0 | |
remaining = total_to_process - samples_done - samples_failed | |
eta = remaining / speed if speed > 0 else 0 | |
current_total = already_done + samples_done | |
status = f"β Batch {batch_id}/{len(chunks)-1} done ({success_count} success" | |
if failed_in_batch > 0: | |
status += f", {failed_in_batch} failed" | |
status += f" in {batch_time:.1f}s)" | |
print(f"{status} | " | |
f"Processed: {samples_done:,}/{total_to_process:,} | " | |
f"Total: {current_total:,}/{len(df):,} ({100*current_total/len(df):.1f}%) | " | |
f"Speed: {speed:.1f} samples/s | " | |
f"ETA: {timedelta(seconds=int(eta))}") | |
# Save checkpoint | |
if samples_done - last_checkpoint >= checkpoint_interval or (samples_done + samples_failed) >= total_to_process: | |
checkpoint_file = f"/home/ubuntu/ttsar/ASR_DATA/transcribed_checkpoint_{current_total}.csv" | |
df.to_csv(checkpoint_file, index=False) | |
print(f" β Checkpoint saved: {checkpoint_file}") | |
last_checkpoint = samples_done | |
except Exception as e: | |
failed_chunks.append(original_chunk) | |
print("-" * 20 + " ERROR " + "-" * 20) | |
print(f"β Batch {batch_id} FAILED. Indices count: {len(original_chunk['indices'])}") | |
print(f"Error: {str(e)}") | |
traceback.print_exc() | |
print("-" * 47) | |
except KeyboardInterrupt: | |
interrupted = True | |
print("\n\n" + "="*50) | |
print("! KEYBOARD INTERRUPT DETECTED !") | |
print("Stopping workers and saving progress...") | |
print("="*50 + "\n") | |
# --- Finalization --- | |
total_time = time.time() - start_time | |
print("-" * 50) | |
if interrupted: | |
print(f"PROCESS INTERRUPTED") | |
else: | |
print(f"PROCESSING COMPLETE!") | |
print(f"Session time: {timedelta(seconds=int(total_time))}") | |
print(f"Samples successfully processed: {samples_done:,}") | |
print(f"Samples failed: {samples_failed:,}") | |
if total_time > 0 and samples_done > 0: | |
print(f"Average speed: {samples_done/total_time:.1f} samples/second") | |
# Save final result | |
final_output = "/home/ubuntu/ttsar/ASR_DATA/transcribed_manifest_final.csv" | |
df.to_csv(final_output, index=False) | |
print(f"Final output saved to: {final_output}") | |
print("-" * 50) | |
# --- Summary --- | |
successful_transcriptions = df['text'].notna().sum() - (df['text'] == "[FAILED]").sum() | |
failed_transcriptions = (df['text'] == "[FAILED]").sum() | |
remaining_missing = df['text'].isna().sum() | |
print("Summary:") | |
print(f" - Total dataset size: {len(df):,} samples") | |
print(f" - Successfully transcribed: {successful_transcriptions:,} samples") | |
print(f" - Failed transcriptions: {failed_transcriptions:,} samples") | |
print(f" - Still missing (NaN): {remaining_missing:,} samples") | |
print(f" - Processed this session: {samples_done:,} successful, {samples_failed:,} failed") | |
print(f" - Failed batches (entire batch): {len(failed_chunks)}") | |
# Save list of failed files | |
if failed_files_list: | |
failed_files_path = "/home/ubuntu/ttsar/ASR_DATA/failed_transcription_files.txt" | |
with open(failed_files_path, 'w') as f: | |
for file_path in failed_files_list: | |
f.write(f"{file_path}\n") | |
print(f"\nFailed files saved to: {failed_files_path}") | |
if failed_chunks: | |
failed_batches_path = "/home/ubuntu/ttsar/ASR_DATA/failed_batches.txt" | |
with open(failed_batches_path, 'w') as f: | |
for chunk in failed_chunks: | |
f.write(f"Batch {chunk['batch_id']}: indices {chunk['indices'][:5]}... ({len(chunk['indices'])} total)\n") | |
print(f"Failed batch info saved to: {failed_batches_path}") | |
print("-" * 50) |