cdactvm's picture
Update app.py
ed97bcc verified
raw
history blame
2.46 kB
# import torch
# import torchaudio
# import gradio as gr
# from transformers import Wav2Vec2BertProcessor, Wav2Vec2BertForCTC
# # Set device
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load processor & model
# model_name = "cdactvm/w2v-bert-punjabi" # Change if using a Punjabi ASR model
# processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
# model = Wav2Vec2BertForCTC.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
# def transcribe(audio_path):
# # Load audio file
# waveform, sample_rate = torchaudio.load(audio_path)
# # Convert stereo to mono (if needed)
# if waveform.shape[0] > 1:
# waveform = torch.mean(waveform, dim=0, keepdim=True)
# # Resample to 16kHz
# if sample_rate != 16000:
# waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
# # Process audio
# inputs = processor(waveform.squeeze(0), sampling_rate=16000, return_tensors="pt")
# inputs = {key: val.to(device, dtype=torch.bfloat16) for key, val in inputs.items()}
# # Get logits & transcribe
# with torch.no_grad():
# logits = model(**inputs).logits
# predicted_ids = torch.argmax(logits, dim=-1)
# transcription = processor.batch_decode(predicted_ids)[0]
# return transcription
# # Gradio Interface
# app = gr.Interface(
# fn=transcribe,
# inputs=gr.Audio(sources="upload", type="filepath"),
# outputs="text",
# title="Punjabi Speech-to-Text",
# description="Upload an audio file and get the transcription in Punjabi."
# )
# if __name__ == "__main__":
# app.launch()
import gradio as gr
import torch
from transformers import pipeline
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load ASR pipeline
asr_pipeline = pipeline(
"automatic-speech-recognition",
model="cdactvm/w2v-bert-punjabi", # Replace with a Punjabi ASR model if available
torch_dtype=torch.bfloat16,
device=0 if torch.cuda.is_available() else -1 # GPU (0) or CPU (-1)
)
def transcribe(audio_path):
# Run inference
result = asr_pipeline(audio_path)
return result["text"]
# Gradio Interface
app = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs="text",
title="Punjabi Speech-to-Text",
description="Upload an audio file and get the transcription in Punjabi."
)
if __name__ == "__main__":
app.launch()