habdine commited on
Commit
b83bac8
1 Parent(s): 192c05b

Update modeling_prot2text.py

Browse files
Files changed (1) hide show
  1. modeling_prot2text.py +2 -1
modeling_prot2text.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
  from transformers.generation.configuration_utils import GenerationConfig
12
  from transformers.generation.logits_process import LogitsProcessorList
13
  from transformers.generation.stopping_criteria import StoppingCriteriaList
 
14
 
15
  from .pdb2graph import PDB2Graph, download_alphafold_structure
16
  from .graphs import *
@@ -78,7 +79,7 @@ class EncoderRGCN(PreTrainedModel):
78
  return out.unsqueeze(1)
79
 
80
  class Prot2TextModel(PreTrainedModel):
81
- config_class = PretrainedConfig
82
  _keys_to_ignore_on_load_missing = [r"transformer"]
83
  base_model_prefix = "decoder"
84
  def __init__(self, config):
 
11
  from transformers.generation.configuration_utils import GenerationConfig
12
  from transformers.generation.logits_process import LogitsProcessorList
13
  from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from .configuration_prot2text import Prot2TextConfig
15
 
16
  from .pdb2graph import PDB2Graph, download_alphafold_structure
17
  from .graphs import *
 
79
  return out.unsqueeze(1)
80
 
81
  class Prot2TextModel(PreTrainedModel):
82
+ config_class = Prot2TextConfig
83
  _keys_to_ignore_on_load_missing = [r"transformer"]
84
  base_model_prefix = "decoder"
85
  def __init__(self, config):