File size: 2,031 Bytes
e8e15b2
c98cb1d
e8e15b2
 
 
 
c98cb1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e15b2
 
 
 
 
c98cb1d
e8e15b2
 
 
 
c98cb1d
e8e15b2
 
 
 
 
 
 
 
 
 
 
c98cb1d
e8e15b2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# sagan_inference.py

import torch
import torchaudio
import math
from sagan_model import SAGANModel

### 2) Age-group Z-score stats (proxy values from literature) ###
STATS = {
    "kindergarten": {
        "pitch":  {"mu":  30.0, "sigma": 29.0},  # Wise & Sloboda (2008)
        "rhythm": {"mu":  60.0, "sigma": 15.0},  # Demorest & Pfordresher (2015)
        "timbre": {"mu": 0.65, "sigma": 0.10},
    },
    "grade_6": {
        "pitch":  {"mu":  43.0, "sigma": 26.0},
        "rhythm": {"mu":  75.0, "sigma": 10.0},
        "timbre": {"mu": 0.75, "sigma": 0.08},
    },
    "adult": {
        "pitch":  {"mu":  32.0, "sigma": 19.0},
        "rhythm": {"mu":  80.0, "sigma":  8.0},
        "timbre": {"mu": 0.85, "sigma": 0.05},
    },
}

def z_score_standardize(waveform: torch.Tensor, age_group: str) -> torch.Tensor:
    stats = STATS.get(age_group, STATS["adult"])
    mu, sigma = stats["pitch"]["mu"], stats["pitch"]["sigma"]
    # example for pitch; repeat for rhythm/timbre as needed
    return (waveform - mu) / (sigma + 1e-9)

def run_sagan(audio_path: str, checkpoint_path: str, device='cpu'):
    # 1) Load audio
    waveform, sr = torchaudio.load(audio_path)
    waveform = z_score_standardize(waveform).to(device)

    # 2) Instantiate model & load weights
    model = SAGANModel(z_dim=128).to(device)
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()

    # 3) Prepare latent vector from audio (example: mean-pool + linear proj)
    #    _Here you’ll replace `encode_to_z` with your custom feature extractor_
    z = encode_to_z(waveform).unsqueeze(-1).unsqueeze(-1)  # -> (1, 128, 1, 1)

    # 4) Generate
    with torch.no_grad():
        fake_img = model(z)  # -> (1, 3, 64, 64) for a 64×64 SAGAN
    return fake_img

# Placeholder: your own mapping from waveform → z
def encode_to_z(wf):
    # e.g., a small CNN or an MLP extracting 128-d features from audio
    return wf.mean(dim=-1).mean(dim=-1).unsqueeze(0).repeat(1,128)