Spaces:
Runtime error
Runtime error
import argparse | |
import torch.cuda | |
from datasets import load_dataset | |
import json | |
import os | |
import transformers | |
from torch.utils.data import Subset | |
import wandb | |
import numpy as np | |
import gc | |
from models.personality_clustering import PersonalityClustering | |
from util.bloom_trainer import BloomTrainer | |
from util.data import PersonaChatDataset | |
from util.metrics import perplexity | |
from petals.client.remote_model import DistributedBloomForCausalLM | |
"""Пример запуска | |
python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/ | |
""" | |
DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl' | |
MAX_VAL_DATA_SIZE = 4 | |
def load_config(path): | |
with open(path, 'r') as f: | |
config = json.load(f) | |
return argparse.Namespace(**config) | |
def main(): | |
args = parse_args() | |
persona_clustering = PersonalityClustering() | |
persona_clustering.load(args.clustering_model_path) | |
config = load_config(args.config) | |
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME) | |
tokenizer.padding_side = config.PADDING_SIDE | |
tokenizer.model_max_length = config.MODEL_MAX_LENGTH | |
dataset = load_dataset(config.PERSONACHAT_DATASET_NAME) | |
personachat_train_dataset = PersonaChatDataset(persona_clustering, | |
dataset['train'], | |
tokenizer) | |
personachat_val_dataset = PersonaChatDataset(persona_clustering, | |
dataset['validation'], | |
tokenizer) | |
for id in args.persona_ids: | |
prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt') | |
train_dataset = personachat_train_dataset[id] | |
val_dataset = personachat_val_dataset[id] | |
honest_validation = True | |
if len(val_dataset) < 4: | |
val_dataset = personachat_train_dataset[id] | |
honest_validation = False | |
# для ускорения обрежем размер валидации до некоторой границы | |
if len(val_dataset) > MAX_VAL_DATA_SIZE: | |
subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False) | |
val_dataset = Subset(val_dataset, subset_indexes) | |
# train_dataset.shuffle() | |
wandb_run = wandb.init( | |
project=args.wandb_project, | |
config={ | |
'lr': config.LR, | |
'batch_size': config.BATCH_SIZE, | |
'persona_id': id, | |
'device': config.DEVICE, | |
'model_name': config.MODEL_NAME, | |
'n_epoch': config.N_EPOCH, | |
'honest_validation': honest_validation | |
}, | |
name=f'id{id}', | |
reinit=True | |
) | |
if len(config.INITIAL_PEERS) == 0: | |
model = DistributedBloomForCausalLM.from_pretrained( | |
config.MODEL_NAME, | |
pre_seq_len=config.NUM_PREFIX_TOKENS, | |
tuning_mode=config.TUNING_MODE | |
).to(config.DEVICE) | |
else: | |
model = DistributedBloomForCausalLM.from_pretrained( | |
config.MODEL_NAME, | |
initial_peers=config.INITIAL_PEERS, | |
pre_seq_len=config.NUM_PREFIX_TOKENS, | |
tuning_mode=config.TUNING_MODE | |
).to(config.DEVICE) | |
trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path) | |
trainer.train() | |
eval_perplexity = trainer.evaluate(perplexity) | |
trainer.save_model(prompt_path) | |
wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path}) | |
del model | |
gc.collect() | |
torch.cuda.empty_cache() | |
def parse_args(args=None): | |
parser = argparse.ArgumentParser(add_help=True, | |
description="bloom training script") | |
parser.add_argument('--persona-ids', type=int, nargs='+', | |
help='Ids of persona') | |
parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str, | |
default=DEFAULT_CLUSTERING_MODEL, | |
help='Path to clustering model') | |
parser.add_argument('--config', type=str, help='Path to training config file') | |
parser.add_argument('--prompt-path', type=str, | |
help='Path to dir with trained soft prompts') | |
parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3') | |
args = parser.parse_args(args) | |
return args | |
if __name__ == '__main__': | |
main() |