Update modeling_prot2text.py
Browse files- modeling_prot2text.py +2 -1
modeling_prot2text.py
CHANGED
@@ -11,6 +11,7 @@ import numpy as np
|
|
11 |
from transformers.generation.configuration_utils import GenerationConfig
|
12 |
from transformers.generation.logits_process import LogitsProcessorList
|
13 |
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
|
|
14 |
|
15 |
from .pdb2graph import PDB2Graph, download_alphafold_structure
|
16 |
from .graphs import *
|
@@ -78,7 +79,7 @@ class EncoderRGCN(PreTrainedModel):
|
|
78 |
return out.unsqueeze(1)
|
79 |
|
80 |
class Prot2TextModel(PreTrainedModel):
|
81 |
-
config_class =
|
82 |
_keys_to_ignore_on_load_missing = [r"transformer"]
|
83 |
base_model_prefix = "decoder"
|
84 |
def __init__(self, config):
|
|
|
11 |
from transformers.generation.configuration_utils import GenerationConfig
|
12 |
from transformers.generation.logits_process import LogitsProcessorList
|
13 |
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
14 |
+
from .configuration_prot2text import Prot2TextConfig
|
15 |
|
16 |
from .pdb2graph import PDB2Graph, download_alphafold_structure
|
17 |
from .graphs import *
|
|
|
79 |
return out.unsqueeze(1)
|
80 |
|
81 |
class Prot2TextModel(PreTrainedModel):
|
82 |
+
config_class = Prot2TextConfig
|
83 |
_keys_to_ignore_on_load_missing = [r"transformer"]
|
84 |
base_model_prefix = "decoder"
|
85 |
def __init__(self, config):
|