|
import torch |
|
def get_config_phase1(): |
|
return { |
|
"data_dir": "./data", |
|
"clip_model_name": "openai/clip-vit-base-patch16", |
|
"phi2_model_name": "microsoft/phi-2", |
|
"train_batch_size": 2, |
|
"val_batch_size": 1, |
|
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
"epochs": 2, |
|
"max_tokens": 20, |
|
"clip_embed": 768, |
|
"phi_embed": 2560, |
|
"num_workers": 4, |
|
"ckpts": "./ckpts" |
|
} |
|
|
|
def get_config_phase2(): |
|
return { |
|
"data_dir": "./data", |
|
"clip_model_name": "openai/clip-vit-base-patch16", |
|
"phi2_model_name": "microsoft/phi-2", |
|
"train_batch_size": 1, |
|
"val_batch_size": 1, |
|
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
"epochs": 10, |
|
"max_tokens": 100, |
|
"clip_embed": 768, |
|
"phi_embed": 2560, |
|
"num_workers": 0, |
|
"ckpts": "./ckpts", |
|
"vocab_size": 51200 |
|
} |