# !/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2021/3/9 16:33 # @Author : dongchao yang # @File : train.py from itertools import zip_longest import numpy as np from scipy import ndimage import torch import torch.nn as nn import torch.nn.functional as F import time from torchlibrosa.augmentation import SpecAugmentation from torchlibrosa.stft import Spectrogram, LogmelFilterBank import math from sklearn.cluster import KMeans import os import time from functools import partial # import timm # from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import warnings from functools import partial # from timm.models.registry import register_model # from timm.models.vision_transformer import _cfg # from mmdet.utils import get_root_logger # from mmcv.runner import load_checkpoint # from mmcv.runner import _load_checkpoint, load_state_dict # import mmcv.runner import copy from collections import OrderedDict import io import re DEBUG=0 event_labels = ['Alarm', 'Alarm_clock', 'Animal', 'Applause', 'Arrow', 'Artillery_fire', 'Babbling', 'Baby_laughter', 'Bark', 'Basketball_bounce', 'Battle_cry', 'Bell', 'Bird', 'Bleat', 'Bouncing', 'Breathing', 'Buzz', 'Camera', 'Cap_gun', 'Car', 'Car_alarm', 'Cat', 'Caw', 'Cheering', 'Child_singing', 'Choir', 'Chop', 'Chopping_(food)', 'Clapping', 'Clickety-clack', 'Clicking', 'Clip-clop', 'Cluck', 'Coin_(dropping)', 'Computer_keyboard', 'Conversation', 'Coo', 'Cough', 'Cowbell', 'Creak', 'Cricket', 'Croak', 'Crow', 'Crowd', 'DTMF', 'Dog', 'Door', 'Drill', 'Drip', 'Engine', 'Engine_starting', 'Explosion', 'Fart', 'Female_singing', 'Filing_(rasp)', 'Finger_snapping', 'Fire', 'Fire_alarm', 'Firecracker', 'Fireworks', 'Frog', 'Gasp', 'Gears', 'Giggle', 'Glass', 'Glass_shatter', 'Gobble', 'Groan', 'Growling', 'Hammer', 'Hands', 'Hiccup', 'Honk', 'Hoot', 'Howl', 'Human_sounds', 'Human_voice', 'Insect', 'Laughter', 'Liquid', 'Machine_gun', 'Male_singing', 'Mechanisms', 'Meow', 'Moo', 'Motorcycle', 'Mouse', 'Music', 'Oink', 'Owl', 'Pant', 'Pant_(dog)', 'Patter', 'Pig', 'Plop', 'Pour', 'Power_tool', 'Purr', 'Quack', 'Radio', 'Rain_on_surface', 'Rapping', 'Rattle', 'Reversing_beeps', 'Ringtone', 'Roar', 'Run', 'Rustle', 'Scissors', 'Scrape', 'Scratch', 'Screaming', 'Sewing_machine', 'Shout', 'Shuffle', 'Shuffling_cards', 'Singing', 'Single-lens_reflex_camera', 'Siren', 'Skateboard', 'Sniff', 'Snoring', 'Speech', 'Speech_synthesizer', 'Spray', 'Squeak', 'Squeal', 'Steam', 'Stir', 'Surface_contact', 'Tap', 'Tap_dance', 'Telephone_bell_ringing', 'Television', 'Tick', 'Tick-tock', 'Tools', 'Train', 'Train_horn', 'Train_wheels_squealing', 'Truck', 'Turkey', 'Typewriter', 'Typing', 'Vehicle', 'Video_game_sound', 'Water', 'Whimper_(dog)', 'Whip', 'Whispering', 'Whistle', 'Whistling', 'Whoop', 'Wind', 'Writing', 'Yip', 'and_pans', 'bird_song', 'bleep', 'clink', 'cock-a-doodle-doo', 'crinkling', 'dove', 'dribble', 'eructation', 'faucet', 'flapping_wings', 'footsteps', 'gunfire', 'heartbeat', 'infant_cry', 'kid_speaking', 'man_speaking', 'mastication', 'mice', 'river', 'rooster', 'silverware', 'skidding', 'smack', 'sobbing', 'speedboat', 'splatter', 'surf', 'thud', 'thwack', 'toot', 'truck_horn', 'tweet', 'vroom', 'waterfowl', 'woman_speaking'] def load_checkpoint(model, filename, map_location=None, strict=False, logger=None, revise_keys=[(r'^module\.', '')]): """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. logger (:mod:`logging.Logger` or None): The logger for error message. revise_keys (list): A list of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module\\.', '')]. Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint = _load_checkpoint(filename, map_location, logger) ''' new_proj = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=1).unsqueeze(1)) checkpoint['patch_embed1.proj.weight'] = new_proj.weight new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=2).unsqueeze(2).repeat(1,1,3,1)) checkpoint['patch_embed1.proj.weight'] = new_proj.weight new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=3).unsqueeze(3).repeat(1,1,1,3)) checkpoint['patch_embed1.proj.weight'] = new_proj.weight ''' new_proj = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) new_proj.weight = torch.nn.Parameter(torch.sum(checkpoint['patch_embed1.proj.weight'], dim=1).unsqueeze(1)) checkpoint['patch_embed1.proj.weight'] = new_proj.weight # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f'No state_dict found in checkpoint file {filename}') # get state_dict from checkpoint if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: state_dict = OrderedDict( {re.sub(p, r, k): v for k, v in state_dict.items()}) state_dict = OrderedDict({k.replace('backbone.',''):v for k,v in state_dict.items()}) # Keep metadata in state_dict state_dict._metadata = metadata # load state_dict load_state_dict(model, state_dict, strict, logger) return checkpoint def init_weights(m): if isinstance(m, (nn.Conv2d, nn.Conv1d)): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def init_layer(layer): """Initialize a Linear or Convolutional layer. """ nn.init.xavier_uniform_(layer.weight) if hasattr(layer, 'bias'): if layer.bias is not None: layer.bias.data.fill_(0.) def init_bn(bn): """Initialize a Batchnorm layer. """ bn.bias.data.fill_(0.) bn.weight.data.fill_(1.) class MaxPool(nn.Module): def __init__(self, pooldim=1): super().__init__() self.pooldim = pooldim def forward(self, logits, decision): return torch.max(decision, dim=self.pooldim)[0] class LinearSoftPool(nn.Module): """LinearSoftPool Linear softmax, takes logits and returns a probability, near to the actual maximum value. Taken from the paper: A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling https://arxiv.org/abs/1810.09050 """ def __init__(self, pooldim=1): super().__init__() self.pooldim = pooldim def forward(self, logits, time_decision): return (time_decision**2).sum(self.pooldim) / (time_decision.sum( self.pooldim)+1e-7) class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.bn2 = nn.BatchNorm2d(out_channels) self.init_weight() def init_weight(self): init_layer(self.conv1) init_layer(self.conv2) init_bn(self.bn1) init_bn(self.bn2) def forward(self, input, pool_size=(2, 2), pool_type='avg'): x = input x = F.relu_(self.bn1(self.conv1(x))) x = F.relu_(self.bn2(self.conv2(x))) if pool_type == 'max': x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg': x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg+max': x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 else: raise Exception('Incorrect argument!') return x class ConvBlock_GLU(nn.Module): def __init__(self, in_channels, out_channels,kernel_size=(3,3)): super(ConvBlock_GLU, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1, 1), padding=(1, 1), bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.sigmoid = nn.Sigmoid() self.init_weight() def init_weight(self): init_layer(self.conv1) init_bn(self.bn1) def forward(self, input, pool_size=(2, 2), pool_type='avg'): x = input x = self.bn1(self.conv1(x)) cnn1 = self.sigmoid(x[:, :x.shape[1]//2, :, :]) cnn2 = x[:,x.shape[1]//2:,:,:] x = cnn1*cnn2 if pool_type == 'max': x = F.max_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg': x = F.avg_pool2d(x, kernel_size=pool_size) elif pool_type == 'avg+max': x1 = F.avg_pool2d(x, kernel_size=pool_size) x2 = F.max_pool2d(x, kernel_size=pool_size) x = x1 + x2 elif pool_type == 'None': pass elif pool_type == 'LP': pass #nn.LPPool2d(4, pool_size) else: raise Exception('Incorrect argument!') return x class Mul_scale_GLU(nn.Module): def __init__(self): super(Mul_scale_GLU,self).__init__() self.conv_block1_1 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(1,1)) # 1*1 self.conv_block1_2 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(3,3)) # 3*3 self.conv_block1_3 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(5,5)) # 5*5 self.conv_block2 = ConvBlock_GLU(in_channels=96, out_channels=128*2) # self.conv_block3 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock_GLU(in_channels=128, out_channels=128*2) self.conv_block4 = ConvBlock_GLU(in_channels=128, out_channels=256*2) self.conv_block5 = ConvBlock_GLU(in_channels=256, out_channels=256*2) self.conv_block6 = ConvBlock_GLU(in_channels=256, out_channels=512*2) self.conv_block7 = ConvBlock_GLU(in_channels=512, out_channels=512*2) self.padding = nn.ReplicationPad2d((0,1,0,1)) def forward(self, input, fi=None): """ Input: (batch_size, data_length)""" x1 = self.conv_block1_1(input, pool_size=(2, 2), pool_type='avg') x1 = x1[:,:,:500,:32] #print('x1 ',x1.shape) x2 = self.conv_block1_2(input,pool_size=(2,2),pool_type='avg') #print('x2 ',x2.shape) x3 = self.conv_block1_3(input,pool_size=(2,2),pool_type='avg') x3 = self.padding(x3) #print('x3 ',x3.shape) # assert 1==2 x = torch.cat([x1,x2],dim=1) x = torch.cat([x,x3],dim=1) #print('x ',x.shape) x = self.conv_block2(x, pool_size=(2, 2), pool_type='None') x = self.conv_block3(x,pool_size=(2,2),pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) # #print('x2,3 ',x.shape) x = self.conv_block4(x, pool_size=(2, 4), pool_type='None') x = self.conv_block5(x,pool_size=(2,4),pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) #print('x4,5 ',x.shape) x = self.conv_block6(x, pool_size=(1, 4), pool_type='None') x = self.conv_block7(x, pool_size=(1, 4), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) # print('x6,7 ',x.shape) # assert 1==2 return x class Cnn14(nn.Module): def __init__(self, sample_rate=32000, window_size=1024, hop_size=320, mel_bins=64, fmin=50, fmax=14000, classes_num=527): super(Cnn14, self).__init__() window = 'hann' center = True pad_mode = 'reflect' ref = 1.0 amin = 1e-10 top_db = None # Spectrogram extractor self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, win_length=window_size, window=window, center=center, pad_mode=pad_mode, freeze_parameters=True) # Logmel feature extractor self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, freeze_parameters=True) # Spec augmenter self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2) self.bn0 = nn.BatchNorm2d(64) self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) self.fc1 = nn.Linear(2048, 128, bias=True) self.fc_audioset = nn.Linear(128, classes_num, bias=True) self.init_weight() def init_weight(self): init_layer(self.fc1) init_layer(self.fc_audioset) def forward(self, input_, mixup_lambda=None): """ Input: (batch_size, data_length)""" input_ = input_.unsqueeze(1) x = self.conv_block1(input_, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block5(x, pool_size=(1, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block6(x, pool_size=(1, 2), pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) # print(x.shape) # x = torch.mean(x, dim=3) x = x.transpose(1, 2).contiguous().flatten(-2) x = self.fc1(x) # print(x.shape) # assert 1==2 # (x1,_) = torch.max(x, dim=2) # x2 = torch.mean(x, dim=2) # x = x1 + x2 # x = F.dropout(x, p=0.5, training=self.training) # x = F.relu_(self.fc1(x)) # embedding = F.dropout(x, p=0.5, training=self.training) return x class Cnn10_fi(nn.Module): def __init__(self): super(Cnn10_fi, self).__init__() self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) # self.fc1 = nn.Linear(512, 512, bias=True) # self.fc_audioset = nn.Linear(512, classes_num, bias=True) # self.init_weight() def forward(self, input, fi=None): """ Input: (batch_size, data_length)""" x = self.conv_block1(input, pool_size=(2, 2), pool_type='avg') if fi != None: gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) x = (gamma)*x + beta x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') if fi != None: gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) x = (gamma)*x + beta x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=(2, 4), pool_type='avg') if fi != None: gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) x = (gamma)*x + beta x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=(1, 4), pool_type='avg') if fi != None: gamma = fi[:,0].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) beta = fi[:,1].unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(x) x = (gamma)*x + beta x = F.dropout(x, p=0.2, training=self.training) return x class Cnn10_mul_scale(nn.Module): def __init__(self,scale=8): super(Cnn10_mul_scale, self).__init__() self.conv_block1_1 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(1,1)) self.conv_block1_2 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(3,3)) self.conv_block1_3 = ConvBlock_GLU(in_channels=1, out_channels=64,kernel_size=(5,5)) self.conv_block2 = ConvBlock(in_channels=96, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.scale = scale self.padding = nn.ReplicationPad2d((0,1,0,1)) def forward(self, input, pool_size=(2, 2), pool_type='avg'): """ Input: (batch_size, data_length)""" if self.scale == 8: pool_size1 = (2,2) pool_size2 = (2,2) pool_size3 = (2,4) pool_size4 = (1,4) elif self.scale == 4: pool_size1 = (2,2) pool_size2 = (2,2) pool_size3 = (1,4) pool_size4 = (1,4) elif self.scale == 2: pool_size1 = (2,2) pool_size2 = (1,2) pool_size3 = (1,4) pool_size4 = (1,4) else: pool_size1 = (1,2) pool_size2 = (1,2) pool_size3 = (1,4) pool_size4 = (1,4) # print('input ',input.shape) x1 = self.conv_block1_1(input, pool_size=pool_size1, pool_type='avg') x1 = x1[:,:,:500,:32] #print('x1 ',x1.shape) x2 = self.conv_block1_2(input, pool_size=pool_size1, pool_type='avg') #print('x2 ',x2.shape) x3 = self.conv_block1_3(input, pool_size=pool_size1, pool_type='avg') x3 = self.padding(x3) #print('x3 ',x3.shape) # assert 1==2 m_i = min(x3.shape[2],min(x1.shape[2],x2.shape[2])) #print('m_i ', m_i) x = torch.cat([x1[:,:,:m_i,:],x2[:,:, :m_i,:],x3[:,:, :m_i,:]],dim=1) # x = torch.cat([x,x3],dim=1) # x = self.conv_block1(input, pool_size=pool_size1, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=pool_size2, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=pool_size3, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=pool_size4, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) return x class Cnn10(nn.Module): def __init__(self,scale=8): super(Cnn10, self).__init__() self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.scale = scale def forward(self, input, pool_size=(2, 2), pool_type='avg'): """ Input: (batch_size, data_length)""" if self.scale == 8: pool_size1 = (2,2) pool_size2 = (2,2) pool_size3 = (2,4) pool_size4 = (1,4) elif self.scale == 4: pool_size1 = (2,2) pool_size2 = (2,2) pool_size3 = (1,4) pool_size4 = (1,4) elif self.scale == 2: pool_size1 = (2,2) pool_size2 = (1,2) pool_size3 = (1,4) pool_size4 = (1,4) else: pool_size1 = (1,2) pool_size2 = (1,2) pool_size3 = (1,4) pool_size4 = (1,4) x = self.conv_block1(input, pool_size=pool_size1, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block2(x, pool_size=pool_size2, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block3(x, pool_size=pool_size3, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) x = self.conv_block4(x, pool_size=pool_size4, pool_type='avg') x = F.dropout(x, p=0.2, training=self.training) return x class MeanPool(nn.Module): def __init__(self, pooldim=1): super().__init__() self.pooldim = pooldim def forward(self, logits, decision): return torch.mean(decision, dim=self.pooldim) class ResPool(nn.Module): def __init__(self, pooldim=1): super().__init__() self.pooldim = pooldim self.linPool = LinearSoftPool(pooldim=1) class AutoExpPool(nn.Module): def __init__(self, outputdim=10, pooldim=1): super().__init__() self.outputdim = outputdim self.alpha = nn.Parameter(torch.full((outputdim, ), 1)) self.pooldim = pooldim def forward(self, logits, decision): scaled = self.alpha * decision # \alpha * P(Y|x) in the paper return (logits * torch.exp(scaled)).sum( self.pooldim) / torch.exp(scaled).sum(self.pooldim) class SoftPool(nn.Module): def __init__(self, T=1, pooldim=1): super().__init__() self.pooldim = pooldim self.T = T def forward(self, logits, decision): w = torch.softmax(decision / self.T, dim=self.pooldim) return torch.sum(decision * w, dim=self.pooldim) class AutoPool(nn.Module): """docstring for AutoPool""" def __init__(self, outputdim=10, pooldim=1): super().__init__() self.outputdim = outputdim self.alpha = nn.Parameter(torch.ones(outputdim)) self.dim = pooldim def forward(self, logits, decision): scaled = self.alpha * decision # \alpha * P(Y|x) in the paper weight = torch.softmax(scaled, dim=self.dim) return torch.sum(decision * weight, dim=self.dim) # B x C class ExtAttentionPool(nn.Module): def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): super().__init__() self.inputdim = inputdim self.outputdim = outputdim self.pooldim = pooldim self.attention = nn.Linear(inputdim, outputdim) nn.init.zeros_(self.attention.weight) nn.init.zeros_(self.attention.bias) self.activ = nn.Softmax(dim=self.pooldim) def forward(self, logits, decision): # Logits of shape (B, T, D), decision of shape (B, T, C) w_x = self.activ(self.attention(logits) / self.outputdim) h = (logits.permute(0, 2, 1).contiguous().unsqueeze(-2) * w_x.unsqueeze(-1)).flatten(-2).contiguous() return torch.sum(h, self.pooldim) class AttentionPool(nn.Module): """docstring for AttentionPool""" def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): super().__init__() self.inputdim = inputdim self.outputdim = outputdim self.pooldim = pooldim self.transform = nn.Linear(inputdim, outputdim) self.activ = nn.Softmax(dim=self.pooldim) self.eps = 1e-7 def forward(self, logits, decision): # Input is (B, T, D) # B, T , D w = self.activ(torch.clamp(self.transform(logits), -15, 15)) detect = (decision * w).sum( self.pooldim) / (w.sum(self.pooldim) + self.eps) # B, T, D return detect class Block2D(nn.Module): def __init__(self, cin, cout, kernel_size=3, padding=1): super().__init__() self.block = nn.Sequential( nn.BatchNorm2d(cin), nn.Conv2d(cin, cout, kernel_size=kernel_size, padding=padding, bias=False), nn.LeakyReLU(inplace=True, negative_slope=0.1)) def forward(self, x): return self.block(x) class AudioCNN(nn.Module): def __init__(self, classes_num): super(AudioCNN, self).__init__() self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) self.fc1 = nn.Linear(512,128,bias=True) self.fc = nn.Linear(128, classes_num, bias=True) self.init_weights() def init_weights(self): init_layer(self.fc) def forward(self, input): ''' Input: (batch_size, times_steps, freq_bins)''' # [128, 801, 168] --> [128,1,801,168] x = input[:, None, :, :] '''(batch_size, 1, times_steps, freq_bins)''' x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') # 128,64,400,84 x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') # 128,128,200,42 x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') # 128,256,100,21 x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') # 128,512,50,10 '''(batch_size, feature_maps, time_steps, freq_bins)''' x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_stpes) # 128,512,50 (x, _) = torch.max(x, dim=2) # (batch_size, feature_maps) 128,512 x = self.fc1(x) # 128,128 output = self.fc(x) # 128,10 return x,output def extract(self,input): '''Input: (batch_size, times_steps, freq_bins)''' x = input[:, None, :, :] x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') '''(batch_size, feature_maps, time_steps, freq_bins)''' x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_stpes) (x, _) = torch.max(x, dim=2) # (batch_size, feature_maps) x = self.fc1(x) # 128,128 return x def parse_poolingfunction(poolingfunction_name='mean', **kwargs): """parse_poolingfunction A heler function to parse any temporal pooling Pooling is done on dimension 1 :param poolingfunction_name: :param **kwargs: """ poolingfunction_name = poolingfunction_name.lower() if poolingfunction_name == 'mean': return MeanPool(pooldim=1) elif poolingfunction_name == 'max': return MaxPool(pooldim=1) elif poolingfunction_name == 'linear': return LinearSoftPool(pooldim=1) elif poolingfunction_name == 'expalpha': return AutoExpPool(outputdim=kwargs['outputdim'], pooldim=1) elif poolingfunction_name == 'soft': return SoftPool(pooldim=1) elif poolingfunction_name == 'auto': return AutoPool(outputdim=kwargs['outputdim']) elif poolingfunction_name == 'attention': return AttentionPool(inputdim=kwargs['inputdim'], outputdim=kwargs['outputdim']) class conv1d(nn.Module): def __init__(self, nin, nout, kernel_size=3, stride=1, padding='VALID', dilation=1): super(conv1d, self).__init__() if padding == 'VALID': dconv_pad = 0 elif padding == 'SAME': dconv_pad = dilation * ((kernel_size - 1) // 2) else: raise ValueError("Padding Mode Error!") self.conv = nn.Conv1d(nin, nout, kernel_size=kernel_size, stride=stride, padding=dconv_pad) self.act = nn.ReLU() self.init_layer(self.conv) def init_layer(self, layer, nonlinearity='relu'): """Initialize a Linear or Convolutional layer. """ nn.init.kaiming_normal_(layer.weight, nonlinearity=nonlinearity) nn.init.constant_(layer.bias, 0.1) def forward(self, x): out = self.act(self.conv(x)) return out class Atten_1(nn.Module): def __init__(self, input_dim, context=2, dropout_rate=0.2): super(Atten_1, self).__init__() self._matrix_k = nn.Linear(input_dim, input_dim // 4) self._matrix_q = nn.Linear(input_dim, input_dim // 4) self.relu = nn.ReLU() self.context = context self._dropout_layer = nn.Dropout(dropout_rate) self.init_layer(self._matrix_k) self.init_layer(self._matrix_q) def init_layer(self, layer, nonlinearity='leaky_relu'): """Initialize a Linear or Convolutional layer. """ nn.init.kaiming_uniform_(layer.weight, nonlinearity=nonlinearity) if hasattr(layer, 'bias'): if layer.bias is not None: layer.bias.data.fill_(0.) def forward(self, input_x): k_x = input_x k_x = self.relu(self._matrix_k(k_x)) k_x = self._dropout_layer(k_x) # print('k_x ',k_x.shape) q_x = input_x[:, self.context, :] # print('q_x ',q_x.shape) q_x = q_x[:, None, :] # print('q_x1 ',q_x.shape) q_x = self.relu(self._matrix_q(q_x)) q_x = self._dropout_layer(q_x) # print('q_x2 ',q_x.shape) x_ = torch.matmul(k_x, q_x.transpose(-2, -1) / math.sqrt(k_x.size(-1))) # print('x_ ',x_.shape) x_ = x_.squeeze(2) alpha = F.softmax(x_, dim=-1) att_ = alpha # print('alpha ',alpha) alpha = alpha.unsqueeze(2).repeat(1,1,input_x.shape[2]) # print('alpha ',alpha) # alpha = alpha.view(alpha.size(0), alpha.size(1), alpha.size(2), 1) out = alpha * input_x # print('out ', out.shape) # out = out.mean(2) out = out.mean(1) # print('out ',out.shape) # assert 1==2 #y = alpha * input_x #return y, att_ out = input_x[:, self.context, :] + out return out class Fusion(nn.Module): def __init__(self, inputdim, inputdim2, n_fac): super().__init__() self.fuse_layer1 = conv1d(inputdim, inputdim2*n_fac,1) self.fuse_layer2 = conv1d(inputdim2, inputdim2*n_fac,1) self.avg_pool = nn.AvgPool1d(n_fac, stride=n_fac) # 沿着最后一个维度进行pooling def forward(self,embedding,mix_embed): embedding = embedding.permute(0,2,1) fuse1_out = self.fuse_layer1(embedding) # [2, 501, 2560] ,512*5, 1D卷积融合,spk_embeding ,扩大其维度 fuse1_out = fuse1_out.permute(0,2,1) mix_embed = mix_embed.permute(0,2,1) fuse2_out = self.fuse_layer2(mix_embed) # [2, 501, 2560] ,512*5, 1D卷积融合,spk_embeding ,扩大其维度 fuse2_out = fuse2_out.permute(0,2,1) as_embs = torch.mul(fuse1_out, fuse2_out) # 相乘 [2, 501, 2560] # (10, 501, 512) as_embs = self.avg_pool(as_embs) # [2, 501, 512] 相当于 2560//5 return as_embs class CDur_fusion(nn.Module): def __init__(self, inputdim, outputdim, **kwargs): super().__init__() self.features = nn.Sequential( Block2D(1, 32), nn.LPPool2d(4, (2, 4)), Block2D(32, 128), Block2D(128, 128), nn.LPPool2d(4, (2, 4)), Block2D(128, 128), Block2D(128, 128), nn.LPPool2d(4, (1, 4)), nn.Dropout(0.3), ) with torch.no_grad(): rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] self.gru = nn.GRU(128, 128, bidirectional=True, batch_first=True) self.fusion = Fusion(128,2) self.fc = nn.Linear(256,256) self.outputlayer = nn.Linear(256, outputdim) self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding): # batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = self.fusion(embedding,x) #x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class CDur(nn.Module): def __init__(self, inputdim, outputdim,time_resolution, **kwargs): super().__init__() self.features = nn.Sequential( Block2D(1, 32), nn.LPPool2d(4, (2, 4)), Block2D(32, 128), Block2D(128, 128), nn.LPPool2d(4, (2, 4)), Block2D(128, 128), Block2D(128, 128), nn.LPPool2d(4, (2, 4)), nn.Dropout(0.3), ) with torch.no_grad(): rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] self.gru = nn.GRU(256, 256, bidirectional=True, batch_first=True) self.fc = nn.Linear(512,256) self.outputlayer = nn.Linear(256, outputdim) self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding,one_hot=None): # batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class CDur_big(nn.Module): def __init__(self, inputdim, outputdim, **kwargs): super().__init__() self.features = nn.Sequential( Block2D(1, 64), Block2D(64, 64), nn.LPPool2d(4, (2, 2)), Block2D(64, 128), Block2D(128, 128), nn.LPPool2d(4, (2, 2)), Block2D(128, 256), Block2D(256, 256), nn.LPPool2d(4, (2, 4)), Block2D(256, 512), Block2D(512, 512), nn.LPPool2d(4, (1, 4)), nn.Dropout(0.3),) with torch.no_grad(): rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] self.gru = nn.GRU(640, 512, bidirectional=True, batch_first=True) self.fc = nn.Linear(1024,256) self.outputlayer = nn.Linear(256, outputdim) self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding): # batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class CDur_GLU(nn.Module): def __init__(self, inputdim, outputdim, **kwargs): super().__init__() self.features = Mul_scale_GLU() # with torch.no_grad(): # rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape # rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] self.gru = nn.GRU(640, 512,1, bidirectional=True, batch_first=True) # previous is 640 # self.gru = LSTMModel(640, 512,1) self.fc = nn.Linear(1024,256) self.outputlayer = nn.Linear(256, outputdim) # self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding,one_hot=None): # batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) # print('x ',x.shape) # assert 1==2 embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) # x = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class CDur_CNN14(nn.Module): def __init__(self, inputdim, outputdim,time_resolution,**kwargs): super().__init__() if time_resolution==125: self.features = Cnn10(8) elif time_resolution == 250: #print('time_resolution ',time_resolution) self.features = Cnn10(4) elif time_resolution == 500: self.features = Cnn10(2) else: self.features = Cnn10(0) with torch.no_grad(): rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] # self.features = Cnn10() self.gru = nn.GRU(640, 512, bidirectional=True, batch_first=True) # self.gru = LSTMModel(640, 512,1) self.fc = nn.Linear(1024,256) self.outputlayer = nn.Linear(256, outputdim) # self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding,one_hot=None): batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) # print('x ',x.shape) # assert 1==2 embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) # x = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class CDur_CNN_mul_scale(nn.Module): def __init__(self, inputdim, outputdim,time_resolution,**kwargs): super().__init__() if time_resolution==125: self.features = Cnn10_mul_scale(8) elif time_resolution == 250: #print('time_resolution ',time_resolution) self.features = Cnn10_mul_scale(4) elif time_resolution == 500: self.features = Cnn10_mul_scale(2) else: self.features = Cnn10_mul_scale(0) # with torch.no_grad(): # rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape # rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] # self.features = Cnn10() self.gru = nn.GRU(640, 512, bidirectional=True, batch_first=True) # self.gru = LSTMModel(640, 512,1) self.fc = nn.Linear(1024,256) self.outputlayer = nn.Linear(256, outputdim) # self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding,one_hot=None): # print('x ',x.shape) # assert 1==2 batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) # print('x ',x.shape) # assert 1==2 embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) # x = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class CDur_CNN_mul_scale_fusion(nn.Module): def __init__(self, inputdim, outputdim, time_resolution,**kwargs): super().__init__() if time_resolution==125: self.features = Cnn10_mul_scale(8) elif time_resolution == 250: #print('time_resolution ',time_resolution) self.features = Cnn10_mul_scale(4) elif time_resolution == 500: self.features = Cnn10_mul_scale(2) else: self.features = Cnn10_mul_scale(0) # with torch.no_grad(): # rnn_input_dim = self.features(torch.randn(1, 1, 500,inputdim)).shape # rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] # self.features = Cnn10() self.gru = nn.GRU(512, 512, bidirectional=True, batch_first=True) # self.gru = LSTMModel(640, 512,1) self.fc = nn.Linear(1024,256) self.fusion = Fusion(128,512,2) self.outputlayer = nn.Linear(256, outputdim) # self.features.apply(init_weights) self.outputlayer.apply(init_weights) def forward(self, x, embedding,one_hot=None): # print('x ',x.shape) # assert 1==2 batch, time, dim = x.shape x = x.unsqueeze(1) # (b,1,t,d) x = self.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,512) # print('x ',x.shape) # assert 1==2 embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) x = self.fusion(embedding, x) #x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.gru.flatten_parameters() x, _ = self.gru(x) # x torch.Size([16, 125, 256]) # x = self.gru(x) # x torch.Size([16, 125, 256]) x = self.fc(x) decision_time = torch.softmax(self.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0],decision_up class RaDur_fusion(nn.Module): def __init__(self, model_config, inputdim, outputdim, time_resolution, **kwargs): super().__init__() self.encoder = Cnn14() self.detection = CDur_CNN_mul_scale_fusion(inputdim, outputdim, time_resolution) self.softmax = nn.Softmax(dim=2) #self.temperature = 5 # if model_config['pre_train']: # self.encoder.load_state_dict(torch.load(model_config['encoder_path'])['model']) # self.detection.load_state_dict(torch.load(model_config['CDur_path'])) self.q = nn.Linear(128,128) self.k = nn.Linear(128,128) self.q_ee = nn.Linear(128, 128) self.k_ee = nn.Linear(128, 128) self.temperature = 11.3 # sqrt(128) self.att_pool = model_config['att_pool'] self.enhancement = model_config['enhancement'] self.tao = model_config['tao'] self.top = model_config['top'] self.bn = nn.BatchNorm1d(128) self.EE_fusion = Fusion(128, 128, 4) def get_w(self,q,k): q = self.q(q) k = self.k(k) q = q.unsqueeze(1) attn = torch.bmm(q, k.transpose(1, 2)) attn = attn/self.temperature attn = self.softmax(attn) return attn def get_w_ee(self,q,k): q = self.q_ee(q) k = self.k_ee(k) q = q.unsqueeze(1) attn = torch.bmm(q, k.transpose(1, 2)) attn = attn/self.temperature attn = self.softmax(attn) return attn def attention_pooling(self, embeddings, mean_embedding): att_pool_w = self.get_w(mean_embedding,embeddings) embedding = torch.bmm(att_pool_w, embeddings).squeeze(1) # print(embedding.shape) # print(att_pool_w.shape) # print(att_pool_w[0]) # assert 1==2 return embedding def select_topk_embeddings(self, scores, embeddings, k): _, idx_DESC = scores.sort(descending=True, dim=1) # 根据分数进行排序 top_k = _[:,:k] # print('top_k ', top_k) # top_k = top_k.mean(1) idx_topk = idx_DESC[:, :k] # 取top_k个 # print('index ', idx_topk) idx_topk = idx_topk.unsqueeze(2).expand([-1, -1, embeddings.shape[2]]) selected_embeddings = torch.gather(embeddings, 1, idx_topk) return selected_embeddings,top_k def sum_with_attention(self, embedding, top_k, selected_embeddings): # print('embedding ',embedding) # print('selected_embeddings ',selected_embeddings.shape) att_1 = self.get_w_ee(embedding, selected_embeddings) att_1 = att_1.squeeze(1) #print('att_1 ',att_1.shape) larger = top_k > self.tao # print('larger ',larger) top_k = top_k*larger # print('top_k ',top_k.shape) # print('top_k ',top_k) att_1 = att_1*top_k #print('att_1 ',att_1.shape) # assert 1==2 att_2 = att_1.unsqueeze(2).repeat(1,1,128) Es = selected_embeddings*att_2 return Es def orcal_EE(self, x, embedding, label): batch, time, dim = x.shape mixture_embedding = self.encoder(x) # 8, 125, 128 mixture_embedding = mixture_embedding.transpose(1,2) mixture_embedding = self.bn(mixture_embedding) mixture_embedding = mixture_embedding.transpose(1,2) x = x.unsqueeze(1) # (b,1,t,d) x = self.detection.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) embedding_pre = embedding.unsqueeze(1) embedding_pre = embedding_pre.repeat(1, x.shape[1], 1) f = self.detection.fusion(embedding_pre, x) # the first stage results #f = torch.cat((x, embedding_pre), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.detection.gru.flatten_parameters() f, _ = self.detection.gru(f) # x torch.Size([16, 125, 256]) f = self.detection.fc(f) decision_time = torch.softmax(self.detection.outputlayer(f),dim=2) # x torch.Size([16, 125, 2]) selected_embeddings, top_k = self.select_topk_embeddings(decision_time[:,:,0], mixture_embedding, self.top) selected_embeddings = self.sum_with_attention(embedding, top_k, selected_embeddings) # add the weight mix_embedding = selected_embeddings.mean(1).unsqueeze(1) # mix_embedding = mix_embedding.repeat(1, x.shape[1], 1) embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) mix_embedding = self.EE_fusion(mix_embedding, embedding) # 使用神经网络进行融合 # mix_embedding2 = selected_embeddings2.mean(1) #mix_embedding = embedding + mix_embedding # 直接相加 # new detection results # embedding_now = mix_embedding.unsqueeze(1) # embedding_now = embedding_now.repeat(1, x.shape[1], 1) f_now = self.detection.fusion(mix_embedding, x) #f_now = torch.cat((x, embedding_now), dim=2) # f_now, _ = self.detection.gru(f_now) # x torch.Size([16, 125, 256]) f_now = self.detection.fc(f_now) decision_time_now = torch.softmax(self.detection.outputlayer(f_now), dim=2) # x torch.Size([16, 125, 2]) top_k = top_k.mean(1) # get avg score,higher score will have more weight larger = top_k > self.tao top_k = top_k * larger top_k = top_k/2.0 # print('top_k ',top_k) # assert 1==2 # print('tok_k[ ',top_k.shape) # print('decision_time ',decision_time.shape) # print('decision_time_now ',decision_time_now.shape) neg_w = top_k.unsqueeze(1).unsqueeze(2) neg_w = neg_w.repeat(1, decision_time_now.shape[1], decision_time_now.shape[2]) # print('neg_w ',neg_w.shape) #print('neg_w ',neg_w[:,0:10,0]) pos_w = 1-neg_w #print('pos_w ',pos_w[:,0:10,0]) decision_time_final = decision_time*pos_w + neg_w*decision_time_now #print('decision_time_final ',decision_time_final[0,0:10,0]) # print(decision_time_final[0,:,:]) #assert 1==2 return decision_time_final def forward(self, x, ref, label=None): batch, time, dim = x.shape logit = torch.zeros(1).cuda() embeddings = self.encoder(ref) mean_embedding = embeddings.mean(1) if self.att_pool == True: mean_embedding = self.bn(mean_embedding) embeddings = embeddings.transpose(1,2) embeddings = self.bn(embeddings) embeddings = embeddings.transpose(1,2) embedding = self.attention_pooling(embeddings, mean_embedding) else: embedding = mean_embedding if self.enhancement == True: decision_time = self.orcal_EE(x, embedding, label) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), # [16, 2, 125] time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0], decision_up, logit x = x.unsqueeze(1) # (b,1,t,d) x = self.detection.features(x) # x = x.transpose(1, 2).contiguous().flatten(-2) # 重新拷贝一份x,之后推平-2:-1之间的维度 # (b,125,128) embedding = embedding.unsqueeze(1) embedding = embedding.repeat(1, x.shape[1], 1) # x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] x = self.detection.fusion(embedding, x) # embedding = embedding.unsqueeze(1) # embedding = embedding.repeat(1, x.shape[1], 1) # x = torch.cat((x, embedding), dim=2) # [B, T, 128 + emb_dim] if not hasattr(self, '_flattened'): self.detection.gru.flatten_parameters() x, _ = self.detection.gru(x) # x torch.Size([16, 125, 256]) x = self.detection.fc(x) decision_time = torch.softmax(self.detection.outputlayer(x),dim=2) # x torch.Size([16, 125, 2]) decision_up = torch.nn.functional.interpolate( decision_time.transpose(1, 2), time, # 501 mode='linear', align_corners=False).transpose(1, 2) # 从125插值回 501 ?--> (16,501,2) return decision_time[:,:,0], decision_up, logit