""" Prot2Text configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)


class Prot2TextConfig(PretrainedConfig):
    model_type = "prot2text"
    keys_to_ignore_at_inference = ["past_key_values"]
    _keys_to_ignore_on_load_missing = [r"transformer"]

    def __init__(
        self,
        cross_esm_graph=True,
        decoder_start_token_id=50257,
        early_stopping=True,
        eos_token_id=50258,
        bos_token_id=50257,
        esm=True,
        esm_model_name="facebook/esm2_t6_8M_UR50D",
        gpt_model_name="gpt2",
        length_penalty=2.0,
        max_new_tokens=256,
        no_repeat_ngram_size=3,
        pad_token_id=50256,
        prot2text_version="1.1",
        rgcn=True,
        rgc_input_dim=67,
        rgcn_n_layers=6,
        gpt_config=None,
        esm_config=None,
        **kwargs,
    ):
        self.cross_esm_graph = cross_esm_graph
        self.decoder_start_token_id = decoder_start_token_id
        self.early_stopping = early_stopping
        self.eos_token_id = eos_token_id
        self.esm = esm
        self.esm_model_name = esm_model_name
        self.gpt_model_name = gpt_model_name
        self.length_penalty = length_penalty
        self.max_new_tokens = max_new_tokens
        self.no_repeat_ngram_size = no_repeat_ngram_size
        self.pad_token_id = pad_token_id
        self.prot2text_version = prot2text_version
        self.rgcn = rgcn
        self.rgc_input_dim = rgc_input_dim
        self.rgcn_n_layers = rgcn_n_layers
        if gpt_config is None:
            self.gpt_config = AutoConfig.from_pretrained(gpt_model_name,
                                        _name_or_path= gpt_model_name,
                                        is_encoder_decoder=True,
                                        use_cache=False,
                                        add_cross_attention=True,
                                        bos_token_id=bos_token_id,
                                        decoder_start_token_id=decoder_start_token_id,
                                        eos_token_id=eos_token_id,
                                        max_new_tokens=max_new_tokens,
                                        pad_token_id=50256,
                                        vocab_size=50259,
                                        num_beams=1,
                                        max_length=256,
                                        min_length=1).to_dict()
        else:
            self.gpt_config = gpt_config
        if esm_config is None:
            self.esm_config = AutoConfig.from_pretrained(esm_model_name).to_dict()
        self.esm_config = esm_config

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)