# -*- encoding: utf-8 -*- ''' @File : pretrain_cogvideo.py @Time : 2021/10/06 00:58:32 @Author : Wenyi Hong @Contact : hwy22@mails.tsinghua.edu.cn ''' # here put the import lib import os import sys import math import random import torch import argparse import numpy as np from icetk import icetk as tokenizer tokenizer.add_special_tokens(['', '', '']) from models.cogvideo_model import CogVideoModel from SwissArmyTransformer import mpu, get_args from SwissArmyTransformer.training.deepspeed_training import training_main from SwissArmyTransformer.data_utils import BinaryDataset def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None): # Extract batch size and sequence length. batch_size, seq_length = data.size() assert attention_mask_totxt is not None layout = args.layout assert seq_length == layout[-1] n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long() frame_len = layout[1]-layout[0] position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long, device=data.device) for i in range(batch_size): torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]], dtype=torch.long, device=data.device) torch.arange(512, 512+layout[2]-layout[0], out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device) return position_ids def get_batch(data_iterator, args, timers): # Items and their type. keys = ['text', 'loss_mask', 'attention_mask_totxt'] datatype = torch.int64 # Broadcast data. timers('data loader').start() if data_iterator is not None: data = next(data_iterator) else: data = None timers('data loader').stop() data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() loss_mask = data_b['loss_mask'].float() attention_mask_totxt = data_b['attention_mask_totxt'].float() labels = tokens_[:, 1:].clone().contiguous() loss_mask = loss_mask[:, 1:].contiguous() tokens = tokens_[:, :-1].clone().contiguous() for idx in range(args.layout[0], args.layout[2], 400): tokens[:, idx] = tokenizer[''] # Get the masks and postition ids. position_ids = get_masks_and_position_ids_video( tokens, attention_mask_totxt=attention_mask_totxt, args=args ) attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1) # Convert if args.fp16: attention_mask_totxt = attention_mask_totxt.half() return tokens, labels, loss_mask, attention_mask_totxt, position_ids def forward_step(data_iterator, model, args, timers): """Forward step.""" # Get the batch. timers('batch generator').start() tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch( data_iterator, args, timers) timers('batch generator').stop() # Forward model. logits, *mems = model(tokens, position_ids, attention_mask_totxt) # ======= hyper params =======# perframe_len = 400 text_len=64 frame_num = 5 logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous() losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:]) # scaling loss mask loss_mask = loss_mask[:, text_len:].reshape(-1) losses_1d = losses.reshape(-1) * loss_mask loss = torch.sum(losses_1d) / loss_mask.sum() # ===================== Log partial losses ======================== # log_loss_dict = {} bs = losses.shape[0] if args.cogvideo_stage == 1: for i in range(frame_num): log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) else: for i in range(1, frame_num-1): log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) # ===================== END OF BLOCK ======================= # return loss, log_loss_dict def create_dataset_function(path, args): dataset_layout = [64, 464, 2064] input_layout = [64, 464, 2064] # frame_num = 6 # frame_interval = 2 # DEBUG!!! def process_fn(row): row = row.astype(np.int64) text = row[:dataset_layout[0]] frames = row[dataset_layout[0]:] if text[0] == tokenizer['']: text = text[1:] # due to our way of data processing if args.cogvideo_stage == 1: text, loss_mask, frames = make_text_video_generation(text, frames) else: text, loss_mask, frames = mask_video_frame_interpolation(text, frames) n_pad = input_layout[0] - len(text) parts = [ np.array([tokenizer['']] * n_pad, dtype=np.int64), text, np.array([tokenizer['']], dtype=np.int64), frames, ] ret = np.concatenate(parts, axis=0) attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad)) return {'text': ret, 'loss_mask': loss_mask, 'attention_mask_totxt': attention_mask_totxt, } return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1]) def make_text_video_generation(text, frames): input_layout = [64, 464, 2064] text = text[text!= tokenizer['']][:input_layout[0]] # dataset format: 1.0秒{text} ... loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位 return text, loss_mask, frames def mask_video_frame_interpolation(text, frames): input_layout = [64, 464, 2064] frame_len = input_layout[1]-input_layout[0] # text format: 1.0秒 {text} text = text[text!= tokenizer['']][:input_layout[0]] loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[1]-input_layout[0]) + [0] * (input_layout[1]-input_layout[0]) + [1] * (input_layout[1]-input_layout[0]) + [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位 return text, loss_mask, frames if __name__ == '__main__': py_parser = argparse.ArgumentParser(add_help=False) py_parser.add_argument('--txt-loss-scale', type=float, default=1) CogVideoModel.add_model_specific_args(py_parser) known, args_list = py_parser.parse_known_args() args = get_args(args_list) args = argparse.Namespace(**vars(args), **vars(known)) args.layout = [int(x) for x in args.layout.split(',')] training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function)