|
import math |
|
import copy |
|
import torch |
|
from torch.nn import functional as F |
|
import torch.nn as nn |
|
|
|
from .model_proteinglm_clm import ProteinGLMForGeneration |
|
|
|
|
|
class MSAGPT(ProteinGLMForGeneration): |
|
def __init__(self, args, transformer=None, **kwargs): |
|
super().__init__( |
|
args, |
|
transformer=transformer, |
|
**kwargs |
|
) |
|
|
|
@classmethod |
|
def add_model_specific_args(cls, parser): |
|
group = parser.add_argument_group('MSAGPT-inference', 'MSAGPT inference Configurations') |
|
return super().add_model_specific_args(parser) |
|
|
|
class FineTuneMSAGPT(MSAGPT): |
|
def __init__(self, args, transformer=None, **kwargs): |
|
super().__init__( |
|
args, |
|
transformer=transformer, |
|
**kwargs |
|
) |
|
pass |