Reihaneh's picture
Update app.py
9dc2a58 verified
import os, torch, gradio as gr
from transformers import (
AutoModelForCTC,
AutoProcessor, # happy path
Wav2Vec2Processor, # fallback
Wav2Vec2FeatureExtractor,
Wav2Vec2CTCTokenizer,
)
MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1")
HF_TOKEN = os.getenv("HF_TOKEN") # only if private
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---- Try to load processor; if missing feature extractor, build it manually
processor = None
try:
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
except Exception as e:
print("AutoProcessor failed, building Wav2Vec2Processor manually:", e)
# Load tokenizer (must exist in repo: vocab.json + tokenizer_config.json)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
# Minimal, safe defaults — adjust if your training used different settings
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000, # <-- set to your training SR
padding_value=0.0,
do_normalize=True,
return_attention_mask=True,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = AutoModelForCTC.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval()
# Try to read SR from processor if present
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
def _cheap_resample(wav, sr, target_sr):
if sr == target_sr:
return wav
import numpy as np, math
ratio = target_sr / sr
idx = (np.arange(int(math.ceil(wav.shape[0] * ratio))) / ratio).astype(int)
idx = idx[idx < wav.shape[0]]
return wav[idx]
'''def transcribe(audio):
if audio is None:
return ""
sr, x = audio
if x.ndim == 2: # stereo -> mono
x = x[:, 0]
x = _cheap_resample(x, sr, target_sr)
inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True)
with torch.inference_mode():
logits = model(inputs.input_values.to(device)).logits
ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(ids)[0]
return text'''
def transcribe(a):
try:
if a is None:
return ""
sr, x = a # if you use a helper, just make sure you end up with (sr, np.ndarray)
# 1) mono + sanitize + FORCE float32
import numpy as np, math
if x.ndim == 2:
x = x.mean(axis=1)
x = np.nan_to_num(x).astype(np.float32)
# 2) (optional) cheap resample to your processor’s SR
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
if sr != target_sr:
ratio = target_sr / float(sr)
n = int(math.ceil(len(x) * ratio))
idx = (np.arange(n) / ratio).astype(np.int64)
idx = np.clip(idx, 0, len(x) - 1)
x = x[idx]
# 3) tokenize → cast inputs to DEVICE + MODEL DTYPE
inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(device)
# >>> KEY LINE: match model dtype (prevents "Input type (double) and bias type should be the same")
input_values = input_values.to(model.dtype)
with torch.inference_mode():
logits = model(input_values).logits
ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(ids)[0]
return text
except Exception as e:
import traceback
print(traceback.format_exc())
return f"⚠️ Error: {e}"
with gr.Blocks(title="Frisian ASR") as demo:
gr.Markdown("## 🎙️ Frisian ASR")
audio = gr.Audio(sources=["microphone","upload"], type="numpy", label="Audio")
out = gr.Textbox(label="Transcript")
gr.Button("Transcribe").click(transcribe, inputs=audio, outputs=out)
demo.queue().launch()