Porjaz commited on
Commit
67b38a5
1 Parent(s): f153e02

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +58 -8
custom_interface.py CHANGED
@@ -85,14 +85,64 @@ class ASR(Pretrained):
85
  return seq
86
 
87
 
88
- def classify_file(self, path):
89
- # waveform = self.load_audio(path)
90
- waveform, sr = librosa.load(path, sr=16000)
91
- waveform = torch.tensor(waveform)
92
 
93
- # Fake a batch:
94
- batch = waveform.unsqueeze(0)
95
- rel_length = torch.tensor([1.0])
96
- outputs = self.encode_batch(batch, rel_length)
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return outputs
 
85
  return seq
86
 
87
 
88
+ # def classify_file(self, path):
89
+ # # waveform = self.load_audio(path)
90
+ # waveform, sr = librosa.load(path, sr=16000)
91
+ # waveform = torch.tensor(waveform)
92
 
93
+ # # Fake a batch:
94
+ # batch = waveform.unsqueeze(0)
95
+ # rel_length = torch.tensor([1.0])
96
+ # outputs = self.encode_batch(batch, rel_length)
97
 
98
+ # return outputs
99
+
100
+ def classify_file(self, path):
101
+ # Load the audio file
102
+ waveform, sr = librosa.load(path, sr=16000)
103
+
104
+ # Get audio length in seconds
105
+ audio_length = len(waveform) / sr
106
+ print(f"Audio length: {audio_length:.2f} seconds")
107
+
108
+ # Detect non-silent segments
109
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
110
+
111
+ segments = []
112
+ current_segment = []
113
+ current_length = 0
114
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
115
+
116
+ for interval in non_silent_intervals:
117
+ start, end = interval
118
+ segment_part = waveform[start:end]
119
+
120
+ # If adding the next part exceeds max duration, store the segment and start a new one
121
+ if current_length + len(segment_part) > max_duration:
122
+ segments.append(np.concatenate(current_segment))
123
+ current_segment = []
124
+ current_length = 0
125
+
126
+ current_segment.append(segment_part)
127
+ current_length += len(segment_part)
128
+
129
+ # Append the last segment if it's not empty
130
+ if current_segment:
131
+ segments.append(np.concatenate(current_segment))
132
+
133
+ # Process each segment
134
+ outputs = []
135
+ for i, segment in enumerate(segments):
136
+ print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
137
+
138
+ segment_tensor = torch.tensor(segment)
139
+
140
+ # Fake a batch for the segment
141
+ batch = segment_tensor.unsqueeze(0)
142
+ rel_length = torch.tensor([1.0]) # Adjust if necessary
143
+
144
+ # Pass the segment through the ASR model
145
+ segment_output = self.encode_batch(batch, rel_length)
146
+ outputs.append(segment_output)
147
+
148
  return outputs