Xubo-Liu commited on
Commit
11e99cd
1 Parent(s): 0a03745

Update models/resunet.py

Browse files
Files changed (1) hide show
  1. models/resunet.py +60 -0
models/resunet.py CHANGED
@@ -652,4 +652,64 @@ class ResUNet30(nn.Module):
652
 
653
  return output_dict
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
 
 
652
 
653
  return output_dict
654
 
655
+
656
+ @torch.no_grad()
657
+ def chunk_inference(self, input_dict):
658
+ chunk_config = {
659
+ 'NL': 1.0,
660
+ 'NC': 3.0,
661
+ 'NR': 1.0,
662
+ 'RATE': self.sampling_rate
663
+ }
664
+
665
+ mixtures = input_dict['mixture']
666
+ conditions = input_dict['condition']
667
+
668
+ film_dict = self.film(
669
+ conditions=conditions,
670
+ )
671
+
672
+ NL = int(chunk_config['NL'] * chunk_config['RATE'])
673
+ NC = int(chunk_config['NC'] * chunk_config['RATE'])
674
+ NR = int(chunk_config['NR'] * chunk_config['RATE'])
675
+
676
+ L = mixtures.shape[2]
677
+
678
+ out_np = np.zeros([1, L])
679
+
680
+ WINDOW = NL + NC + NR
681
+ current_idx = 0
682
+
683
+ while current_idx + WINDOW < L:
684
+ chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
685
+
686
+ chunk_out = self.base(
687
+ mixtures=chunk_in,
688
+ film_dict=film_dict,
689
+ )['waveform']
690
+
691
+ chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
692
+
693
+ if current_idx == 0:
694
+ out_np[:, current_idx:current_idx+WINDOW-NR] = \
695
+ chunk_out_np[:, :-NR] if NR != 0 else chunk_out_np
696
+ else:
697
+ out_np[:, current_idx+NL:current_idx+WINDOW-NR] = \
698
+ chunk_out_np[:, NL:-NR] if NR != 0 else chunk_out_np[:, NL:]
699
+
700
+ current_idx += NC
701
+
702
+ if current_idx < L:
703
+ chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
704
+ chunk_out = self.base(
705
+ mixtures=chunk_in,
706
+ film_dict=film_dict,
707
+ )['waveform']
708
+
709
+ chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
710
+
711
+ seg_len = chunk_out_np.shape[1]
712
+ out_np[:, current_idx + NL:current_idx + seg_len] = \
713
+ chunk_out_np[:, NL:]
714
 
715
+ return out_np