SAGAN / sagan_inference.py
Pavithiran's picture
Update sagan_inference.py
e8e15b2 verified
# sagan_inference.py
import torch
import torchaudio
import math
from sagan_model import SAGANModel
### 2) Age-group Z-score stats (proxy values from literature) ###
STATS = {
"kindergarten": {
"pitch": {"mu": 30.0, "sigma": 29.0}, # Wise & Sloboda (2008)
"rhythm": {"mu": 60.0, "sigma": 15.0}, # Demorest & Pfordresher (2015)
"timbre": {"mu": 0.65, "sigma": 0.10},
},
"grade_6": {
"pitch": {"mu": 43.0, "sigma": 26.0},
"rhythm": {"mu": 75.0, "sigma": 10.0},
"timbre": {"mu": 0.75, "sigma": 0.08},
},
"adult": {
"pitch": {"mu": 32.0, "sigma": 19.0},
"rhythm": {"mu": 80.0, "sigma": 8.0},
"timbre": {"mu": 0.85, "sigma": 0.05},
},
}
def z_score_standardize(waveform: torch.Tensor, age_group: str) -> torch.Tensor:
stats = STATS.get(age_group, STATS["adult"])
mu, sigma = stats["pitch"]["mu"], stats["pitch"]["sigma"]
# example for pitch; repeat for rhythm/timbre as needed
return (waveform - mu) / (sigma + 1e-9)
def run_sagan(audio_path: str, checkpoint_path: str, device='cpu'):
# 1) Load audio
waveform, sr = torchaudio.load(audio_path)
waveform = z_score_standardize(waveform).to(device)
# 2) Instantiate model & load weights
model = SAGANModel(z_dim=128).to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
# 3) Prepare latent vector from audio (example: mean-pool + linear proj)
# _Here you’ll replace `encode_to_z` with your custom feature extractor_
z = encode_to_z(waveform).unsqueeze(-1).unsqueeze(-1) # -> (1, 128, 1, 1)
# 4) Generate
with torch.no_grad():
fake_img = model(z) # -> (1, 3, 64, 64) for a 64×64 SAGAN
return fake_img
# Placeholder: your own mapping from waveform → z
def encode_to_z(wf):
# e.g., a small CNN or an MLP extracting 128-d features from audio
return wf.mean(dim=-1).mean(dim=-1).unsqueeze(0).repeat(1,128)