Kr08 commited on
Commit
b360956
·
verified ·
1 Parent(s): ed2ee59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -99
app.py CHANGED
@@ -1,11 +1,21 @@
1
  import gradio as gr
2
  from audio_processing import process_audio, load_models
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, pipeline
4
  import spaces
5
  import torch
6
  import logging
7
-
8
- logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
 
 
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
  # Check if CUDA is available
@@ -13,12 +23,6 @@ cuda_available = torch.cuda.is_available()
13
  device = "cuda" if cuda_available else "cpu"
14
  logger.info(f"Using device: {device}")
15
 
16
- # Initialize model variables
17
- summarizer_model = None
18
- summarizer_tokenizer = None
19
- qa_model = None
20
- qa_tokenizer = None
21
-
22
  # Load Whisper model
23
  print("Loading Whisper model...")
24
  try:
@@ -30,109 +34,73 @@ except Exception as e:
30
  print("Whisper model loaded successfully.")
31
 
32
  def load_summarization_model():
33
- global summarizer_model, summarizer_tokenizer
34
- if summarizer_model is None:
35
- logger.info("Loading summarization model...")
36
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(device)
37
- summarizer_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
38
- logger.info("Summarization model loaded.")
39
-
40
- def load_qa_model():
41
- global qa_model, qa_tokenizer
42
- if qa_model is None:
43
- logger.info("Loading QA model...")
44
- qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad").to(device)
45
- qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
46
- logger.info("QA model loaded.")
 
 
 
 
 
 
 
 
47
 
48
- @spaces.GPU(duration=120)
49
  def transcribe_audio(audio_file, translate, model_size, use_diarization):
50
- language_segments, final_segments = process_audio(audio_file, translate=translate, model_size=model_size, use_diarization=use_diarization)
51
-
52
- output = "Detected language changes:\n\n"
53
- for segment in language_segments:
54
- output += f"Language: {segment['language']}\n"
55
- output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"
56
 
57
- output += f"Transcription with language detection {f'and speaker diarization' if use_diarization else ''} (using {model_size} model):\n\n"
58
- full_text = ""
59
- for segment in final_segments:
60
- output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']})"
61
- if use_diarization:
62
- output += f" {segment['speaker']}:"
63
- output += f"\nOriginal: {segment['text']}\n"
64
- if translate:
65
- output += f"Translated: {segment['translated']}\n"
66
- full_text += segment['translated'] + " "
67
- else:
68
- full_text += segment['text'] + " "
69
- output += "\n"
70
-
71
- return output, full_text
72
-
73
- @spaces.GPU(duration=120)
74
  def summarize_text(text):
75
- load_summarization_model()
76
- inputs = summarizer_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
77
- summary_ids = summarizer_model.generate(inputs["input_ids"], max_length=150, min_length=50, do_sample=False)
78
- summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
79
  return summary
80
 
81
- @spaces.GPU(duration=120)
82
- def answer_question(context, question):
83
- load_qa_model()
84
- inputs = qa_tokenizer(question, context, return_tensors="pt").to(device)
85
- outputs = qa_model(**inputs)
86
- answer_start = torch.argmax(outputs.start_logits)
87
- answer_end = torch.argmax(outputs.end_logits) + 1
88
- answer = qa_tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end])
89
- return answer
90
-
91
- @spaces.GPU(duration=120)
92
- def process_and_summarize(audio_file, translate, model_size, use_diarization):
93
  transcription, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
94
- summary = summarize_text(full_text)
95
  return transcription, summary
96
 
97
- @spaces.GPU(duration=120)
98
- def qa_interface(audio_file, translate, model_size, use_diarization, question):
99
- _, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
100
- answer = answer_question(full_text, question)
101
- return answer
102
-
103
  # Main interface
104
  with gr.Blocks() as iface:
105
- gr.Markdown("# WhisperX Audio Transcription, Translation, Summarization, and QA (with ZeroGPU support)")
106
 
107
- with gr.Tab("Transcribe and Summarize"):
108
- audio_input = gr.Audio(type="filepath")
109
- translate_checkbox = gr.Checkbox(label="Enable Translation")
110
- model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
111
- diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
112
- transcribe_button = gr.Button("Transcribe and Summarize")
113
- transcription_output = gr.Textbox(label="Transcription")
114
- summary_output = gr.Textbox(label="Summary")
115
-
116
- transcribe_button.click(
117
- process_and_summarize,
118
- inputs=[audio_input, translate_checkbox, model_dropdown, diarization_checkbox],
119
- outputs=[transcription_output, summary_output]
120
- )
121
 
