CogVideo / model.py
hysts
Add files
8aac645
raw history blame
No virus
51.9 kB
# This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py
from __future__ import annotations
import argparse
import functools
import logging
import pathlib
import sys
import tempfile
import time
from typing import Any
import gradio as gr
import imageio.v2 as iio
import numpy as np
import torch
from icetk import IceTokenizer
from SwissArmyTransformer import get_args
from SwissArmyTransformer.arguments import set_random_seed
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
from SwissArmyTransformer.resources import auto_create
app_dir = pathlib.Path(__file__).parent
submodule_dir = app_dir / 'CogVideo'
sys.path.insert(0, submodule_dir.as_posix())
from coglm_strategy import CoglmStrategy
from models.cogvideo_cache_model import CogVideoCacheModel
from sr_pipeline import DirectSuperResolution
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
ICETK_MODEL_DIR = app_dir / 'icetk_models'
def get_masks_and_position_ids_stage1(data, textlen, framelen):
# Extract batch size and sequence length.
tokens = data
seq_length = len(data[0])
# Attention mask (lower triangular).
attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
device=data.device)
attention_mask[:, :textlen, textlen:] = 0
attention_mask[:, textlen:, textlen:].tril_()
attention_mask.unsqueeze_(1)
# Unaligned version
position_ids = torch.zeros(seq_length,
dtype=torch.long,
device=data.device)
torch.arange(textlen,
out=position_ids[:textlen],
dtype=torch.long,
device=data.device)
torch.arange(512,
512 + seq_length - textlen,
out=position_ids[textlen:],
dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0)
return tokens, attention_mask, position_ids
def get_masks_and_position_ids_stage2(data, textlen, framelen):
# Extract batch size and sequence length.
tokens = data
seq_length = len(data[0])
# Attention mask (lower triangular).
attention_mask = torch.ones((1, textlen + framelen, textlen + framelen),
device=data.device)
attention_mask[:, :textlen, textlen:] = 0
attention_mask[:, textlen:, textlen:].tril_()
attention_mask.unsqueeze_(1)
# Unaligned version
position_ids = torch.zeros(seq_length,
dtype=torch.long,
device=data.device)
torch.arange(textlen,
out=position_ids[:textlen],
dtype=torch.long,
device=data.device)
frame_num = (seq_length - textlen) // framelen
assert frame_num == 5
torch.arange(512,
512 + framelen,
out=position_ids[textlen:textlen + framelen],
dtype=torch.long,
device=data.device)
torch.arange(512 + framelen * 2,
512 + framelen * 3,
out=position_ids[textlen + framelen:textlen + framelen * 2],
dtype=torch.long,
device=data.device)
torch.arange(512 + framelen * (frame_num - 1),
512 + framelen * frame_num,
out=position_ids[textlen + framelen * 2:textlen +
framelen * 3],
dtype=torch.long,
device=data.device)
torch.arange(512 + framelen * 1,
512 + framelen * 2,
out=position_ids[textlen + framelen * 3:textlen +
framelen * 4],
dtype=torch.long,
device=data.device)
torch.arange(512 + framelen * 3,
512 + framelen * 4,
out=position_ids[textlen + framelen * 4:textlen +
framelen * 5],
dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0)
return tokens, attention_mask, position_ids
def my_update_mems(hiddens, mems_buffers, mems_indexs,
limited_spatial_channel_mem, text_len, frame_len):
if hiddens is None:
return None, mems_indexs
mem_num = len(hiddens)
ret_mem = []
with torch.no_grad():
for id in range(mem_num):
if hiddens[id][0] is None:
ret_mem.append(None)
else:
if id == 0 and limited_spatial_channel_mem and mems_indexs[
id] + hiddens[0][0].shape[1] >= text_len + frame_len:
if mems_indexs[id] == 0:
for layer, hidden in enumerate(hiddens[id]):
mems_buffers[id][
layer, :, :text_len] = hidden.expand(
mems_buffers[id].shape[1], -1,
-1)[:, :text_len]
new_mem_len_part2 = (mems_indexs[id] +
hiddens[0][0].shape[1] -
text_len) % frame_len
if new_mem_len_part2 > 0:
for layer, hidden in enumerate(hiddens[id]):
mems_buffers[id][
layer, :, text_len:text_len +
new_mem_len_part2] = hidden.expand(
mems_buffers[id].shape[1], -1,
-1)[:, -new_mem_len_part2:]
mems_indexs[id] = text_len + new_mem_len_part2
else:
for layer, hidden in enumerate(hiddens[id]):
mems_buffers[id][layer, :,
mems_indexs[id]:mems_indexs[id] +
hidden.shape[1]] = hidden.expand(
mems_buffers[id].shape[1], -1, -1)
mems_indexs[id] += hidden.shape[1]
ret_mem.append(mems_buffers[id][:, :, :mems_indexs[id]])
return ret_mem, mems_indexs
def calc_next_tokens_frame_begin_id(text_len, frame_len, total_len):
# The fisrt token's position id of the frame that the next token belongs to;
if total_len < text_len:
return None
return (total_len - text_len) // frame_len * frame_len + text_len
def my_filling_sequence(
model,
tokenizer,
args,
seq,
batch_size,
get_masks_and_position_ids,
text_len,
frame_len,
strategy=BaseStrategy(),
strategy2=BaseStrategy(),
mems=None,
log_text_attention_weights=0, # default to 0: no artificial change
mode_stage1=True,
enforce_no_swin=False,
guider_seq=None,
guider_text_len=0,
guidance_alpha=1,
limited_spatial_channel_mem=False, # 空间通道的存储限制在本帧内
**kw_args):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
cache, should be first mems.shape[1] parts of context_tokens.
mems are the first-level citizens here, but we don't assume what is memorized.
input mems are used when multi-phase generation.
'''
if guider_seq is not None:
logger.debug('Using Guidance In Inference')
if limited_spatial_channel_mem:
logger.debug("Limit spatial-channel's mem to current frame")
assert len(seq.shape) == 2
# building the initial tokens, attention_mask, and position_ids
actual_context_length = 0
while seq[-1][
actual_context_length] >= 0: # the last seq has least given tokens
actual_context_length += 1 # [0, context_length-1] are given
assert actual_context_length > 0
current_frame_num = (actual_context_length - text_len) // frame_len
assert current_frame_num >= 0
context_length = text_len + current_frame_num * frame_len
tokens, attention_mask, position_ids = get_masks_and_position_ids(
seq, text_len, frame_len)
tokens = tokens[..., :context_length]
input_tokens = tokens.clone()
if guider_seq is not None:
guider_index_delta = text_len - guider_text_len
guider_tokens, guider_attention_mask, guider_position_ids = get_masks_and_position_ids(
guider_seq, guider_text_len, frame_len)
guider_tokens = guider_tokens[..., :context_length -
guider_index_delta]
guider_input_tokens = guider_tokens.clone()
for fid in range(current_frame_num):
input_tokens[:, text_len + 400 * fid] = tokenizer['<start_of_image>']
if guider_seq is not None:
guider_input_tokens[:, guider_text_len +
400 * fid] = tokenizer['<start_of_image>']
attention_mask = attention_mask.type_as(next(
model.parameters())) # if fp16
# initialize generation
counter = context_length - 1 # Last fixed index is ``counter''
index = 0 # Next forward starting index, also the length of cache.
mems_buffers_on_GPU = False
mems_indexs = [0, 0]
mems_len = [(400 + 74) if limited_spatial_channel_mem else 5 * 400 + 74,
5 * 400 + 74]
mems_buffers = [
torch.zeros(args.num_layers,
batch_size,
mem_len,
args.hidden_size * 2,
dtype=next(model.parameters()).dtype)
for mem_len in mems_len
]
if guider_seq is not None:
guider_attention_mask = guider_attention_mask.type_as(
next(model.parameters())) # if fp16
guider_mems_buffers = [
torch.zeros(args.num_layers,
batch_size,
mem_len,
args.hidden_size * 2,
dtype=next(model.parameters()).dtype)
for mem_len in mems_len
]
guider_mems_indexs = [0, 0]
guider_mems = None
torch.cuda.empty_cache()
# step-by-step generation
while counter < len(seq[0]) - 1:
# we have generated counter+1 tokens
# Now, we want to generate seq[counter + 1],
# token[:, index: counter+1] needs forwarding.
if index == 0:
group_size = 2 if (input_tokens.shape[0] == batch_size
and not mode_stage1) else batch_size
logits_all = None
for batch_idx in range(0, input_tokens.shape[0], group_size):
logits, *output_per_layers = model(
input_tokens[batch_idx:batch_idx + group_size, index:],
position_ids[..., index:counter + 1],
attention_mask, # TODO memlen
mems=mems,
text_len=text_len,
frame_len=frame_len,
counter=counter,
log_text_attention_weights=log_text_attention_weights,
enforce_no_swin=enforce_no_swin,
**kw_args)
logits_all = torch.cat(
(logits_all,
logits), dim=0) if logits_all is not None else logits
mem_kv01 = [[o['mem_kv'][0] for o in output_per_layers],
[o['mem_kv'][1] for o in output_per_layers]]
next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
text_len, frame_len, mem_kv01[0][0].shape[1])
for id, mem_kv in enumerate(mem_kv01):
for layer, mem_kv_perlayer in enumerate(mem_kv):
if limited_spatial_channel_mem and id == 0:
mems_buffers[id][
layer, batch_idx:batch_idx + group_size, :
text_len] = mem_kv_perlayer.expand(
min(group_size,
input_tokens.shape[0] - batch_idx), -1,
-1)[:, :text_len]
mems_buffers[id][layer, batch_idx:batch_idx+group_size, text_len:text_len+mem_kv_perlayer.shape[1]-next_tokens_frame_begin_id] =\
mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, next_tokens_frame_begin_id:]
else:
mems_buffers[id][
layer, batch_idx:batch_idx +
group_size, :mem_kv_perlayer.
shape[1]] = mem_kv_perlayer.expand(
min(group_size,
input_tokens.shape[0] - batch_idx), -1,
-1)
mems_indexs[0], mems_indexs[1] = mem_kv01[0][0].shape[
1], mem_kv01[1][0].shape[1]
if limited_spatial_channel_mem:
mems_indexs[0] -= (next_tokens_frame_begin_id - text_len)
mems = [
mems_buffers[id][:, :, :mems_indexs[id]] for id in range(2)
]
logits = logits_all
# Guider
if guider_seq is not None:
guider_logits_all = None
for batch_idx in range(0, guider_input_tokens.shape[0],
group_size):
guider_logits, *guider_output_per_layers = model(
guider_input_tokens[batch_idx:batch_idx + group_size,
max(index -
guider_index_delta, 0):],
guider_position_ids[
...,
max(index - guider_index_delta, 0):counter + 1 -
guider_index_delta],
guider_attention_mask,
mems=guider_mems,
text_len=guider_text_len,
frame_len=frame_len,
counter=counter - guider_index_delta,
log_text_attention_weights=log_text_attention_weights,
enforce_no_swin=enforce_no_swin,
**kw_args)
guider_logits_all = torch.cat(
(guider_logits_all, guider_logits), dim=0
) if guider_logits_all is not None else guider_logits
guider_mem_kv01 = [[
o['mem_kv'][0] for o in guider_output_per_layers
], [o['mem_kv'][1] for o in guider_output_per_layers]]
for id, guider_mem_kv in enumerate(guider_mem_kv01):
for layer, guider_mem_kv_perlayer in enumerate(
guider_mem_kv):
if limited_spatial_channel_mem and id == 0:
guider_mems_buffers[id][
layer, batch_idx:batch_idx + group_size, :
guider_text_len] = guider_mem_kv_perlayer.expand(
min(group_size,
input_tokens.shape[0] - batch_idx),
-1, -1)[:, :guider_text_len]
guider_next_tokens_frame_begin_id = calc_next_tokens_frame_begin_id(
guider_text_len, frame_len,
guider_mem_kv_perlayer.shape[1])
guider_mems_buffers[id][layer, batch_idx:batch_idx+group_size, guider_text_len:guider_text_len+guider_mem_kv_perlayer.shape[1]-guider_next_tokens_frame_begin_id] =\
guider_mem_kv_perlayer.expand(min(group_size, input_tokens.shape[0]-batch_idx), -1, -1)[:, guider_next_tokens_frame_begin_id:]
else:
guider_mems_buffers[id][
layer, batch_idx:batch_idx +
group_size, :guider_mem_kv_perlayer.
shape[1]] = guider_mem_kv_perlayer.expand(
min(group_size,
input_tokens.shape[0] - batch_idx),
-1, -1)
guider_mems_indexs[0], guider_mems_indexs[
1] = guider_mem_kv01[0][0].shape[1], guider_mem_kv01[
1][0].shape[1]
if limited_spatial_channel_mem:
guider_mems_indexs[0] -= (
guider_next_tokens_frame_begin_id -
guider_text_len)
guider_mems = [
guider_mems_buffers[id][:, :, :guider_mems_indexs[id]]
for id in range(2)
]
guider_logits = guider_logits_all
else:
if not mems_buffers_on_GPU:
if not mode_stage1:
torch.cuda.empty_cache()
for idx, mem in enumerate(mems):
mems[idx] = mem.to(next(model.parameters()).device)
if guider_seq is not None:
for idx, mem in enumerate(guider_mems):
guider_mems[idx] = mem.to(
next(model.parameters()).device)
else:
torch.cuda.empty_cache()
for idx, mem_buffer in enumerate(mems_buffers):
mems_buffers[idx] = mem_buffer.to(
next(model.parameters()).device)
mems = [
mems_buffers[id][:, :, :mems_indexs[id]]
for id in range(2)
]
if guider_seq is not None:
for idx, guider_mem_buffer in enumerate(
guider_mems_buffers):
guider_mems_buffers[idx] = guider_mem_buffer.to(
next(model.parameters()).device)
guider_mems = [
guider_mems_buffers[id]
[:, :, :guider_mems_indexs[id]] for id in range(2)
]
mems_buffers_on_GPU = True
logits, *output_per_layers = model(
input_tokens[:, index:],
position_ids[..., index:counter + 1],
attention_mask, # TODO memlen
mems=mems,
text_len=text_len,
frame_len=frame_len,
counter=counter,
log_text_attention_weights=log_text_attention_weights,
enforce_no_swin=enforce_no_swin,
limited_spatial_channel_mem=limited_spatial_channel_mem,
**kw_args)
mem_kv0, mem_kv1 = [o['mem_kv'][0] for o in output_per_layers
], [o['mem_kv'][1] for o in output_per_layers]
if guider_seq is not None:
guider_logits, *guider_output_per_layers = model(
guider_input_tokens[:,
max(index - guider_index_delta, 0):],
guider_position_ids[...,
max(index -
guider_index_delta, 0):counter +
1 - guider_index_delta],
guider_attention_mask,
mems=guider_mems,
text_len=guider_text_len,
frame_len=frame_len,
counter=counter - guider_index_delta,
log_text_attention_weights=0,
enforce_no_swin=enforce_no_swin,
limited_spatial_channel_mem=limited_spatial_channel_mem,
**kw_args)
guider_mem_kv0, guider_mem_kv1 = [
o['mem_kv'][0] for o in guider_output_per_layers
], [o['mem_kv'][1] for o in guider_output_per_layers]
if not mems_buffers_on_GPU:
torch.cuda.empty_cache()
for idx, mem_buffer in enumerate(mems_buffers):
mems_buffers[idx] = mem_buffer.to(
next(model.parameters()).device)
if guider_seq is not None:
for idx, guider_mem_buffer in enumerate(
guider_mems_buffers):
guider_mems_buffers[idx] = guider_mem_buffer.to(
next(model.parameters()).device)
mems_buffers_on_GPU = True
mems, mems_indexs = my_update_mems([mem_kv0, mem_kv1],
mems_buffers, mems_indexs,
limited_spatial_channel_mem,
text_len, frame_len)
if guider_seq is not None:
guider_mems, guider_mems_indexs = my_update_mems(
[guider_mem_kv0, guider_mem_kv1], guider_mems_buffers,
guider_mems_indexs, limited_spatial_channel_mem,
guider_text_len, frame_len)
counter += 1
index = counter
logits = logits[:, -1].expand(batch_size,
-1) # [batch size, vocab size]
tokens = tokens.expand(batch_size, -1)
if guider_seq is not None:
guider_logits = guider_logits[:, -1].expand(batch_size, -1)
guider_tokens = guider_tokens.expand(batch_size, -1)
if seq[-1][counter].item() < 0:
# sampling
guided_logits = guider_logits + (
logits - guider_logits
) * guidance_alpha if guider_seq is not None else logits
if mode_stage1 and counter < text_len + 400:
tokens, mems = strategy.forward(guided_logits, tokens, mems)
else:
tokens, mems = strategy2.forward(guided_logits, tokens, mems)
if guider_seq is not None:
guider_tokens = torch.cat((guider_tokens, tokens[:, -1:]),
dim=1)
if seq[0][counter].item() >= 0:
for si in range(seq.shape[0]):
if seq[si][counter].item() >= 0:
tokens[si, -1] = seq[si, counter]
if guider_seq is not None:
guider_tokens[si,
-1] = guider_seq[si, counter -
guider_index_delta]
else:
tokens = torch.cat(
(tokens, seq[:, counter:counter + 1].clone().expand(
tokens.shape[0], 1).to(device=tokens.device,
dtype=tokens.dtype)),
dim=1)
if guider_seq is not None:
guider_tokens = torch.cat(
(guider_tokens,
guider_seq[:, counter - guider_index_delta:counter + 1 -
guider_index_delta].clone().expand(
guider_tokens.shape[0], 1).to(
device=guider_tokens.device,
dtype=guider_tokens.dtype)),
dim=1)
input_tokens = tokens.clone()
if guider_seq is not None:
guider_input_tokens = guider_tokens.clone()
if (index - text_len - 1) // 400 < (input_tokens.shape[-1] - text_len -
1) // 400:
boi_idx = ((index - text_len - 1) // 400 + 1) * 400 + text_len
while boi_idx < input_tokens.shape[-1]:
input_tokens[:, boi_idx] = tokenizer['<start_of_image>']
if guider_seq is not None:
guider_input_tokens[:, boi_idx -
guider_index_delta] = tokenizer[
'<start_of_image>']
boi_idx += 400
if strategy.is_done:
break
return strategy.finalize(tokens, mems)
class InferenceModel_Sequential(CogVideoCacheModel):
def __init__(self, args, transformer=None, parallel_output=True):
super().__init__(args,
transformer=transformer,
parallel_output=parallel_output,
window_size=-1,
cogvideo_stage=1)
# TODO: check it
def final_forward(self, logits, **kwargs):
logits_parallel = logits
logits_parallel = torch.nn.functional.linear(
logits_parallel.float(),
self.transformer.word_embeddings.weight[:20000].float())
return logits_parallel
class InferenceModel_Interpolate(CogVideoCacheModel):
def __init__(self, args, transformer=None, parallel_output=True):
super().__init__(args,
transformer=transformer,
parallel_output=parallel_output,
window_size=10,
cogvideo_stage=2)
# TODO: check it
def final_forward(self, logits, **kwargs):
logits_parallel = logits
logits_parallel = torch.nn.functional.linear(
logits_parallel.float(),
self.transformer.word_embeddings.weight[:20000].float())
return logits_parallel
def get_default_args() -> argparse.Namespace:
known = argparse.Namespace(generate_frame_num=5,
coglm_temperature2=0.89,
use_guidance_stage1=True,
use_guidance_stage2=False,
guidance_alpha=3.0,
stage_1=True,
stage_2=False,
both_stages=False,
parallel_size=1,
stage1_max_inference_batch_size=-1,
multi_gpu=False,
layout='64, 464, 2064',
window_size=10,
additional_seqlen=2000,
cogvideo_stage=1)
args_list = [
'--tokenizer-type',
'fake',
'--mode',
'inference',
'--distributed-backend',
'nccl',
'--fp16',
'--model-parallel-size',
'1',
'--temperature',
'1.05',
'--top_k',
'12',
'--sandwich-ln',
'--seed',
'1234',
'--num-workers',
'0',
'--batch-size',
'1',
'--max-inference-batch-size',
'8',
]
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
args.layout = [int(x) for x in args.layout.split(',')]
args.do_train = False
return args
class Model:
def __init__(self, only_first_stage: bool = False):
self.args = get_default_args()
if only_first_stage:
self.args.stage_1 = True
self.args.both_stages = False
else:
self.args.stage_1 = False
self.args.both_stages = True
self.tokenizer = self.load_tokenizer()
self.model_stage1, self.args = self.load_model_stage1()
self.model_stage2, self.args = self.load_model_stage2()
self.strategy_cogview2, self.strategy_cogvideo = self.load_strategies()
self.dsr = self.load_dsr()
self.device = torch.device(self.args.device)
def load_tokenizer(self) -> IceTokenizer:
logger.info('--- load_tokenizer ---')
start = time.perf_counter()
tokenizer = IceTokenizer(ICETK_MODEL_DIR.as_posix())
tokenizer.add_special_tokens(
['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return tokenizer
def load_model_stage1(
self) -> tuple[CogVideoCacheModel, argparse.Namespace]:
logger.info('--- load_model_stage1 ---')
start = time.perf_counter()
args = self.args
model_stage1, args = InferenceModel_Sequential.from_pretrained(
args, 'cogvideo-stage1')
model_stage1.eval()
if args.both_stages:
model_stage1 = model_stage1.cpu()
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return model_stage1, args
def load_model_stage2(
self) -> tuple[CogVideoCacheModel | None, argparse.Namespace]:
logger.info('--- load_model_stage2 ---')
start = time.perf_counter()
args = self.args
if args.both_stages:
model_stage2, args = InferenceModel_Interpolate.from_pretrained(
args, 'cogvideo-stage2')
model_stage2.eval()
if args.both_stages:
model_stage2 = model_stage2.cpu()
else:
model_stage2 = None
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return model_stage2, args
def load_strategies(self) -> tuple[CoglmStrategy, CoglmStrategy]:
logger.info('--- load_strategies ---')
start = time.perf_counter()
invalid_slices = [slice(self.tokenizer.num_image_tokens, None)]
strategy_cogview2 = CoglmStrategy(invalid_slices,
temperature=1.0,
top_k=16)
strategy_cogvideo = CoglmStrategy(
invalid_slices,
temperature=self.args.temperature,
top_k=self.args.top_k,
temperature2=self.args.coglm_temperature2)
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return strategy_cogview2, strategy_cogvideo
def load_dsr(self) -> DirectSuperResolution | None:
logger.info('--- load_dsr ---')
start = time.perf_counter()
if self.args.both_stages:
path = auto_create('cogview2-dsr', path=None)
dsr = DirectSuperResolution(self.args,
path,
max_bz=12,
onCUDA=False)
else:
dsr = None
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return dsr
@torch.inference_mode()
def process_stage1(self,
model,
seq_text,
duration,
video_raw_text=None,
video_guidance_text='视频',
image_text_suffix='',
batch_size=1):
process_start_time = time.perf_counter()
generate_frame_num = self.args.generate_frame_num
tokenizer = self.tokenizer
use_guide = self.args.use_guidance_stage1
if next(model.parameters()).device != self.device:
move_start_time = time.perf_counter()
logger.debug('moving stage 1 model to cuda')
model = model.to(self.device)
elapsed = time.perf_counter() - move_start_time
logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
if video_raw_text is None:
video_raw_text = seq_text
mbz = self.args.stage1_max_inference_batch_size if self.args.stage1_max_inference_batch_size > 0 else self.args.max_inference_batch_size
assert batch_size < mbz or batch_size % mbz == 0
frame_len = 400
# generate the first frame:
enc_text = tokenizer.encode(seq_text + image_text_suffix)
seq_1st = enc_text + [tokenizer['<start_of_image>']] + [-1] * 400
logger.info(
f'[Generating First Frame with CogView2] Raw text: {tokenizer.decode(enc_text):s}'
)
text_len_1st = len(seq_1st) - frame_len * 1 - 1
seq_1st = torch.tensor(seq_1st, dtype=torch.long,
device=self.device).unsqueeze(0)
output_list_1st = []
for tim in range(max(batch_size // mbz, 1)):
start_time = time.perf_counter()
output_list_1st.append(
my_filling_sequence(
model,
tokenizer,
self.args,
seq_1st.clone(),
batch_size=min(batch_size, mbz),
get_masks_and_position_ids=
get_masks_and_position_ids_stage1,
text_len=text_len_1st,
frame_len=frame_len,
strategy=self.strategy_cogview2,
strategy2=self.strategy_cogvideo,
log_text_attention_weights=1.4,
enforce_no_swin=True,
mode_stage1=True,
)[0])
elapsed = time.perf_counter() - start_time
logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
output_tokens_1st = torch.cat(output_list_1st, dim=0)
given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
401].unsqueeze(
1
) # given_tokens.shape: [bs, frame_num, 400]
# generate subsequent frames:
total_frames = generate_frame_num
enc_duration = tokenizer.encode(f'{float(duration)}秒')
if use_guide:
video_raw_text = video_raw_text + ' 视频'
enc_text_video = tokenizer.encode(video_raw_text)
seq = enc_duration + [tokenizer['<n>']] + enc_text_video + [
tokenizer['<start_of_image>']
] + [-1] * 400 * generate_frame_num
guider_seq = enc_duration + [tokenizer['<n>']] + tokenizer.encode(
video_guidance_text) + [tokenizer['<start_of_image>']
] + [-1] * 400 * generate_frame_num
logger.info(
f'[Stage1: Generating Subsequent Frames, Frame Rate {4/duration:.1f}] raw text: {tokenizer.decode(enc_text_video):s}'
)
text_len = len(seq) - frame_len * generate_frame_num - 1
guider_text_len = len(guider_seq) - frame_len * generate_frame_num - 1
seq = torch.tensor(seq, dtype=torch.long,
device=self.device).unsqueeze(0).repeat(
batch_size, 1)
guider_seq = torch.tensor(guider_seq,
dtype=torch.long,
device=self.device).unsqueeze(0).repeat(
batch_size, 1)
for given_frame_id in range(given_tokens.shape[1]):
seq[:, text_len + 1 + given_frame_id * 400:text_len + 1 +
(given_frame_id + 1) * 400] = given_tokens[:, given_frame_id]
guider_seq[:, guider_text_len + 1 +
given_frame_id * 400:guider_text_len + 1 +
(given_frame_id + 1) *
400] = given_tokens[:, given_frame_id]
output_list = []
if use_guide:
video_log_text_attention_weights = 0
else:
guider_seq = None
video_log_text_attention_weights = 1.4
for tim in range(max(batch_size // mbz, 1)):
input_seq = seq[:min(batch_size, mbz)].clone(
) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
guider_seq2 = (guider_seq[:min(batch_size, mbz)].clone()
if tim == 0 else guider_seq[mbz * tim:mbz *
(tim + 1)].clone()
) if guider_seq is not None else None
output_list.append(
my_filling_sequence(
model,
tokenizer,
self.args,
input_seq,
batch_size=min(batch_size, mbz),
get_masks_and_position_ids=
get_masks_and_position_ids_stage1,
text_len=text_len,
frame_len=frame_len,
strategy=self.strategy_cogview2,
strategy2=self.strategy_cogvideo,
log_text_attention_weights=video_log_text_attention_weights,
guider_seq=guider_seq2,
guider_text_len=guider_text_len,
guidance_alpha=self.args.guidance_alpha,
limited_spatial_channel_mem=True,
mode_stage1=True,
)[0])
output_tokens = torch.cat(output_list, dim=0)[:, 1 + text_len:]
if self.args.both_stages:
move_start_time = time.perf_counter()
logger.debug('moving stage 1 model to cpu')
model = model.cpu()
torch.cuda.empty_cache()
elapsed = time.perf_counter() - move_start_time
logger.debug(f'moving in model1 takes time: {elapsed:.2f}')
# decoding
res = []
for seq in output_tokens:
decoded_imgs = [
self.postprocess(
torch.nn.functional.interpolate(tokenizer.decode(
image_ids=seq.tolist()[i * 400:(i + 1) * 400]),
size=(480, 480))[0])
for i in range(total_frames)
]
res.append(decoded_imgs) # only the last image (target)
assert len(res) == batch_size
tokens = output_tokens[:, :+total_frames * 400].reshape(
-1, total_frames, 400).cpu()
elapsed = time.perf_counter() - process_start_time
logger.info(f'--- done ({elapsed=:.3f}) ---')
return tokens, res[0]
@torch.inference_mode()
def process_stage2(self,
model,
seq_text,
duration,
parent_given_tokens,
video_raw_text=None,
video_guidance_text='视频',
gpu_rank=0,
gpu_parallel_size=1):
process_start_time = time.perf_counter()
generate_frame_num = self.args.generate_frame_num
tokenizer = self.tokenizer
use_guidance = self.args.use_guidance_stage2
stage2_start_time = time.perf_counter()
if next(model.parameters()).device != self.device:
move_start_time = time.perf_counter()
logger.debug('moving stage-2 model to cuda')
model = model.to(self.device)
elapsed = time.perf_counter() - move_start_time
logger.debug(f'moving in stage-2 model takes time: {elapsed:.2f}')
try:
sample_num_allgpu = parent_given_tokens.shape[0]
sample_num = sample_num_allgpu // gpu_parallel_size
assert sample_num * gpu_parallel_size == sample_num_allgpu
parent_given_tokens = parent_given_tokens[gpu_rank *
sample_num:(gpu_rank +
1) *
sample_num]
except:
logger.critical('No frame_tokens found in interpolation, skip')
return False, []
# CogVideo Stage2 Generation
while duration >= 0.5: # TODO: You can change the boundary to change the frame rate
parent_given_tokens_num = parent_given_tokens.shape[1]
generate_batchsize_persample = (parent_given_tokens_num - 1) // 2
generate_batchsize_total = generate_batchsize_persample * sample_num
total_frames = generate_frame_num
frame_len = 400
enc_text = tokenizer.encode(seq_text)
enc_duration = tokenizer.encode(str(float(duration)) + '秒')
seq = enc_duration + [tokenizer['<n>']] + enc_text + [
tokenizer['<start_of_image>']
] + [-1] * 400 * generate_frame_num
text_len = len(seq) - frame_len * generate_frame_num - 1
logger.info(
f'[Stage2: Generating Frames, Frame Rate {int(4/duration):d}] raw text: {tokenizer.decode(enc_text):s}'
)
# generation
seq = torch.tensor(seq, dtype=torch.long,
device=self.device).unsqueeze(0).repeat(
generate_batchsize_total, 1)
for sample_i in range(sample_num):
for i in range(generate_batchsize_persample):
seq[sample_i * generate_batchsize_persample +
i][text_len + 1:text_len + 1 +
400] = parent_given_tokens[sample_i][2 * i]
seq[sample_i * generate_batchsize_persample +
i][text_len + 1 + 400:text_len + 1 +
800] = parent_given_tokens[sample_i][2 * i + 1]
seq[sample_i * generate_batchsize_persample +
i][text_len + 1 + 800:text_len + 1 +
1200] = parent_given_tokens[sample_i][2 * i + 2]
if use_guidance:
guider_seq = enc_duration + [
tokenizer['<n>']
] + tokenizer.encode(video_guidance_text) + [
tokenizer['<start_of_image>']
] + [-1] * 400 * generate_frame_num
guider_text_len = len(
guider_seq) - frame_len * generate_frame_num - 1
guider_seq = torch.tensor(
guider_seq, dtype=torch.long,
device=self.device).unsqueeze(0).repeat(
generate_batchsize_total, 1)
for sample_i in range(sample_num):
for i in range(generate_batchsize_persample):
guider_seq[sample_i * generate_batchsize_persample +
i][text_len + 1:text_len + 1 +
400] = parent_given_tokens[sample_i][2 *
i]
guider_seq[sample_i * generate_batchsize_persample +
i][text_len + 1 + 400:text_len + 1 +
800] = parent_given_tokens[sample_i][2 *
i +
1]
guider_seq[sample_i * generate_batchsize_persample +
i][text_len + 1 + 800:text_len + 1 +
1200] = parent_given_tokens[sample_i][2 *
i +
2]
video_log_text_attention_weights = 0
else:
guider_seq = None
guider_text_len = 0
video_log_text_attention_weights = 1.4
mbz = self.args.max_inference_batch_size
assert generate_batchsize_total < mbz or generate_batchsize_total % mbz == 0
output_list = []
start_time = time.perf_counter()
for tim in range(max(generate_batchsize_total // mbz, 1)):
input_seq = seq[:min(generate_batchsize_total, mbz)].clone(
) if tim == 0 else seq[mbz * tim:mbz * (tim + 1)].clone()
guider_seq2 = (
guider_seq[:min(generate_batchsize_total, mbz)].clone()
if tim == 0 else guider_seq[mbz * tim:mbz *
(tim + 1)].clone()
) if guider_seq is not None else None
output_list.append(
my_filling_sequence(
model,
tokenizer,
self.args,
input_seq,
batch_size=min(generate_batchsize_total, mbz),
get_masks_and_position_ids=
get_masks_and_position_ids_stage2,
text_len=text_len,
frame_len=frame_len,
strategy=self.strategy_cogview2,
strategy2=self.strategy_cogvideo,
log_text_attention_weights=
video_log_text_attention_weights,
mode_stage1=False,
guider_seq=guider_seq2,
guider_text_len=guider_text_len,
guidance_alpha=self.args.guidance_alpha,
limited_spatial_channel_mem=True,
)[0])
elapsed = time.perf_counter() - start_time
logger.info(f'Duration {duration:.2f}, Elapsed: {elapsed:.2f}\n')
output_tokens = torch.cat(output_list, dim=0)
output_tokens = output_tokens[:, text_len + 1:text_len + 1 +
(total_frames) * 400].reshape(
sample_num, -1,
400 * total_frames)
output_tokens_merge = torch.cat(
(output_tokens[:, :, :1 * 400], output_tokens[:, :,
400 * 3:4 * 400],
output_tokens[:, :, 400 * 1:2 * 400],
output_tokens[:, :, 400 * 4:(total_frames) * 400]),
dim=2).reshape(sample_num, -1, 400)
output_tokens_merge = torch.cat(
(output_tokens_merge, output_tokens[:, -1:, 400 * 2:3 * 400]),
dim=1)
duration /= 2
parent_given_tokens = output_tokens_merge
if self.args.both_stages:
move_start_time = time.perf_counter()
logger.debug('moving stage 2 model to cpu')
model = model.cpu()
torch.cuda.empty_cache()
elapsed = time.perf_counter() - move_start_time
logger.debug(f'moving out model2 takes time: {elapsed:.2f}')
elapsed = time.perf_counter() - stage2_start_time
logger.info(f'CogVideo Stage2 completed. Elapsed: {elapsed:.2f}\n')
# direct super-resolution by CogView2
logger.info('[Direct super-resolution]')
dsr_start_time = time.perf_counter()
enc_text = tokenizer.encode(seq_text)
frame_num_per_sample = parent_given_tokens.shape[1]
parent_given_tokens_2d = parent_given_tokens.reshape(-1, 400)
text_seq = torch.tensor(enc_text, dtype=torch.long,
device=self.device).unsqueeze(0).repeat(
parent_given_tokens_2d.shape[0], 1)
sred_tokens = self.dsr(text_seq, parent_given_tokens_2d)
decoded_sr_videos = []
for sample_i in range(sample_num):
decoded_sr_imgs = []
for frame_i in range(frame_num_per_sample):
decoded_sr_img = tokenizer.decode(
image_ids=sred_tokens[frame_i + sample_i *
frame_num_per_sample][-3600:])
decoded_sr_imgs.append(
self.postprocess(
torch.nn.functional.interpolate(decoded_sr_img,
size=(480, 480))[0]))
decoded_sr_videos.append(decoded_sr_imgs)
elapsed = time.perf_counter() - dsr_start_time
logger.info(
f'Direct super-resolution completed. Elapsed: {elapsed:.2f}')
elapsed = time.perf_counter() - process_start_time
logger.info(f'--- done ({elapsed=:.3f}) ---')
return True, decoded_sr_videos[0]
@staticmethod
def postprocess(tensor: torch.Tensor) -> np.ndarray:
return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
1, 2, 0).to(torch.uint8).numpy()
def run(self, text: str, seed: int,
only_first_stage: bool) -> list[np.ndarray]:
logger.info('==================== run ====================')
start = time.perf_counter()
set_random_seed(seed)
if only_first_stage:
self.args.stage_1 = True
self.args.both_stages = False
else:
self.args.stage_1 = False
self.args.both_stages = True
parent_given_tokens, res = self.process_stage1(
self.model_stage1,
text,
duration=4.0,
video_raw_text=text,
video_guidance_text='视频',
image_text_suffix=' 高清摄影',
batch_size=self.args.batch_size)
if not only_first_stage:
_, res = self.process_stage2(
self.model_stage2,
text,
duration=2.0,
parent_given_tokens=parent_given_tokens,
video_raw_text=text + ' 视频',
video_guidance_text='视频',
gpu_rank=0,
gpu_parallel_size=1) # TODO: 修改
elapsed = time.perf_counter() - start
logger.info(f'Elapsed: {elapsed:.3f}')
logger.info('==================== done ====================')
return res
class AppModel(Model):
def __init__(self, only_first_stage: bool):
super().__init__(only_first_stage)
self.translator = gr.Interface.load(
'spaces/chinhon/translation_eng2ch')
def to_video(self, frames: list[np.ndarray]) -> str:
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
if self.args.stage_1:
fps = 4
else:
fps = 8
writer = iio.get_writer(out_file.name, fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
return out_file.name
def run_with_translation(
self, text: str, translate: bool, seed: int, only_first_stage: bool
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=}')
if translate:
text = translated_text = self.translator(text)
else:
translated_text = None
frames = self.run(text, seed, only_first_stage)
video_path = self.to_video(frames)
return translated_text, video_path, frames