Capstone04 commited on
Commit
cc204c2
·
verified ·
1 Parent(s): 9227788

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +76 -23
asr_diarization/pipeline.py CHANGED
@@ -48,9 +48,47 @@ class ASR_Diarization:
48
  for t, _, spk in diarization.itertracks(yield_label=True)
49
  ]
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def run_transcription(self, audio_path, diar_json):
52
  audio, sr = torchaudio.load(audio_path)
53
- merged_segments = []
54
  speaker_segments = {}
55
 
56
  for seg in diar_json:
@@ -61,30 +99,45 @@ class ASR_Diarization:
61
  reduced = nr.reduce_noise(y=chunk, sr=sr)
62
  result = self.asr_pipeline(reduced)
63
 
64
- tokens = []
65
  if "chunks" in result:
66
  for word_info in result["chunks"]:
67
- start_ts, end_ts = word_info.get("timestamp", (None, None)) or (None, None)
68
- tokens.append({
69
- "tag": "w",
70
- "start": start_ts,
71
- "end": end_ts,
72
- "text": word_info["text"]
73
- })
74
-
75
- seg_dict = {
76
- "speaker": spk,
77
- "segment_start": segment_start,
78
- "segment_end": segment_end,
79
- "tokens": tokens
80
- }
81
- merged_segments.append(seg_dict)
82
-
83
- if spk not in speaker_segments:
84
- speaker_segments[spk] = []
85
- speaker_segments[spk].append(seg_dict)
86
-
87
- return merged_segments, list(speaker_segments.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
90
  ref_rttm=None, ref_json=None):
 
48
  for t, _, spk in diarization.itertracks(yield_label=True)
49
  ]
50
 
51
+ # def run_transcription(self, audio_path, diar_json):
52
+ # audio, sr = torchaudio.load(audio_path)
53
+ # merged_segments = []
54
+ # speaker_segments = {}
55
+
56
+ # for seg in diar_json:
57
+ # segment_start, segment_end, spk = seg["segment_start"], seg["segment_end"], seg["speaker"]
58
+ # start_sample, end_sample = int(segment_start * sr), int(segment_end * sr)
59
+ # chunk = audio[0, start_sample:end_sample].numpy()
60
+
61
+ # reduced = nr.reduce_noise(y=chunk, sr=sr)
62
+ # result = self.asr_pipeline(reduced)
63
+
64
+ # tokens = []
65
+ # if "chunks" in result:
66
+ # for word_info in result["chunks"]:
67
+ # start_ts, end_ts = word_info.get("timestamp", (None, None)) or (None, None)
68
+ # tokens.append({
69
+ # "tag": "w",
70
+ # "start": start_ts,
71
+ # "end": end_ts,
72
+ # "text": word_info["text"]
73
+ # })
74
+
75
+ # seg_dict = {
76
+ # "speaker": spk,
77
+ # "segment_start": segment_start,
78
+ # "segment_end": segment_end,
79
+ # "tokens": tokens
80
+ # }
81
+ # merged_segments.append(seg_dict)
82
+
83
+ # if spk not in speaker_segments:
84
+ # speaker_segments[spk] = []
85
+ # speaker_segments[spk].append(seg_dict)
86
+
87
+ # return merged_segments, list(speaker_segments.keys())
88
+
89
  def run_transcription(self, audio_path, diar_json):
90
  audio, sr = torchaudio.load(audio_path)
91
+ all_word_segments = []
92
  speaker_segments = {}
93
 
94
  for seg in diar_json:
 
99
  reduced = nr.reduce_noise(y=chunk, sr=sr)
100
  result = self.asr_pipeline(reduced)
101
 
 
102
  if "chunks" in result:
103
  for word_info in result["chunks"]:
104
+ # Each word or token gets its own mini segment
105
+ start_ts, end_ts = None, None
106
+
107
+ if isinstance(word_info.get("timestamp"), (list, tuple)):
108
+ start_ts, end_ts = word_info["timestamp"]
109
+ elif isinstance(word_info.get("timestamp"), (float, int)):
110
+ start_ts = word_info["timestamp"]
111
+ end_ts = start_ts
112
+
113
+ if start_ts is None:
114
+ continue
115
+
116
+ # Shift timestamps to align with full audio
117
+ abs_start = segment_start + start_ts
118
+ abs_end = segment_start + end_ts
119
+
120
+ word_segment = {
121
+ "speaker": spk,
122
+ "segment_start": abs_start,
123
+ "segment_end": abs_end,
124
+ "tokens": [
125
+ {
126
+ "tag": "w",
127
+ "start": abs_start,
128
+ "end": abs_end,
129
+ "text": word_info["text"].strip()
130
+ }
131
+ ]
132
+ }
133
+
134
+ all_word_segments.append(word_segment)
135
+
136
+ if spk not in speaker_segments:
137
+ speaker_segments[spk] = []
138
+ speaker_segments[spk].append(word_segment)
139
+
140
+ return all_word_segments, list(speaker_segments.keys())
141
 
142
  def run_pipeline(self, audio_path, output_dir=None, base_name=None,
143
  ref_rttm=None, ref_json=None):