from collections import OrderedDict
from typing import List, Union, Dict

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

import fairseq

# class Model(nn.Module):
#     def __init__(self):
#         super().__init__()
#         # The model needs to be a nn.Module for finetuning, not required for representation extraction
#         self.model1 = nn.Linear(1, HIDDEN_DIM)
#         self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)

#     def forward(self, wavs, upstream_feature_selection="hidden_states"):
#         # You can do task-specified pre- / post-processing based on upstream_feature_selection
#         hidden = self.model1(wavs)
#         # hidden: (batch_size, max_len, hidden_dim)

#         feature = self.model2(hidden)
#         # feature: (batch_size, max_len, hidden_dim)

#         return [hidden, feature]

class UpstreamExpert(nn.Module):
    def __init__(
        self,
        ckpt: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt",
        upstream_feature_selection: str = "hidden_states",
        **kwargs):
        """
        Args:
            ckpt:
                The checkpoint path for loading your pretrained weights.
                Should be fixed as model.pt for SUPERB Challenge.
            upstream_feature_selection:
                The value could be 
                'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
                You can use it to control which task-specified pre- / post-processing to do.
        """
        super().__init__()
        self.name = "[Example UpstreamExpert]"
        self.upstream_feature_selection = upstream_feature_selection

        # # You can use ckpt to load your pretrained weights
        # ckpt = torch.load(ckpt, map_location="cpu")
        # self.model = Model()
        # self.model.load_state_dict(ckpt)

        assert version.parse(fairseq.__version__) > version.parse(
            "0.10.2"
        ), "Please install the fairseq master branch."

        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [ckpt]
        )
        self.model = model[0]
        self.task = task







    def get_downsample_rates(self, key: str) -> int:
        """
        Since we do not do any downsampling in this example upstream
        All keys' corresponding representations have downsample rate of 1
        Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
        """
        return 320

    def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
        """
        When the returning Dict contains the List with more than one Tensor,
        those Tensors should be in the same shape to train a weighted-sum on them.
        """
        wavs_silence = []


        #Total 7 settings

        #original
        wavs_silence = wavs


        #front, 5
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//5).to(wav.device)
            wavs_silence.append(torch.cat((temp_wav, wav)))

        #front, 10
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//10).to(wav.device)
            wavs_silence.append(torch.cat((temp_wav, wav)))

        #front, 20
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//20).to(wav.device)
            wavs_silence.append(torch.cat((temp_wav, wav)))

        #end, 5
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//5).to(wav.device)
            wavs_silence.append(torch.cat((wav, temp_wav)))

        #end, 10
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//10).to(wav.device)
            wavs_silence.append(torch.cat((wav, temp_wav)))

        #end, 20
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//20).to(wav.device)
            wavs_silence.append(torch.cat((wav, temp_wav)))


        wavs = wavs_silence

        device = wavs[0].device
        wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
        wav_padding_mask = ~torch.lt(
            torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
            wav_lengths.unsqueeze(1),
        )
        padded_wav = pad_sequence(wavs, batch_first=True)

        features, feat_padding_mask = self.model.extract_features(
            padded_wav,
            padding_mask=wav_padding_mask,
            mask=None,
        )


        # Deprecated! Do not do any task-specified postprocess below
        # You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
        # The "hidden_states" key will be used as default in many cases
        # Others keys in this example are presented for SUPERB Challenge
        return {
            "hidden_states": features,
        }