TAI2T_Multimodel / configs.py
Vasudevakrishna's picture
New space
60ffb3e
raw
history blame
992 Bytes
import torch
def get_config_phase1():
return {
"data_dir": "./data",
"clip_model_name": "openai/clip-vit-base-patch16",
"phi2_model_name": "microsoft/phi-2",
"train_batch_size": 2,
"val_batch_size": 1,
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"epochs": 2,
"max_tokens": 20,
"clip_embed": 768,
"phi_embed": 2560,
"num_workers": 4,
"ckpts": "./ckpts"
}
def get_config_phase2():
return {
"data_dir": "./data",
"clip_model_name": "openai/clip-vit-base-patch16",
"phi2_model_name": "microsoft/phi-2",
"train_batch_size": 1,
"val_batch_size": 1,
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"epochs": 10,
"max_tokens": 100,
"clip_embed": 768,
"phi_embed": 2560,
"num_workers": 0,
"ckpts": "./ckpts",
"vocab_size": 51200
}