File size: 10,593 Bytes
c569b48
444b9c9
bbacfdf
aaac499
 
3a346c4
b360956
 
 
 
 
 
 
444b9c9
b360956
 
3a346c4
aaac499
bbacfdf
 
 
 
127f69f
 
6241aa9
 
 
eaed8a8
 
 
1c663ae
6241aa9
bbacfdf
 
127f69f
 
 
95af989
127f69f
 
 
 
 
 
 
 
 
9352d35
127f69f
 
 
6db9237
b360956
 
444b9c9
b360956
444b9c9
 
b360956
9352d35
b360956
444b9c9
 
b360956
127f69f
b360956
 
 
 
 
 
 
 
 
 
 
 
aaac499
127f69f
8055777
abe9454
 
444b9c9
9dfd91e
444b9c9
 
 
 
 
f36e52e
127f69f
8055777
aaac499
444b9c9
b360956
444b9c9
b360956
444b9c9
 
b360956
444b9c9
b360956
444b9c9
aaac499
8055777
abe9454
 
444b9c9
abe9454
6241aa9
 
acfe3c0
9116075
444b9c9
 
 
 
 
 
 
 
 
 
6241aa9
 
 
 
 
444b9c9
 
 
 
 
 
acfe3c0
444b9c9
9116075
444b9c9
 
aaac499
3c60188
28d3de2
127f69f
 
 
 
 
 
 
 
 
 
 
 
bbacfdf
 
 
 
 
127f69f
bbacfdf
 
 
 
 
 
 
 
 
 
 
 
 
28d3de2
127f69f
28d3de2
 
 
 
 
 
 
 
 
 
 
9fb4314
 
3c60188
127f69f
 
 
3c60188
 
28d3de2
aaac499
 
127f69f
8055777
b360956
 
 
8055777
b360956
 
127f69f
b360956
acfe3c0
127f69f
 
 
b360956
bbacfdf
 
 
b360956
 
 
 
aaac499
b360956
 
abe9454
acfe3c0
b360956
127f69f
 
 
 
 
 
aaac499
bbacfdf
 
0890b0e
bbacfdf
 
 
aaac499
3a346c4
 
444b9c9
 
3a346c4
aaac499
 
3a346c4
aaac499
 
f36e52e
81e4ee2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import gradio as gr
from audio_processing import process_audio
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import spaces
import torch
import logging
import traceback
import sys

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

def load_translation_model() :
    tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
    model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
    return model, tokenizer


def alternate_translation(translation_model, translation_tokenizer, inputs): 
    # model, tokenizer = load_translation_model()
    tokenized_inputs = translation_tokenizer(inputs, return_tensors='pt')

    answer = ""
    # for 
    translated_tokens = translation_model.generate(**tokenized_inputs, forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), max_length=100)
    return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]    
    
    
def load_qa_model():
    logger.info("Loading Q&A model...")
    try:
        model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
        qa_pipeline = pipeline(
            "text-generation",
            model=model_id,
            model_kwargs={"torch_dtype": torch.bfloat16},
            device_map="auto",
        )
        logger.info(f"Q&A model loaded successfully")
        return qa_pipeline
    except Exception as e:
        logger.warning(f"Failed to load Q&A model. Error: \n{str(e)}")
        return None


def load_summarization_model():
    logger.info("Loading summarization model...")
    try:
        cuda_available = torch.cuda.is_available()
        summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=0 if cuda_available else -1)
        logger.info(f"Summarization model loaded successfully on {'GPU' if cuda_available else 'CPU'}")
        return summarizer
    except Exception as e:
        logger.warning(f"Failed to load summarization model on GPU. Falling back to CPU. Error: \n{str(e)}")
        summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=-1)
        logger.info("Summarization model loaded successfully on CPU")
        return summarizer


def process_with_fallback(func, *args, **kwargs):
    try:
        return func(*args, **kwargs)
    except Exception as e:
        logger.error(f"Error during processing: {str(e)}")
        logger.error(traceback.format_exc())
        if "CUDA" in str(e) or "GPU" in str(e):
            logger.info("Falling back to CPU processing...")
            kwargs['use_gpu'] = False
            return func(*args, **kwargs)
        else:
            raise


@spaces.GPU(duration=60)
def transcribe_audio(audio_file, translate, model_size):
    logger.info(f"Starting transcription: translate={translate}, model_size={model_size}")
    try:
        result = process_with_fallback(process_audio, audio_file, translate=translate, model_size=model_size) # use_diarization=use_diarization
        logger.info("Transcription completed successfully")
        return result
    except Exception as e:
        logger.error(f"Transcription failed: {str(e)}")
        raise gr.Error(f"Transcription failed: {str(e)}")


@spaces.GPU(duration=60)
def summarize_text(text):
    logger.info("Starting text summarization")
    try:
        summarizer = load_summarization_model()
        summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text']
        logger.info("Summarization completed successfully")
        return summary
    except Exception as e:
        logger.error(f"Summarization failed: {str(e)}")
        logger.error(traceback.format_exc())
        return "Error occurred during summarization. Please try again."

