raygiles3 commited on
Commit
d17d37a
1 Parent(s): 20865ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import login
5
  import os
6
 
@@ -18,8 +18,9 @@ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
18
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
19
 
20
  # Initialize the summarization model and tokenizer
21
- summarization_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
22
- summarization_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 
23
 
24
  # Function to transcribe audio
25
  def transcribe_audio(audio_file):
@@ -32,7 +33,7 @@ def transcribe_audio(audio_file):
32
 
33
  # Function to summarize text
34
  def summarize_text(text):
35
- inputs = summarization_tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
36
  summary_ids = summarization_model.generate(inputs.input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
37
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
38
  return summary
@@ -46,7 +47,7 @@ def process_audio(audio_file):
46
  # Gradio UI
47
  iface = gr.Interface(
48
  fn=process_audio,
49
- inputs=gr.Audio(source="upload", type="file"),
50
  outputs=[
51
  gr.Textbox(label="Transcription"),
52
  gr.Textbox(label="Summary")
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, BartForConditionalGeneration, BartTokenizer
4
  from huggingface_hub import login
5
  import os
6
 
 
18
  whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
19
 
20
  # Initialize the summarization model and tokenizer
21
+ # Use BART model for summarization
22
+ summarization_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
23
+ summarization_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
24
 
25
  # Function to transcribe audio
26
  def transcribe_audio(audio_file):
 
33
 
34
  # Function to summarize text
35
  def summarize_text(text):
36
+ inputs = summarization_tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
37
  summary_ids = summarization_model.generate(inputs.input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
38
  summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
  return summary
 
47
  # Gradio UI
48
  iface = gr.Interface(
49
  fn=process_audio,
50
+ inputs=gr.Audio(type="file"),
51
  outputs=[
52
  gr.Textbox(label="Transcription"),
53
  gr.Textbox(label="Summary")