Update custom_interface.py
Browse files- custom_interface.py +58 -8
custom_interface.py
CHANGED
@@ -85,14 +85,64 @@ class ASR(Pretrained):
|
|
85 |
return seq
|
86 |
|
87 |
|
88 |
-
def classify_file(self, path):
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
return outputs
|
|
|
85 |
return seq
|
86 |
|
87 |
|
88 |
+
# def classify_file(self, path):
|
89 |
+
# # waveform = self.load_audio(path)
|
90 |
+
# waveform, sr = librosa.load(path, sr=16000)
|
91 |
+
# waveform = torch.tensor(waveform)
|
92 |
|
93 |
+
# # Fake a batch:
|
94 |
+
# batch = waveform.unsqueeze(0)
|
95 |
+
# rel_length = torch.tensor([1.0])
|
96 |
+
# outputs = self.encode_batch(batch, rel_length)
|
97 |
|
98 |
+
# return outputs
|
99 |
+
|
100 |
+
def classify_file(self, path):
|
101 |
+
# Load the audio file
|
102 |
+
waveform, sr = librosa.load(path, sr=16000)
|
103 |
+
|
104 |
+
# Get audio length in seconds
|
105 |
+
audio_length = len(waveform) / sr
|
106 |
+
print(f"Audio length: {audio_length:.2f} seconds")
|
107 |
+
|
108 |
+
# Detect non-silent segments
|
109 |
+
non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
|
110 |
+
|
111 |
+
segments = []
|
112 |
+
current_segment = []
|
113 |
+
current_length = 0
|
114 |
+
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
|
115 |
+
|
116 |
+
for interval in non_silent_intervals:
|
117 |
+
start, end = interval
|
118 |
+
segment_part = waveform[start:end]
|
119 |
+
|
120 |
+
# If adding the next part exceeds max duration, store the segment and start a new one
|
121 |
+
if current_length + len(segment_part) > max_duration:
|
122 |
+
segments.append(np.concatenate(current_segment))
|
123 |
+
current_segment = []
|
124 |
+
current_length = 0
|
125 |
+
|
126 |
+
current_segment.append(segment_part)
|
127 |
+
current_length += len(segment_part)
|
128 |
+
|
129 |
+
# Append the last segment if it's not empty
|
130 |
+
if current_segment:
|
131 |
+
segments.append(np.concatenate(current_segment))
|
132 |
+
|
133 |
+
# Process each segment
|
134 |
+
outputs = []
|
135 |
+
for i, segment in enumerate(segments):
|
136 |
+
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
|
137 |
+
|
138 |
+
segment_tensor = torch.tensor(segment)
|
139 |
+
|
140 |
+
# Fake a batch for the segment
|
141 |
+
batch = segment_tensor.unsqueeze(0)
|
142 |
+
rel_length = torch.tensor([1.0]) # Adjust if necessary
|
143 |
+
|
144 |
+
# Pass the segment through the ASR model
|
145 |
+
segment_output = self.encode_batch(batch, rel_length)
|
146 |
+
outputs.append(segment_output)
|
147 |
+
|
148 |
return outputs
|