@spaces.GPU(duration=60)
def process_and_summarize(audio_file, translate, model_size, do_summarize=True):
    logger.info(f"Starting process_and_summarize: translate={translate}, model_size={model_size}, do_summarize={do_summarize}")
    try:
        language_segments, final_segments = transcribe_audio(audio_file, translate, model_size)

        translation_model, translation_tokenizer = load_translation_model()
        # transcription = "Detected language changes:\n\n"
        transcription = ""
        for segment in language_segments:
            transcription += f"Language: {segment['language']}\n"
            transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"

        transcription += f"Transcription with language detection and speaker diarization (using {model_size} model):\n\n"
        full_text = ""
        for segment in final_segments:
            transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:\n"
            transcription += f"Original: {segment['text']}\n"
            if translate:
                alt_trans=alternate_translation(translation_model, translation_tokenizer, segment['text'])
                transcription += f"Translated:{alt_trans}"
                full_text += alt_trans
                # transcription += f"Translated: {segment['translated']}\n"
                # full_text += segment['translated'] + " "
            else:
                full_text += segment['text'] + " "
            transcription += "\n"

        summary = summarize_text(full_text) if do_summarize else ""
        logger.info("Process and summarize completed successfully")
        return transcription, full_text, summary
    except Exception as e:
        logger.error(f"Process and summarize failed: {str(e)}\n")
        logger.error(traceback.format_exc())
        raise gr.Error(f"Processing failed: {str(e)}")



@spaces.GPU(duration=60)
def answer_question(context, question):
    logger.info("Starting Q&A process")
    try:
        qa_pipeline = load_qa_model()
        if qa_pipeline is None:
            return "Error: Q&A model could not be loaded."

        messages = [
            {"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."},
            {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"},
        ]
        alternate_system_message = """
        You are an AI assistant designed to analyze speech transcriptions in a safe and responsible manner. 
        Your purpose is to assist people, not to monitor or detect threats.

        When responding to user queries, your primary goals are:
        
        1. To provide factual, accurate information to the best of your abilities.
        2. To guide users towards appropriate resources and authorities if they are facing an emergency or urgent situation.
        3. To refrain from speculating about or escalating potentially concerning situations without clear justification.
        4. To avoid making judgements or taking actions that could infringe on individual privacy or civil liberties.
        
        However, if the speech suggests someone may be in immediate danger or that a crime is being planned, you should:

        - Identify & report 
        - Identify any cryptic information and report it. 
        - Avoid probing for additional details or speculating about the nature of the potential threat.
        - Do not provide any information that could enable or encourage harmful, illegal or unethical acts.
        Your role is to be a helpful, informative assistant. 
        """
        out = qa_pipeline(messages, max_new_tokens=256)
        
        logger.info(f"Raw model output: {out}")

        generated_text = out[0]['generated_text']
        
        # Find the assistant's message
        for message in generated_text:
            if message['role'] == 'assistant':
                answer = message['content']
                break
        else:
            answer = "No assistant response found in the model's output."

        logger.info(f"Extracted answer: {answer}")
        return answer
    except Exception as e:
        logger.error(f"Q&A process failed: {str(e)}")
        logger.error(traceback.format_exc())
        return f"Error occurred during Q&A process. Please try again. Error: {str(e)}"


# Main interface
with gr.Blocks() as iface:
    gr.Markdown("# WhisperX Audio Transcription, Translation, Summarization, and Q&A (with ZeroGPU support)")

    audio_input = gr.Audio(type="filepath")
    translate_checkbox = gr.Checkbox(label="Enable Translation")
    summarize_checkbox = gr.Checkbox(label="Enable Summarization", interactive=False)
    model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
    process_button = gr.Button("Process Audio")
    transcription_output = gr.Textbox(label="Transcription/Translation")
    full_text_output = gr.Textbox(label="Full Text")
    summary_output = gr.Textbox(label="Summary")
    
    question_input = gr.Textbox(label="Ask a question about the transcription")
    answer_button = gr.Button("Get Answer")
    answer_output = gr.Textbox(label="Answer")

    translate_alternate = gr.Button("Alternate Translation")
    translate_alternate_output = gr.Textbox(label="Alternate Translation")
    
    def update_summarize_checkbox(translate):
        return gr.Checkbox(interactive=translate)

    translate_checkbox.change(update_summarize_checkbox, inputs=[translate_checkbox], outputs=[summarize_checkbox])
    
    process_button.click(
        process_and_summarize,
        inputs=[audio_input, translate_checkbox, model_dropdown, summarize_checkbox],
        outputs=[transcription_output, full_text_output, summary_output]
    )
    
    answer_button.click(
        answer_question,
        inputs=[full_text_output, question_input],
        outputs=[answer_output]
    )

    translate_alternate.click(
        alternate_translation, 
        inputs=[summary_output],
        outputs=[translate_alternate_output]
    )
    
    gr.Markdown(
        f"""
        ## System Information
        - Device: {"CUDA" if torch.cuda.is_available() else "CPU"}
        - CUDA Available: {"Yes" if torch.cuda.is_available() else "No"}
        
        ## ZeroGPU Support
        This application supports ZeroGPU for Hugging Face Spaces pro users. 
        GPU-intensive tasks are automatically optimized for better performance when available.
        """
    )

iface.launch()