therealcyberlord commited on
Commit
4c8b94b
1 Parent(s): d4f2f17

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
4
+ import numpy as np
5
+ from pyannote.audio import Pipeline
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ load_dotenv()
10
+
11
+ # Check and set device
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
+
16
+ # Model and pipeline setup
17
+ model_id = "distil-whisper/distil-small.en"
18
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
19
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
20
+ )
21
+ model.to(device)
22
+
23
+ processor = AutoProcessor.from_pretrained(model_id)
24
+
25
+ pipe = pipeline(
26
+ "automatic-speech-recognition",
27
+ model=model,
28
+ tokenizer=processor.tokenizer,
29
+ feature_extractor=processor.feature_extractor,
30
+ max_new_tokens=128,
31
+ torch_dtype=torch_dtype,
32
+ device=device,
33
+ )
34
+
35
+ # diarization pipeline (renamed to avoid conflict)
36
+ diarization_pipeline = Pipeline.from_pretrained(
37
+ "pyannote/speaker-diarization-3.0", use_auth_token=os.getenv("HF_KEY")
38
+ )
39
+
40
+
41
+ def transcribe(audio):
42
+ sr, data = audio
43
+ processed_data = np.array(data).astype(np.float32) / 32767.0
44
+ waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
45
+
46
+ # results from the pipeline
47
+ transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"]
48
+ diarization_res = diarization_pipeline(
49
+ {"waveform": waveform_tensor, "sample_rate": sr}
50
+ )
51
+
52
+ return transcription_res, diarization_res
53
+
54
+
55
+ demo = gr.Interface(
56
+ fn=transcribe,
57
+ inputs=gr.Audio(sources=["upload", "microphone"]),
58
+ outputs=[
59
+ gr.Textbox(lines=3, info="audio transcription"),
60
+ gr.Textbox(info="speaker diarization"),
61
+ ],
62
+ title="Automatic Speech Recognition 🗣️",
63
+ description="Transcribe your speech to text with distilled whisper",
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ demo.launch()