File size: 992 Bytes
60ffb3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
def get_config_phase1():
    return {
        "data_dir": "./data",
        "clip_model_name": "openai/clip-vit-base-patch16",
        "phi2_model_name": "microsoft/phi-2",
        "train_batch_size": 2,
        "val_batch_size": 1,
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "epochs": 2,
        "max_tokens": 20,
        "clip_embed": 768,
        "phi_embed": 2560,
        "num_workers": 4, 
        "ckpts": "./ckpts"
    }

def get_config_phase2():
    return {
        "data_dir": "./data",
        "clip_model_name": "openai/clip-vit-base-patch16",
        "phi2_model_name": "microsoft/phi-2",
        "train_batch_size": 1,
        "val_batch_size": 1,
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "epochs": 10,
        "max_tokens": 100,
        "clip_embed": 768,
        "phi_embed": 2560,
        "num_workers": 0, 
        "ckpts": "./ckpts",
        "vocab_size": 51200
    }