Kr08 commited on
Commit
62b6f11
·
verified ·
1 Parent(s): aaac499

Update audio_processing.py

Browse files
Files changed (1) hide show
  1. audio_processing.py +110 -75
audio_processing.py CHANGED
@@ -2,104 +2,139 @@ import whisperx
2
  import torch
3
  import numpy as np
4
  from scipy.signal import resample
5
- import numpy as np
6
- import whisperx
7
  from pyannote.audio import Pipeline
8
  import os
9
  from dotenv import load_dotenv
10
-
11
  load_dotenv()
 
 
 
 
12
 
13
  hf_token = os.getenv("HF_TOKEN")
14
- import whisperx
15
- import torch
16
- import numpy as np
17
 
18
- import whisperx
19
- import torch
20
- import numpy as np
21
 
22
- import whisperx
23
- import torch
24
- import numpy as np
25
- CHUNK_LENGTH= 30
26
-
27
-
28
- import whisperx
29
- import torch
30
- import numpy as np
31
 
32
- def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000): # 30 seconds at 16kHz
33
  chunks = []
34
- for i in range(0, len(audio), chunk_size):
35
  chunk = audio[i:i+chunk_size]
36
  if len(chunk) < chunk_size:
37
  chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
38
  chunks.append(chunk)
39
  return chunks
40
 
41
- def process_audio(audio_file):
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
- compute_type = "float32"
44
- audio = whisperx.load_audio(audio_file)
45
- model = whisperx.load_model("small", device, compute_type=compute_type)
46
-
47
- # Initialize speaker diarization pipeline
48
- diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
49
- diarization_pipeline = diarization_pipeline.to(torch.device(device))
50
 
51
- # Perform diarization on the entire audio
52
- diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
53
 
 
54
 
55
- # Preprocess audio into consistent chunks
56
- chunks = preprocess_audio(audio)
57
 
58
- language_segments = []
59
- final_segments = []
60
-
61
- for i, chunk in enumerate(chunks):
62
- # Detect language for this chunk
63
- lang = model.detect_language(chunk)
64
-
65
- # Transcribe this chunk
66
- result = model.transcribe(chunk, language=lang)
67
 
68
- chunk_start_time = i * 5 # Each chunk is 30 seconds
69
-
70
- # Adjust timestamps and add language information
71
- for segment in result["segments"]:
72
- segment_start = chunk_start_time + segment["start"]
73
- segment_end = chunk_start_time + segment["end"]
74
- segment["start"] = segment_start
75
- segment["end"] = segment_end
76
- segment["language"] = lang
 
 
 
 
 
 
 
 
77
 
78
- speakers = []
79
- for turn, track, speaker in diarization_result.itertracks(yield_label=True):
80
- if turn.start <= segment_end and turn.end >= segment_start:
81
- speakers.append(speaker)
82
- if speakers:
83
- segment["speaker"] = max(set(speakers), key=speakers.count)
84
- else:
85
- segment["speaker"] = "Unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- final_segments.append(segment)
88
- # Add language segment
89
- language_segments.append({
90
- "language": lang,
91
- "start": chunk_start_time,
92
- "end": chunk_start_time + 5
93
- })
94
 
95
- return language_segments, final_segments
 
96
 
97
- def print_results(language, language_probs, segments):
98
- print(f"Detected Language: {language}")
99
- print("Language Probabilities:")
100
- for lang, prob in language_probs.items():
101
- print(f" {lang}: {prob:.4f}")
102
-
103
- print("\nTranscription:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  for segment in segments:
105
- print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] Speaker {segment['speaker']}: {segment['text']}")
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
  from scipy.signal import resample
 
 
5
  from pyannote.audio import Pipeline
6
  import os
7
  from dotenv import load_dotenv
 
8
  load_dotenv()
9
+ import logging
10
+ import time
11
+ from difflib import SequenceMatcher
12
+ import spaces
13
 
14
  hf_token = os.getenv("HF_TOKEN")
 
 
 
15
 
16
+ CHUNK_LENGTH = 5
17
+ OVERLAP = 2
 
18
 
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
21
 
