sanchit-gandhi commited on
Commit
ccb306d
·
1 Parent(s): 26faa5f

remove plot

Browse files
Files changed (1) hide show
  1. app.py +5 -29
app.py CHANGED
@@ -2,7 +2,6 @@ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
2
  from transformers.utils import is_flash_attn_2_available
3
  import torch
4
  import gradio as gr
5
- import matplotlib.pyplot as plt
6
  import time
7
  import os
8
 
@@ -64,6 +63,7 @@ def transcribe(inputs):
64
  start_time = time.time()
65
  result = distil_pipe_forward(*args, **kwargs)
66
  distil_runtime = time.time() - start_time
 
67
  return result
68
 
69
  distil_pipe._forward = _forward_distil_time
@@ -75,34 +75,13 @@ def transcribe(inputs):
75
  start_time = time.time()
76
  result = pipe_forward(*args, **kwargs)
77
  runtime = time.time() - start_time
 
78
  return result
79
 
80
  pipe._forward = _forward_time
81
  text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
82
 
83
- # Create figure and axis
84
- fig, ax = plt.subplots(figsize=(5, 5))
85
-
86
- # Define bar width and positions
87
- bar_width = 0.1
88
- positions = [0, 0.1] # Adjusted positions to bring bars closer
89
-
90
- # Plot data
91
- ax.bar(positions[0], distil_runtime, bar_width, edgecolor='black')
92
- ax.bar(positions[1], runtime, bar_width, edgecolor='black')
93
-
94
- # Set title, labels, and xticks
95
- ax.set_ylabel('Transcription time (s)')
96
- ax.set_xticks(positions)
97
- ax.set_xticklabels(['Distil-Whisper', 'Whisper'])
98
-
99
- # Gridlines and other styling
100
- ax.grid(which='major', axis='y', linestyle='--', linewidth=0.5)
101
-
102
- # Use tight layout to avoid overlaps
103
- plt.tight_layout()
104
-
105
- yield distil_text, distil_runtime, text, runtime, plt
106
 
107
  if __name__ == "__main__":
108
  with gr.Blocks() as demo:
@@ -129,18 +108,15 @@ if __name__ == "__main__":
129
  )
130
  audio = gr.components.Audio(type="filepath", label="Audio input")
131
  button = gr.Button("Transcribe")
132
- plot = gr.components.Plot()
133
  with gr.Row():
134
  distil_runtime = gr.components.Textbox(label="Distil-Whisper Transcription Time (s)")
135
  runtime = gr.components.Textbox(label="Whisper Transcription Time (s)")
136
  with gr.Row():
137
  distil_transcription = gr.components.Textbox(label="Distil-Whisper Transcription", show_copy_button=True)
138
  transcription = gr.components.Textbox(label="Whisper Transcription", show_copy_button=True)
139
-
140
  button.click(
141
  fn=transcribe,
142
  inputs=audio,
143
- outputs=[distil_transcription, distil_runtime, transcription, runtime, plot],
144
  )
145
-
146
- demo.queue().launch()
 
2
  from transformers.utils import is_flash_attn_2_available
3
  import torch
4
  import gradio as gr
 
5
  import time
6
  import os
7
 
 
63
  start_time = time.time()
64
  result = distil_pipe_forward(*args, **kwargs)
65
  distil_runtime = time.time() - start_time
66
+ distil_runtime = round(distil_runtime, 2)
67
  return result
68
 
69
  distil_pipe._forward = _forward_distil_time
 
75
  start_time = time.time()
76
  result = pipe_forward(*args, **kwargs)
77
  runtime = time.time() - start_time
78
+ runtime = round(runtime, 2)
79
  return result
80
 
81
  pipe._forward = _forward_time
82
  text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
83
 
84
+ yield distil_text, distil_runtime, text, runtime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
  with gr.Blocks() as demo:
 
108
  )
109
  audio = gr.components.Audio(type="filepath", label="Audio input")
110
  button = gr.Button("Transcribe")
 
111
  with gr.Row():
112
  distil_runtime = gr.components.Textbox(label="Distil-Whisper Transcription Time (s)")
113
  runtime = gr.components.Textbox(label="Whisper Transcription Time (s)")
114
  with gr.Row():
115
  distil_transcription = gr.components.Textbox(label="Distil-Whisper Transcription", show_copy_button=True)
116
  transcription = gr.components.Textbox(label="Whisper Transcription", show_copy_button=True)
 
117
  button.click(
118
  fn=transcribe,
119
  inputs=audio,
120
+ outputs=[distil_transcription, distil_runtime, transcription, runtime],
121
  )
122
+ demo.queue().launch()