Update custom_interface.py
Browse files- custom_interface.py +10 -58
custom_interface.py
CHANGED
@@ -85,64 +85,16 @@ class ASR(Pretrained):
|
|
85 |
return seq
|
86 |
|
87 |
|
88 |
-
# def classify_file(self, path):
|
89 |
-
# # waveform = self.load_audio(path)
|
90 |
-
# waveform, sr = librosa.load(path, sr=16000)
|
91 |
-
# waveform = torch.tensor(waveform)
|
92 |
-
|
93 |
-
# # Fake a batch:
|
94 |
-
# batch = waveform.unsqueeze(0)
|
95 |
-
# rel_length = torch.tensor([1.0])
|
96 |
-
# outputs = self.encode_batch(batch, rel_length)
|
97 |
-
|
98 |
-
# return outputs
|
99 |
-
|
100 |
def classify_file(self, path):
|
101 |
-
#
|
102 |
waveform, sr = librosa.load(path, sr=16000)
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
segments = []
|
112 |
-
current_segment = []
|
113 |
-
current_length = 0
|
114 |
-
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
|
115 |
-
|
116 |
-
for interval in non_silent_intervals:
|
117 |
-
start, end = interval
|
118 |
-
segment_part = waveform[start:end]
|
119 |
-
|
120 |
-
# If adding the next part exceeds max duration, store the segment and start a new one
|
121 |
-
if current_length + len(segment_part) > max_duration:
|
122 |
-
segments.append(np.concatenate(current_segment))
|
123 |
-
current_segment = []
|
124 |
-
current_length = 0
|
125 |
-
|
126 |
-
current_segment.append(segment_part)
|
127 |
-
current_length += len(segment_part)
|
128 |
-
|
129 |
-
# Append the last segment if it's not empty
|
130 |
-
if current_segment:
|
131 |
-
segments.append(np.concatenate(current_segment))
|
132 |
-
|
133 |
-
# Process each segment
|
134 |
-
outputs = []
|
135 |
-
for i, segment in enumerate(segments):
|
136 |
-
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
|
137 |
-
|
138 |
-
segment_tensor = torch.tensor(segment)
|
139 |
-
|
140 |
-
# Fake a batch for the segment
|
141 |
-
batch = segment_tensor.unsqueeze(0)
|
142 |
-
rel_length = torch.tensor([1.0]) # Adjust if necessary
|
143 |
-
|
144 |
-
# Pass the segment through the ASR model
|
145 |
-
segment_output = self.encode_batch(batch, rel_length)
|
146 |
-
outputs.append(segment_output)
|
147 |
-
|
148 |
return outputs
|
|
|
|
|
|
85 |
return seq
|
86 |
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def classify_file(self, path):
|
89 |
+
# waveform = self.load_audio(path)
|
90 |
waveform, sr = librosa.load(path, sr=16000)
|
91 |
+
waveform = torch.tensor(waveform)
|
92 |
+
|
93 |
+
# Fake a batch:
|
94 |
+
batch = waveform.unsqueeze(0)
|
95 |
+
rel_length = torch.tensor([1.0])
|
96 |
+
outputs = self.encode_batch(batch, rel_length)
|
97 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
return outputs
|
99 |
+
|
100 |
+
|