Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Проверка синтетического корпуса MELD-S: | |
• существует ли WAV-файл; | |
• правильные ли размеры аудио- и текст-эмбеддингов; | |
• совпадает ли итоговый размер фич-вектора с ожиданием. | |
Результат: | |
GOOD / BAD в консоль + CSV bad_synth_meld.csv (если нашли проблемы). | |
""" | |
from __future__ import annotations | |
import csv | |
import logging | |
import sys | |
import traceback | |
from pathlib import Path | |
from types import SimpleNamespace | |
from typing import Dict, List, Optional | |
import pandas as pd | |
import torch | |
import torchaudio | |
from tqdm import tqdm | |
# ---------------------------------------------------------------------- | |
# >>>>>>>>> НАСТРОЙКИ ПОЛЬЗОВАТЕЛЯ (проверьте пути!) <<<<<<<<<<< | |
# ---------------------------------------------------------------------- | |
USER_CONFIG = { | |
# пути к синтетике | |
"synthetic_path": r"E:/MELD_S", | |
"csv_name": "meld_s_train_labels.csv", | |
"wav_subdir": "wavs", | |
# модели / чекпойнты такие же, как в вашем config.toml | |
"audio_model_name": "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim", | |
"audio_ckpt": "best_audio_model_2.pt", | |
"text_model_name": "jinaai/jina-embeddings-v3", | |
"text_ckpt": "best_text_model.pth", | |
# общие параметры | |
"device": "cuda" if torch.cuda.is_available() else "cpu", | |
"sample_rate": 16000, | |
"num_emotions": 7, # anger, disgust, fear, happy, neutral, sad, surprise | |
} | |
# ---------------------------------------------------------------------- | |
# импорт собственных экстракторов | |
# ---------------------------------------------------------------------- | |
try: | |
from feature_extractor import ( | |
PretrainedAudioEmbeddingExtractor, | |
PretrainedTextEmbeddingExtractor, | |
) | |
except ModuleNotFoundError: | |
try: | |
# если файл лежит в data_loading/ | |
from data_loading.feature_extractor import ( | |
PretrainedAudioEmbeddingExtractor, | |
PretrainedTextEmbeddingExtractor, | |
) | |
except ModuleNotFoundError as e: | |
sys.exit( | |
"❌ Не найден feature_extractor.py. " | |
"Убедитесь, что он в PYTHONPATH или лежит рядом со скриптом." | |
) | |
# ---------------------------------------------------------------------- | |
# вспомогательные функции | |
# ---------------------------------------------------------------------- | |
def build_audio_cfg() -> SimpleNamespace: | |
"""Готовим config-объект для PretrainedAudioEmbeddingExtractor.""" | |
return SimpleNamespace( | |
audio_model_name=USER_CONFIG["audio_model_name"], | |
emb_device=USER_CONFIG["device"], | |
audio_pooling="mean", # как в тренировке | |
emb_normalize=False, | |
max_audio_frames=0, | |
audio_classifier_checkpoint=USER_CONFIG["audio_ckpt"], | |
sample_rate=USER_CONFIG["sample_rate"], | |
wav_length=4, | |
) | |
def build_text_cfg() -> SimpleNamespace: | |
"""Config для PretrainedTextEmbeddingExtractor.""" | |
return SimpleNamespace( | |
text_model_name=USER_CONFIG["text_model_name"], | |
emb_device=USER_CONFIG["device"], | |
text_pooling="mean", | |
emb_normalize=False, | |
max_tokens=95, | |
text_classifier_checkpoint=USER_CONFIG["text_ckpt"], | |
) | |
def get_dims(audio_extractor, text_extractor) -> Dict[str, int]: | |
"""Возвращает фактические размеры эмбеддингов (audio_dim, text_dim).""" | |
sr = USER_CONFIG["sample_rate"] | |
with torch.no_grad(): | |
dummy_wav = torch.zeros(1, sr) | |
_, a_emb = audio_extractor.extract(dummy_wav[0], sr) | |
audio_dim = a_emb[0].shape[-1] | |
_, t_emb = text_extractor.extract("dummy text") | |
text_dim = t_emb[0].shape[-1] | |
return {"audio_dim": audio_dim, "text_dim": text_dim} | |
def check_row( | |
row: pd.Series, | |
feats: Dict[str, object], | |
dims: Dict[str, int], | |
wav_dir: Path, | |
) -> Optional[str]: | |
""" | |
Возвращает None, если пример корректный, иначе строку-причину. | |
""" | |
video = row["video_name"] | |
wav_path = wav_dir / f"{video}.wav" | |
text = row.get("text", "") | |
try: | |
if not wav_path.exists(): | |
return "file_missing" | |
# ---------- аудио ---------- | |
wf, sr = torchaudio.load(str(wav_path)) | |
if sr != USER_CONFIG["sample_rate"]: | |
wf = torchaudio.transforms.Resample(sr, USER_CONFIG["sample_rate"])(wf) | |
a_pred, a_emb = feats["audio"].extract(wf[0], USER_CONFIG["sample_rate"]) | |
a_emb = a_emb[0] | |
if a_emb.shape[-1] != dims["audio_dim"]: | |
return f"audio_dim_{a_emb.shape[-1]}" | |
# ---------- текст ---------- | |
t_pred, t_emb = feats["text"].extract(text) | |
t_emb = t_emb[0] | |
if t_emb.shape[-1] != dims["text_dim"]: | |
return f"text_dim_{t_emb.shape[-1]}" | |
# ---------- конкатенация ---------- | |
full_vec = torch.cat( | |
[a_emb, t_emb, a_pred[0], t_pred[0]], | |
dim=-1, | |
) | |
expected_all = ( | |
dims["audio_dim"] | |
+ dims["text_dim"] | |
+ 2 * USER_CONFIG["num_emotions"] | |
) | |
if full_vec.shape[-1] != expected_all: | |
return f"concat_dim_{full_vec.shape[-1]}" | |
except Exception as e: | |
logging.error(f"{video}: {traceback.format_exc(limit=2)}") | |
return "exception_" + e.__class__.__name__ | |
return None | |
# ---------------------------------------------------------------------- | |
# основной скрипт | |
# ---------------------------------------------------------------------- | |
def main() -> None: | |
syn_root = Path(USER_CONFIG["synthetic_path"]) | |
csv_path = syn_root / USER_CONFIG["csv_name"] | |
wav_dir = syn_root / USER_CONFIG["wav_subdir"] | |
if not csv_path.exists(): | |
sys.exit(f"CSV не найден: {csv_path}") | |
if not wav_dir.exists(): | |
sys.exit(f"WAV-директория не найдена: {wav_dir}") | |
# 1. экстракторы | |
audio_feat = PretrainedAudioEmbeddingExtractor(build_audio_cfg()) | |
text_feat = PretrainedTextEmbeddingExtractor(build_text_cfg()) | |
feats = {"audio": audio_feat, "text": text_feat} | |
# 2. реальные размерности | |
dims = get_dims(audio_feat, text_feat) | |
expected_total = ( | |
dims["audio_dim"] + dims["text_dim"] + 2 * USER_CONFIG["num_emotions"] | |
) | |
print( | |
f"Audio dim = {dims['audio_dim']}, " | |
f"Text dim = {dims['text_dim']}, " | |
f"Expected concat = {expected_total}" | |
) | |
# 3. правим CSV | |
df = pd.read_csv(csv_path) | |
bad_rows: List[Dict[str, str]] = [] | |
good_cnt = 0 | |
for _, row in tqdm(df.iterrows(), total=len(df), desc="Checking"): | |
reason = check_row(row, feats, dims, wav_dir) | |
if reason: | |
bad_rows.append( | |
{ | |
"video_name": row["video_name"], | |
"reason": reason, | |
"wav_path": str(wav_dir / f"{row['video_name']}.wav"), | |
} | |
) | |
else: | |
good_cnt += 1 | |
# 4. отчёт | |
print("\n========== SUMMARY ==========") | |
print(f"✅ GOOD : {good_cnt}") | |
print(f"❌ BAD : {len(bad_rows)}") | |
if bad_rows: | |
out_csv = Path(__file__).with_name("bad_synth_meld.csv") | |
with open(out_csv, "w", newline="", encoding="utf-8") as f: | |
writer = csv.DictWriter( | |
f, fieldnames=["video_name", "reason", "wav_path"] | |
) | |
writer.writeheader() | |
writer.writerows(bad_rows) | |
print(f"\nСписок проблемных примеров сохранён: {out_csv.resolve()}") | |
if __name__ == "__main__": | |
main() | |