Spaces:
Running
Running
from dataclasses import dataclass, field | |
from transformers import TrainingArguments | |
class LiveTrainingArguments(TrainingArguments): | |
live_version: str = 'live1+' | |
system_prompt: str = ( | |
"A multimodal AI assistant is helping users with some activities." | |
" Below is their conversation, interleaved with the list of video frames received by the assistant." | |
) | |
train_datasets: list[str] = None | |
eval_datasets: list[str] = None | |
stream_loss_weight: float = 1.0 | |
llm_pretrained: str = 'meta-llama/Meta-Llama-3-8B-Instruct' | |
vision_pretrained: str = 'google/siglip-large-patch16-384' | |
lora_modules: str = "model.*(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)|lm_head$" | |
lora_r: int = 128 | |
lora_alpha: int = 256 | |
finetune_modules: list[str] = field(default_factory=lambda: ['connector']) | |
frame_fps: int = 2 # for training. inference can be 10 | |
frame_token_cls: bool = None | |
frame_token_pooled: list[int] = None | |
frame_resolution: int = 384 | |
frame_token_interval: str = None | |
frame_token_interval_threshold: float = 0.0 | |
augmentation: bool = False | |
attn_implementation: str = 'flash_attention_2' | |
output_dir: str = 'outputs/debug' | |
class LiveOneTrainingArguments(LiveTrainingArguments): | |
live_version: str = 'live1' | |
frame_token_cls: bool = True | |
frame_num_tokens: int = 1 | |
frame_token_interval: str = '' | |
embed_mark: str = '2fps_384_1' | |
max_num_frames: int = 7200 # 1h, 2fps, 7200 frames | |
class LiveOnePlusTrainingArguments(LiveTrainingArguments): | |
live_version: str = 'live1+' | |
frame_token_cls: bool = True | |
frame_token_pooled: list[int] = field(default_factory=lambda: [3,3]) | |
frame_num_tokens: int = 10 # 1+3x3 | |
embed_mark: str = '2fps_384_1+3x3' | |
frame_token_interval: str = ',' | |
max_num_frames: int = 1200 # 10min, 2fps, 1200 frames | |
def get_args_class(live_version: str): | |
if live_version == 'live1': | |
return LiveOneTrainingArguments | |
elif live_version == 'live1+': | |
return LiveOnePlusTrainingArguments | |
raise NotImplementedError | |