from transformers import PretrainedConfig | |
from typing import List | |
class LoCoNetConfig(PretrainedConfig): | |
model_type = "loconet" | |
def __init__( | |
self, | |
num_speakers: int = 3, | |
clip_length: int = 200, | |
av: str = "speaker_temporal", | |
av_layers: int = 3, | |
adjust_attention: bool = False, | |
**kwargs, | |
): | |
self.num_speakers = num_speakers | |
self.clip_length = clip_length | |
self.av = av | |
self.av_layers = av_layers | |
self.adjust_attention = adjust_attention | |
super().__init__(**kwargs) | |