chat-gradio / personalized-chat-bot /scripts /train_bloom_personachat.py
j.gilyazev
add personalized-chat-bot
0766044
raw history blame
No virus
4.71 kB
import argparse
import torch.cuda
from datasets import load_dataset
import json
import os
import transformers
from torch.utils.data import Subset
import wandb
import numpy as np
import gc
from models.personality_clustering import PersonalityClustering
from util.bloom_trainer import BloomTrainer
from util.data import PersonaChatDataset
from util.metrics import perplexity
from petals.client.remote_model import DistributedBloomForCausalLM
"""Пример запуска
python -m scripts.train_bloom_personachat --persona-ids 6 --config scripts/config.json --prompt-path data/models/
"""
DEFAULT_CLUSTERING_MODEL = './data/models/personality_clustering_500_paraphrase-MiniLM-L6-v2_k-means.pkl'
MAX_VAL_DATA_SIZE = 4
def load_config(path):
with open(path, 'r') as f:
config = json.load(f)
return argparse.Namespace(**config)
def main():
args = parse_args()
persona_clustering = PersonalityClustering()
persona_clustering.load(args.clustering_model_path)
config = load_config(args.config)
tokenizer = transformers.BloomTokenizerFast.from_pretrained(config.MODEL_NAME)
tokenizer.padding_side = config.PADDING_SIDE
tokenizer.model_max_length = config.MODEL_MAX_LENGTH
dataset = load_dataset(config.PERSONACHAT_DATASET_NAME)
personachat_train_dataset = PersonaChatDataset(persona_clustering,
dataset['train'],
tokenizer)
personachat_val_dataset = PersonaChatDataset(persona_clustering,
dataset['validation'],
tokenizer)
for id in args.persona_ids:
prompt_path = os.path.join(args.prompt_path, f'{id}_persona_prompt_embedding.pt')
train_dataset = personachat_train_dataset[id]
val_dataset = personachat_val_dataset[id]
honest_validation = True
if len(val_dataset) < 4:
val_dataset = personachat_train_dataset[id]
honest_validation = False
# для ускорения обрежем размер валидации до некоторой границы
if len(val_dataset) > MAX_VAL_DATA_SIZE:
subset_indexes = np.random.choice(len(val_dataset), MAX_VAL_DATA_SIZE, replace=False)
val_dataset = Subset(val_dataset, subset_indexes)
# train_dataset.shuffle()
wandb_run = wandb.init(
project=args.wandb_project,
config={
'lr': config.LR,
'batch_size': config.BATCH_SIZE,
'persona_id': id,
'device': config.DEVICE,
'model_name': config.MODEL_NAME,
'n_epoch': config.N_EPOCH,
'honest_validation': honest_validation
},
name=f'id{id}',
reinit=True
)
if len(config.INITIAL_PEERS) == 0:
model = DistributedBloomForCausalLM.from_pretrained(
config.MODEL_NAME,
pre_seq_len=config.NUM_PREFIX_TOKENS,
tuning_mode=config.TUNING_MODE
).to(config.DEVICE)
else:
model = DistributedBloomForCausalLM.from_pretrained(
config.MODEL_NAME,
initial_peers=config.INITIAL_PEERS,
pre_seq_len=config.NUM_PREFIX_TOKENS,
tuning_mode=config.TUNING_MODE
).to(config.DEVICE)
trainer = BloomTrainer(model, config, train_dataset, val_dataset, wandb_run, prompt_path)
trainer.train()
eval_perplexity = trainer.evaluate(perplexity)
trainer.save_model(prompt_path)
wandb_run.log({'perplexity': eval_perplexity, 'model_path': prompt_path})
del model
gc.collect()
torch.cuda.empty_cache()
def parse_args(args=None):
parser = argparse.ArgumentParser(add_help=True,
description="bloom training script")
parser.add_argument('--persona-ids', type=int, nargs='+',
help='Ids of persona')
parser.add_argument('-clustering-model-path', '--clustering-model-path', type=str,
default=DEFAULT_CLUSTERING_MODEL,
help='Path to clustering model')
parser.add_argument('--config', type=str, help='Path to training config file')
parser.add_argument('--prompt-path', type=str,
help='Path to dir with trained soft prompts')
parser.add_argument('--wandb-project', type=str, default='test_bloom_personachat_176b_v3')
args = parser.parse_args(args)
return args
if __name__ == '__main__':
main()