mojtaba-nafez's picture
fix config.py
6a63e41
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import BertTokenizer, BertModel, BertConfig, BertTokenizerFast
from transformers import XLMRobertaModel, XLMRobertaConfig
import os
"""
Configurations
"""
file_dirname = os.path.dirname(__file__) #in case it is needed for relative paths
dataset_path = os.path.join(file_dirname, "data/Dataset-Merged.json") # dataset path for PoemTextModel training, validation and test
image_path = "" # path to append to the image filenames of datasets used for CLIPModel training
random_seed = 3 # the seed used to shuffle dataset with
# what percentage of dataset will be used for each set?
train_propotion = 0.85
val_propotion = 0.05
# The remaining will be used as the test set
batch_size = 128
num_workers = 0 # parameter of torch Dataloader
lr = 1e-3 # learning rate
weight_decay = 1e-3
patience = 2 # patience parameter for lr scheduler
factor = 0.5 # factor parameter for lr scheduler
epochs = 60
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Pretrained hugging face models chosen by poem_encoder_model
poem_encoder_dict = {
"Bert":{
"poem_encoder_pretrained_name": 'mitra-mir/BERT-Persian-Poetry',
},
"ALBERT":{
"poem_encoder_pretrained_name": 'mitra-mir/ALBERT-Persian-Poetry',
},
"ParsBERT":{
"poem_encoder_pretrained_name": 'HooshvareLab/bert-base-parsbert-uncased',
},
}
poem_encoder_model = "ParsBERT" ### Important! The base model for poem encoder (one of "Bert", "ALBERT" and "ParsBERT")
# keep this an empty string if you want to use the pretrained weights from
# huggingface (poem_encoder_dict[poem_encoder_model])/a fresh model.
# else give the path to encoder
poem_encoder_load_path = ""
# path to save encoder to
poem_encoder_save_path = "{}-poem-encoder".format(poem_encoder_model)
if poem_encoder_load_path:
poem_encoder_pretrained_name = poem_encoder_load_path
poem_tokenizer = poem_encoder_load_path
else:
poem_encoder_pretrained_name = poem_encoder_dict[poem_encoder_model]['poem_encoder_pretrained_name']
poem_tokenizer = poem_encoder_dict[poem_encoder_model]['poem_encoder_pretrained_name']
poem_embedding = 768 # embedding dim of poem encoder's output (for one token)
poems_max_length = 64 # max_length parameter when padding/truncating poems using poem tokenizer
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
poem_projection_load_path = os.path.join(file_dirname, "projections/{}_best_poem_projection.pt".format(poem_encoder_model))
# path to save projection to
poem_projection_save_path = "{}_best_poem_projection.pt".format(poem_encoder_model)
poem_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
# Pretrained hugging face models chosen by text_encoder_model
text_encoder_dict = {
"M-Bert":{
"text_encoder_pretrained_name": 'bert-base-multilingual-cased',
},
"XLM-RoBERTa":{
"text_encoder_pretrained_name": 'xlm-roberta-base',
},
"LaBSE":{
"text_encoder_pretrained_name": 'setu4993/LaBSE',
}
}
text_encoder_model = 'LaBSE' ### Important! The base model for text encoder (one of "M-Bert", "XLM-RoBERTa" and "LaBSE")
# keep this an empty string if you want to use the pretrained weights from huggingface/a fresh model. else give the path to encoder
text_encoder_load_path = ""
# path to save encoder to
text_encoder_save_path = "{}-text-encoder".format(text_encoder_model)
if text_encoder_load_path:
text_encoder_pretrained_name = text_encoder_load_path
text_tokenizer = text_encoder_load_path
else:
text_encoder_pretrained_name = text_encoder_dict[text_encoder_model]["text_encoder_pretrained_name"]
text_tokenizer = text_encoder_dict[text_encoder_model]["text_encoder_pretrained_name"]
text_embedding = 768 # embedding dim of text encoder's output (for one token)
text_max_length = 200 # max_length parameter when padding/truncating text using text tokenizer
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
text_projection_load_path = os.path.join(file_dirname, "projections/{}_best_text_projection.pt".format(text_encoder_model))
# path to save peojection to
text_projection_save_path = "{}_best_text_projection.pt".format(text_encoder_model)
text_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
image_encoder_model = 'resnet50' # image model name to load via timm library
# keep this an empty string if you want to use the pretrained weights from huggingface/a fresh model. else give the path to encoder
image_encoder_weights_load_path = ""
# path to save encoder weights to
image_encoder_weights_save_path = "{}_best_image_encoder.pt".format(image_encoder_model)
image_embedding = 2048 # embedding dim of image encoder's output (for one token)
# keep this an empty string if you want to use a freshly initialized projection module. else give the path to projection model
image_projection_load_path = ""
# path to save projection to
image_projection_save_path = "{}_best_image_projection.pt".format(image_encoder_model)
image_encoder_trainable = False # if set to false, this encoder's frozen and its weights won't be saved at all.
# classes of Tokenizer, Model and Config to use for each text/poem encoder model
tokenizers = {"ALBERT": AutoTokenizer, "M-Bert": BertTokenizer, "XLM-RoBERTa": AutoTokenizer, "ParsBERT":AutoTokenizer, "Bert":AutoTokenizer, "LaBSE": BertTokenizerFast}
encoders = {"ALBERT": AutoModel, "M-Bert": BertModel, "XLM-RoBERTa":XLMRobertaModel, "ParsBERT": AutoModel, "Bert":AutoModel, "LaBSE": BertModel}
configs = {"ALBERT": AutoConfig, "M-Bert": BertConfig, "XLM-RoBERTa": XLMRobertaConfig, "ParsBERT": AutoConfig, "Bert":AutoConfig, "LaBSE": BertConfig}
temperature = 1.0 # temperature parameter for scaling dot similarities
# image size
size = 224
# for projection head; used for poem, text and image encoders
projection_dim = 1024 # projection embedding dim (output of models dim)
dropout = 0.1 # fraction of the output of fc layer in projection head to be zeroed.