122
- with gr.Tab("Question Answering"):
123
- qa_audio_input = gr.Audio(type="filepath")
124
- qa_translate_checkbox = gr.Checkbox(label="Enable Translation")
125
- qa_model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
126
- qa_diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
127
- question_input = gr.Textbox(label="Ask a question about the audio")
128
- qa_button = gr.Button("Get Answer")
129
- answer_output = gr.Textbox(label="Answer")
130
-
131
- qa_button.click(
132
- qa_interface,
133
- inputs=[qa_audio_input, qa_translate_checkbox, qa_model_dropdown, qa_diarization_checkbox, question_input],
134
- outputs=answer_output
135
- )
136
 
137
  gr.Markdown(
138
  f"""
 
1
  import gradio as gr
2
  from audio_processing import process_audio, load_models
3
+ from transformers import pipeline
4
  import spaces
5
  import torch
6
  import logging
7
+ import traceback
8
+ import sys
9
+
10
+ # Set up logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
14
+ handlers=[
15
+ logging.StreamHandler(sys.stdout),
16
+ logging.FileHandler('app.log')
17
+ ]
18
+ )
19
  logger = logging.getLogger(__name__)
20
 
21
  # Check if CUDA is available
 
23
  device = "cuda" if cuda_available else "cpu"
24
  logger.info(f"Using device: {device}")
25
 
 
 
 
 
 
 
26
  # Load Whisper model
27
  print("Loading Whisper model...")
28
  try:
 
34
  print("Whisper model loaded successfully.")
35
 
36
  def load_summarization_model():
37
+ logger.info("Loading summarization model...")
38
+ try:
39
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=0 if cuda_available else -1)
40
+ except Exception as e:
41
+ logger.warning(f"Failed to load summarization model on GPU. Falling back to CPU. Error: {str(e)}")
42
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=-1)
43
+ logger.info("Summarization model loaded.")
44
+ return summarizer
45
+
46
+ def process_with_fallback(func, *args, **kwargs):
47
+ try:
48
+ return func(*args, **kwargs)
49
+ except Exception as e:
50
+ logger.error(f"Error during processing: {str(e)}")
51
+ logger.error(traceback.format_exc())
52
+ if "CUDA" in str(e) or "GPU" in str(e):
53
+ logger.info("Falling back to CPU processing...")
54
+ # Modify kwargs to force CPU processing
55
+ kwargs['use_gpu'] = False
56
+ return func(*args, **kwargs)
57
+ else:
58
+ raise
59
 
60
+ @spaces.GPU
61
  def transcribe_audio(audio_file, translate, model_size, use_diarization):
62
+ return process_with_fallback(process_audio, audio_file, translate=translate, model_size=model_size, use_diarization=use_diarization)
 
 
 
 
 
63
 
64
+ @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def summarize_text(text):
66
+ summarizer = load_summarization_model()
67
+ try:
68
+ summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
69
+ except Exception as e:
70
+ logger.error(f"Error during summarization: {str(e)}")
71
+ logger.error(traceback.format_exc())
72
+ summary = "Error occurred during summarization. Please try again."
73
  return summary
74
 
75
+ @spaces.GPU
76
+ def process_and_summarize(audio_file, translate, model_size, use_diarization, do_summarize):
 
 
 
 
 
 
 
 
 
 
77
  transcription, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
78
+ summary = summarize_text(full_text) if do_summarize else ""
79
  return transcription, summary
80
 
 
 
 
 
 
 
81
  # Main interface
82
  with gr.Blocks() as iface:
83
+ gr.Markdown("# WhisperX Audio Transcription, Translation, and Summarization (with ZeroGPU support)")
84
 
85
+ audio_input = gr.Audio(type="filepath")
86
+ translate_checkbox = gr.Checkbox(label="Enable Translation")
87
+ summarize_checkbox = gr.Checkbox(label="Enable Summarization", interactive=False)
88
+ model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
89
+ diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
90
+ process_button = gr.Button("Process Audio")
91
+ transcription_output = gr.Textbox(label="Transcription/Translation")
92
+ summary_output = gr.Textbox(label="Summary")
93
+
94
+ def update_summarize_checkbox(translate):
95
+ return gr.Checkbox(interactive=translate)
96
+
97
+ translate_checkbox.change(update_summarize_checkbox, inputs=[translate_checkbox], outputs=[summarize_checkbox])
 
98
 
99
+ process_button.click(
100
+ process_and_summarize,
101
+ inputs=[audio_input, translate_checkbox, model_dropdown, diarization_checkbox, summarize_checkbox],
102
+ outputs=[transcription_output, summary_output]
103
+ )
 
 
 
 
 
 
 
 
 
104
 
105
  gr.Markdown(
106
  f"""