HowFar-Caarma / inference.py
massabaali's picture
Upload inference.py with huggingface_hub
8ac3327 verified
"""
Inference script for HowFar-Caarma: distance estimation from speech using HuBERT.
Usage:
python inference.py --ckpt epoch18_val_acc7997.ckpt --audio path/to/audio.wav
"""
import argparse
import torch
import torch.nn as nn
import torchaudio
import pytorch_lightning as pl
from transformers import HubertModel, AutoFeatureExtractor
class Hubert_Model(nn.Module):
def __init__(self, hubert_model_name="facebook/hubert-large-ls960-ft", cache_dir=""):
super().__init__()
self.encoder = HubertModel.from_pretrained(hubert_model_name, cache_dir=cache_dir)
hidden_size = self.encoder.config.hidden_size
self.layer_norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(0.1)
self.bn = nn.BatchNorm1d(hidden_size)
def forward(self, input_values, labels=None):
outputs = self.encoder(input_values)
hidden_states = outputs.last_hidden_state
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
pooled = torch.mean(hidden_states, dim=1)
pooled = self.bn(pooled)
return pooled
class CassavaPLModule(pl.LightningModule):
def __init__(self, hparams, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def load_model(ckpt_path, device="cuda"):
base = Hubert_Model()
model = CassavaPLModule.load_from_checkpoint(
ckpt_path,
hparams={"lr": 0.001, "batch_size": 1},
model=base,
strict=False,
map_location=device,
)
model.eval()
model.to(device)
model.freeze()
return model
def extract_embedding(model, audio_path, device="cuda"):
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
processor = AutoFeatureExtractor.from_pretrained("facebook/hubert-large-ls960-ft")
inputs = processor(
waveform.squeeze(0).numpy(),
sampling_rate=16000,
return_tensors="pt",
).input_values.to(device)
with torch.no_grad():
embedding = model(inputs)
return embedding
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", required=True)
parser.add_argument("--audio", required=True)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(args.ckpt, device=device)
emb = extract_embedding(model, args.audio, device=device)
print(f"Embedding shape: {emb.shape}")