in2IN / config.py
pabloruizponce's picture
Upload model
f908e9d verified
raw history blame
No virus
1.63 kB
from transformers import PretrainedConfig
class in2INConfig(PretrainedConfig):
def __init__(self,
num_layers=8,
num_heads=8,
dropout=0.1,
input_dim=262,
latent_dim=1024,
ff_size=2048,
activation="gelu",
diffusion_steps=1000,
beta_scheduler="cosine",
sampler="uniform",
motion_rep="global",
finetune=False,
text_encoder="clip",
t_bar=700,
control="text",
strategy="ddim50",
cfg_weight=3,
cfg_weight_interaction=3,
cfg_weight_individual=1,
mode="interaction",
**kwargs):
self.NUM_LAYERS = num_layers
self.NUM_HEADS = num_heads
self.DROPOUT = dropout
self.INPUT_DIM = input_dim
self.LATENT_DIM = latent_dim
self.FF_SIZE = ff_size
self.ACTIVATION = activation
self.DIFFUSION_STEPS = diffusion_steps
self.BETA_SCHEDULER = beta_scheduler
self.SAMPLER = sampler
self.MOTION_REP = motion_rep
self.FINETUNE = finetune
self.TEXT_ENCODER = text_encoder
self.T_BAR = t_bar
self.CONTROL = control
self.STRATEGY = strategy
self.CFG_WEIGHT = cfg_weight
self.CFG_WEIGHT_INTERACTION = cfg_weight_interaction
self.CFG_WEIGHT_INDIVIDUAL = cfg_weight_individual
self.MODE = mode
super().__init__(**kwargs)