22
+ def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
23
  chunks = []
24
+ for i in range(0, len(audio), chunk_size - overlap):
25
  chunk = audio[i:i+chunk_size]
26
  if len(chunk) < chunk_size:
27
  chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
28
  chunks.append(chunk)
29
  return chunks
30
 
31
+ @spaces.GPU
32
+ def process_audio(audio_file, translate=False, model_size="small"):
33
+ start_time = time.time()
34
+
35
+ try:
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ compute_type = "float16" if device == "cuda" else "float32"
38
+ audio = whisperx.load_audio(audio_file)
39
+ model = whisperx.load_model(model_size, device, compute_type=compute_type)
40
 
41
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
42
+ diarization_pipeline = diarization_pipeline.to(torch.device(device))
43
 
44
+ diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
45
 
46
+ chunks = preprocess_audio(audio)
 
47
 
48
+ language_segments = []
49
+ final_segments = []
 
 
 
 
 
 
 
50
 
51
+ overlap_duration = 2 # 2 seconds overlap
52
+ for i, chunk in enumerate(chunks):
53
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
54
+ chunk_end_time = chunk_start_time + CHUNK_LENGTH
55
+ logger.info(f"Processing chunk {i+1}/{len(chunks)}")
56
+ lang = model.detect_language(chunk)
57
+ result_transcribe = model.transcribe(chunk, language=lang)
58
+ if translate:
59
+ result_translate = model.transcribe(chunk, task="translate")
60
+ chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
61
+ for j, t_seg in enumerate(result_transcribe["segments"]):
62
+ segment_start = chunk_start_time + t_seg["start"]
63
+ segment_end = chunk_start_time + t_seg["end"]
64
+ # Skip segments in the overlapping region of the previous chunk
65
+ if i > 0 and segment_end <= chunk_start_time + overlap_duration:
66
+ print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
67
+ continue
68
 
69
+ # Skip segments in the overlapping region of the next chunk
70
+ if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
71
+ print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
72
+ continue
73
+
74
+ speakers = []
75
+ for turn, track, speaker in diarization_result.itertracks(yield_label=True):
76
+ if turn.start <= segment_end and turn.end >= segment_start:
77
+ speakers.append(speaker)
78
+
79
+ segment = {
80
+ "start": segment_start,
81
+ "end": segment_end,
82
+ "language": lang,
83
+ "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
84
+ "text": t_seg["text"],
85
+ }
86
+
87
+ if translate:
88
+ segment["translated"] = result_translate["segments"][j]["text"]
89
+
90
+ final_segments.append(segment)
91
 
92
+ language_segments.append({
93
+ "language": lang,
94
+ "start": chunk_start_time,
95
+ "end": chunk_start_time + CHUNK_LENGTH
96
+ })
97
+ chunk_end_time = time.time()
98
+ logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
99
 
100
+ final_segments.sort(key=lambda x: x["start"])
101
+ merged_segments = merge_nearby_segments(final_segments)
102
 
103
+ end_time = time.time()
104
+ logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
105
+
106
+ return language_segments, merged_segments
107
+ except Exception as e:
108
+ logger.error(f"An error occurred during audio processing: {str(e)}")
109
+ raise
110
+
111
+ def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
112
+ merged = []
113
+ for segment in segments:
114
+ if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
115
+ merged.append(segment)
116
+ else:
117
+ # Find the overlap
118
+ matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
119
+ match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
120
+
121
+ if match.size / len(segment['text']) > similarity_threshold:
122
+ # Merge the segments
123
+ merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
124
+ merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
125
+
126
+ merged[-1]['end'] = segment['end']
127
+ merged[-1]['text'] = merged_text
128
+ merged[-1]['translated'] = merged_translated
129
+ else:
130
+ # If no significant overlap, append as a new segment
131
+ merged.append(segment)
132
+ return merged
133
+
134
+ def print_results(segments):
135
  for segment in segments:
136
+ print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
137
+ print(f"Original: {segment['text']}")
138
+ if 'translated' in segment:
139
+ print(f"Translated: {segment['translated']}")
140
+ print()