Spaces:
Sleeping
Sleeping
# sagan_inference.py | |
import torch | |
import torchaudio | |
import math | |
from sagan_model import SAGANModel | |
### 2) Age-group Z-score stats (proxy values from literature) ### | |
STATS = { | |
"kindergarten": { | |
"pitch": {"mu": 30.0, "sigma": 29.0}, # Wise & Sloboda (2008) | |
"rhythm": {"mu": 60.0, "sigma": 15.0}, # Demorest & Pfordresher (2015) | |
"timbre": {"mu": 0.65, "sigma": 0.10}, | |
}, | |
"grade_6": { | |
"pitch": {"mu": 43.0, "sigma": 26.0}, | |
"rhythm": {"mu": 75.0, "sigma": 10.0}, | |
"timbre": {"mu": 0.75, "sigma": 0.08}, | |
}, | |
"adult": { | |
"pitch": {"mu": 32.0, "sigma": 19.0}, | |
"rhythm": {"mu": 80.0, "sigma": 8.0}, | |
"timbre": {"mu": 0.85, "sigma": 0.05}, | |
}, | |
} | |
def z_score_standardize(waveform: torch.Tensor, age_group: str) -> torch.Tensor: | |
stats = STATS.get(age_group, STATS["adult"]) | |
mu, sigma = stats["pitch"]["mu"], stats["pitch"]["sigma"] | |
# example for pitch; repeat for rhythm/timbre as needed | |
return (waveform - mu) / (sigma + 1e-9) | |
def run_sagan(audio_path: str, checkpoint_path: str, device='cpu'): | |
# 1) Load audio | |
waveform, sr = torchaudio.load(audio_path) | |
waveform = z_score_standardize(waveform).to(device) | |
# 2) Instantiate model & load weights | |
model = SAGANModel(z_dim=128).to(device) | |
ckpt = torch.load(checkpoint_path, map_location=device) | |
model.load_state_dict(ckpt['model_state_dict']) | |
model.eval() | |
# 3) Prepare latent vector from audio (example: mean-pool + linear proj) | |
# _Here you’ll replace `encode_to_z` with your custom feature extractor_ | |
z = encode_to_z(waveform).unsqueeze(-1).unsqueeze(-1) # -> (1, 128, 1, 1) | |
# 4) Generate | |
with torch.no_grad(): | |
fake_img = model(z) # -> (1, 3, 64, 64) for a 64×64 SAGAN | |
return fake_img | |
# Placeholder: your own mapping from waveform → z | |
def encode_to_z(wf): | |
# e.g., a small CNN or an MLP extracting 128-d features from audio | |
return wf.mean(dim=-1).mean(dim=-1).unsqueeze(0).repeat(1,128) | |