Nathan Slaughter commited on
Commit
e8ecce6
1 Parent(s): 76ed6be

add Qwen2VL for action item inference

Browse files
Files changed (2) hide show
  1. app.py +84 -48
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,66 +1,103 @@
1
  import torch
2
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
3
  import gradio as gr
4
  import librosa
5
 
6
- # 1. Determine the device
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- # 2. Load the processor and model
10
- processor = AutoProcessor.from_pretrained("openai/whisper-large", language='en')
11
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
 
 
 
 
 
12
  "openai/whisper-large"
13
  )
14
 
15
- # 3. Move the model to the device
16
- model.to(device)
17
 
18
- def transcribe_audio(audio_path: str) -> str:
19
  try:
20
- # Step 1: Load the audio file
21
- # librosa.load returns a tuple (audio_data, sampling_rate)
22
- audio, sr = librosa.load(audio_path, sr=16000) # Resample to 16000 Hz
23
-
24
- # Step 2: Transcribe the audio
25
- inputs = processor(audio, sampling_rate=16000, return_tensors="pt", language='en')
26
- input_features = inputs.input_features.to(device) #type: ignore
27
-
28
- # Generate transcription
29
- with torch.no_grad(): #type: ignore
30
- predicted_ids = model.generate(input_features) #type: ignore
31
-
32
- # Decode the transcription
33
- transcript = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
34
-
35
- return transcript
36
-
37
  except Exception as e:
38
- return f"Error during processing: {str(e)}"
 
 
39
 
40
 
41
  def extract_action_items(transcript: str) -> str:
42
- action_keywords = ["action item", "todo", "task", "follow up", "need to"]
43
- sentences = transcript.split('.')
44
- action_items = [
45
- sentence.strip() + '.'
46
- for sentence in sentences
47
- if any(keyword in sentence.lower() for keyword in action_keywords)
48
- ]
49
- return "\n".join(action_items) if action_items else "No action items found."
50
-
51
-
52
- def transcribe_and_extract_action_items(audio_path: str) -> tuple[str, str]:
53
  try:
54
- transcript = transcribe_audio(audio_path)
55
-
56
- # Join action items into a single string, separated by newlines
57
- action_items_text = extract_action_items(transcript)
58
-
59
- return transcript, action_items_text
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
- return f"Error during processing: {str(e)}", ""
 
 
 
 
 
63
 
 
 
 
64
 
65
  # Define the Gradio interface components
66
  input_audio = gr.Audio(
@@ -96,4 +133,3 @@ interface = gr.Interface(
96
  # 5. Launch the interface
97
  if __name__ == "__main__":
98
  interface.launch()
99
-
 
1
  import torch
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForSpeechSeq2Seq
3
  import gradio as gr
4
  import librosa
5
 
6
+ # Determine the device
7
+ if torch.cuda.is_available(): # for CUDA
8
+ device = torch.device("cuda")
9
+ elif torch.backends.mps.is_available(): # for Apple MPS
10
+ device = torch.device("mps")
11
+ else: # fallback for CPU
12
+ device = torch.device("cpu")
13
+
14
+ # Load the audio processor and model
15
+ stt_processor = AutoProcessor.from_pretrained("openai/whisper-large", language='en')
16
+ stt_model = AutoModelForSpeechSeq2Seq.from_pretrained(
17
  "openai/whisper-large"
18
  )
19
 
20
+ # Move the model to the device
21
+ stt_model.to(device)
22
 
23
+ def transcribe_audio(audio_path: str):
24
  try:
25
+ audio, sr = librosa.load(audio_path, sr=16000)
26
+ inputs = stt_processor(audio, sampling_rate=16000, return_tensors="pt", language='en')
27
+ input_features = inputs.input_features.to(device)
28
+ with torch.no_grad():
29
+ predicted_ids = stt_model.generate(input_features)
30
+ transcript = stt_processor.batch_decode(predicted_ids, skip_special_tokens=True, language='en')[0]
 
 
 
 
 
 
 
 
 
 
 
31
  except Exception as e:
32
+ return f"Error during transcription: {str(e)}"
33
+ finally:
34
+ return transcript
35
 
36
 
37
  def extract_action_items(transcript: str) -> str:
38
+ """
39
+ Extracts action items from a transcript using the Llama-3.1-8B-Instruct model.
40
+ see example code in the model card: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
41
+ """
42
+ model_id = "Qwen/Qwen2-VL-7B-Instruct"
 
 
 
 
 
 
43
  try:
44
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="auto",
48
+ # attn_implementation="flash_attention_2"
49
+ )
50
+ # default processer
51
+ processor = AutoProcessor.from_pretrained(model_id)
52
+ messages = [
53
+ {
54
+ "role": "user",
55
+ "content": [
56
+ {"type": "text", "text": f"""Infer the action items from the following meeting transcript
57
+ and list them as a bulleted list in the format:\n- [item short title]: [item description]
58
+
59
+ The [item short title] should be a short phrase that summarizes the action item.
60
+ The [item description] should be a longer description of the action item.
61
+
62
+ TRANSCRIPT:
63
+
64
+ {transcript}
65
+ """
66
+ }
67
+ ],
68
+ }
69
+ ]
70
+
71
+ # Preparation for inference
72
+ text = processor.apply_chat_template(
73
+ messages, tokenize=False, add_generation_prompt=True
74
+ )
75
+ inputs = processor(
76
+ text=[text],
77
+ padding=True,
78
+ return_tensors="pt",
79
+ )
80
+ inputs = inputs.to(device)
81
+ # Extract action items
82
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
83
+ generated_ids_trimmed = [
84
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
85
+ ]
86
+ output_text = processor.batch_decode(
87
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
88
+ )
89
+ return output_text
90
  except Exception as e:
91
+ return f"Error during action item extraction: {str(e)}"
92
+
93
+ def transcribe_and_extract_action_items(audio_path):
94
+ transcript = transcribe_audio(audio_path)
95
+ action_items_text = extract_action_items(transcript)
96
+ return transcript, action_items_text
97
 
98
+ ##################################################
99
+ # Gradio Interface
100
+ ##################################################
101
 
102
  # Define the Gradio interface components
103
  input_audio = gr.Audio(
 
133
  # 5. Launch the interface
134
  if __name__ == "__main__":
135
  interface.launch()
 
requirements.txt CHANGED
@@ -3,3 +3,6 @@ pydantic
3
  openai
4
  librosa
5
  langchain
 
 
 
 
3
  openai
4
  librosa
5
  langchain
6
+ transformers
7
+ bitsandbytes
8
+ accelerate