CogView2 / model.py
hysts's picture
hysts HF staff
Add a simpler interface
fde1134
# This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py
from __future__ import annotations
import argparse
import functools
import logging
import os
import pathlib
import random
import subprocess
import sys
import time
import zipfile
from typing import Any
if os.getenv('SYSTEM') == 'spaces':
subprocess.run('pip install icetk==0.0.3'.split())
subprocess.run('pip install SwissArmyTransformer==0.2.4'.split())
subprocess.run(
'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31'
.split())
#subprocess.run('git clone https://github.com/NVIDIA/apex'.split())
#subprocess.run('git checkout 1403c21'.split(), cwd='apex')
#with open('patch.apex') as f:
# subprocess.run('patch -p1'.split(), cwd='apex', stdin=f)
#subprocess.run(
# 'pip install -v --disable-pip-version-check --no-cache-dir --global-option --cpp_ext --global-option --cuda_ext ./'
# .split(),
# cwd='apex')
#subprocess.run('rm -rf apex'.split())
with open('patch') as f:
subprocess.run('patch -p1'.split(), cwd='CogView2', stdin=f)
from huggingface_hub import hf_hub_download
def download_and_extract_icetk_models() -> None:
icetk_model_dir = pathlib.Path('/home/user/.icetk_models')
icetk_model_dir.mkdir()
path = hf_hub_download('THUDM/icetk',
'models.zip',
use_auth_token=os.getenv('HF_TOKEN'))
with zipfile.ZipFile(path) as f:
f.extractall(path=icetk_model_dir.as_posix())
def download_and_extract_cogview2_models(name: str) -> None:
path = hf_hub_download('THUDM/CogView2',
name,
use_auth_token=os.getenv('HF_TOKEN'))
with zipfile.ZipFile(path) as f:
f.extractall()
os.remove(path)
download_and_extract_icetk_models()
names = [
'coglm.zip',
'cogview2-dsr.zip',
'cogview2-itersr.zip',
]
for name in names:
download_and_extract_cogview2_models(name)
os.environ['SAT_HOME'] = '/home/user/app/sharefs/cogview-new'
import gradio as gr
import numpy as np
import torch
from icetk import icetk as tokenizer
from SwissArmyTransformer import get_args
from SwissArmyTransformer.arguments import set_random_seed
from SwissArmyTransformer.generation.autoregressive_sampling import \
filling_sequence
from SwissArmyTransformer.model import CachedAutoregressiveModel
app_dir = pathlib.Path(__file__).parent
submodule_dir = app_dir / 'CogView2'
sys.path.insert(0, submodule_dir.as_posix())
from coglm_strategy import CoglmStrategy
from sr_pipeline import SRGroup
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.propagate = False
logger.addHandler(stream_handler)
tokenizer.add_special_tokens(
['<start_of_image>', '<start_of_english>', '<start_of_chinese>'])
def get_masks_and_position_ids_coglm(
seq: torch.Tensor, context_length: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
tokens = seq.unsqueeze(0)
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
attention_mask.tril_()
attention_mask[..., :context_length] = 1
attention_mask.unsqueeze_(1)
position_ids = torch.zeros(len(seq),
device=tokens.device,
dtype=torch.long)
torch.arange(0, context_length, out=position_ids[:context_length])
torch.arange(512,
512 + len(seq) - context_length,
out=position_ids[context_length:])
position_ids = position_ids.unsqueeze(0)
return tokens, attention_mask, position_ids
class InferenceModel(CachedAutoregressiveModel):
def final_forward(self, logits, **kwargs):
logits_parallel = logits
logits_parallel = torch.nn.functional.linear(
logits_parallel.float(),
self.transformer.word_embeddings.weight[:20000].float())
return logits_parallel
def get_recipe(name: str) -> dict[str, Any]:
r = {
'attn_plus': 1.4,
'temp_all_gen': 1.15,
'topk_gen': 16,
'temp_cluster_gen': 1.,
'temp_all_dsr': 1.5,
'topk_dsr': 100,
'temp_cluster_dsr': 0.89,
'temp_all_itersr': 1.3,
'topk_itersr': 16,
'query_template': '{}<start_of_image>',
}
if name == 'none':
pass
elif name == 'mainbody':
r['query_template'] = '{} 高清摄影 隔绝<start_of_image>'
elif name == 'photo':
r['query_template'] = '{} 高清摄影<start_of_image>'
elif name == 'flat':
r['query_template'] = '{} 平面风格<start_of_image>'
# r['attn_plus'] = 1.8
# r['temp_cluster_gen'] = 0.75
r['temp_all_gen'] = 1.1
r['topk_dsr'] = 5
r['temp_cluster_dsr'] = 0.4
r['temp_all_itersr'] = 1
r['topk_itersr'] = 5
elif name == 'comics':
r['query_template'] = '{} 漫画 隔绝<start_of_image>'
r['topk_dsr'] = 5
r['temp_cluster_dsr'] = 0.4
r['temp_all_gen'] = 1.1
r['temp_all_itersr'] = 1
r['topk_itersr'] = 5
elif name == 'oil':
r['query_template'] = '{} 油画风格<start_of_image>'
pass
elif name == 'sketch':
r['query_template'] = '{} 素描风格<start_of_image>'
r['temp_all_gen'] = 1.1
elif name == 'isometric':
r['query_template'] = '{} 等距矢量图<start_of_image>'
r['temp_all_gen'] = 1.1
elif name == 'chinese':
r['query_template'] = '{} 水墨国画<start_of_image>'
r['temp_all_gen'] = 1.12
elif name == 'watercolor':
r['query_template'] = '{} 水彩画风格<start_of_image>'
return r
def get_default_args() -> argparse.Namespace:
arg_list = ['--mode', 'inference', '--fp16']
args = get_args(arg_list)
known = argparse.Namespace(img_size=160,
only_first_stage=False,
inverse_prompt=False,
style='mainbody')
args = argparse.Namespace(**vars(args), **vars(known),
**get_recipe(known.style))
return args
class Model:
def __init__(self,
max_inference_batch_size: int,
only_first_stage: bool = False):
self.args = get_default_args()
self.args.only_first_stage = only_first_stage
self.args.max_inference_batch_size = max_inference_batch_size
self.model, self.args = self.load_model()
self.strategy = self.load_strategy()
self.srg = self.load_srg()
self.query_template = self.args.query_template
self.style = self.args.style
self.device = torch.device(self.args.device)
self.fp16 = self.args.fp16
self.max_batch_size = self.args.max_inference_batch_size
self.only_first_stage = self.args.only_first_stage
def load_model(self) -> tuple[InferenceModel, argparse.Namespace]:
logger.info('--- load_model ---')
start = time.perf_counter()
model, args = InferenceModel.from_pretrained(self.args, 'coglm')
if not self.args.only_first_stage:
model.transformer.cpu()
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return model, args
def load_strategy(self) -> CoglmStrategy:
logger.info('--- load_strategy ---')
start = time.perf_counter()
invalid_slices = [slice(tokenizer.num_image_tokens, None)]
strategy = CoglmStrategy(invalid_slices,
temperature=self.args.temp_all_gen,
top_k=self.args.topk_gen,
top_k_cluster=self.args.temp_cluster_gen)
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return strategy
def load_srg(self) -> SRGroup:
logger.info('--- load_srg ---')
start = time.perf_counter()
srg = None if self.args.only_first_stage else SRGroup(self.args)
if srg is not None:
srg.dsr.max_bz = 2
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return srg
def update_style(self, style: str) -> None:
if style == self.style:
return
logger.info('--- update_style ---')
start = time.perf_counter()
self.style = style
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style)))
self.query_template = self.args.query_template
logger.debug(f'{self.query_template=}')
self.strategy.temperature = self.args.temp_all_gen
if self.srg is not None:
self.srg.dsr.strategy.temperature = self.args.temp_all_dsr
self.srg.dsr.strategy.topk = self.args.topk_dsr
self.srg.dsr.strategy.temperature2 = self.args.temp_cluster_dsr
self.srg.itersr.strategy.temperature = self.args.temp_all_itersr
self.srg.itersr.strategy.topk = self.args.topk_itersr
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
def run(self, text: str, style: str, seed: int, only_first_stage: bool,
num: int) -> list[np.ndarray] | None:
logger.info('==================== run ====================')
start = time.perf_counter()
self.update_style(style)
set_random_seed(seed)
seq, txt_len = self.preprocess_text(text)
if seq is None:
return None
self.only_first_stage = only_first_stage
if not self.only_first_stage or self.srg is not None:
self.srg.dsr.model.cpu()
self.srg.itersr.model.cpu()
torch.cuda.empty_cache()
self.model.transformer.to(self.device)
tokens = self.generate_tokens(seq, txt_len, num)
if not self.only_first_stage:
self.model.transformer.cpu()
torch.cuda.empty_cache()
self.srg.dsr.model.to(self.device)
self.srg.itersr.model.to(self.device)
torch.cuda.empty_cache()
res = self.generate_images(seq, txt_len, tokens)
elapsed = time.perf_counter() - start
logger.info(f'Elapsed: {elapsed}')
logger.info('==================== done ====================')
return res
@torch.inference_mode()
def preprocess_text(
self, text: str) -> tuple[torch.Tensor, int] | tuple[None, None]:
logger.info('--- preprocess_text ---')
start = time.perf_counter()
text = self.query_template.format(text)
logger.debug(f'{text=}')
seq = tokenizer.encode(text)
logger.info(f'{len(seq)=}')
if len(seq) > 110:
logger.info('The input text is too long.')
return None, None
txt_len = len(seq) - 1
seq = torch.tensor(seq + [-1] * 400, device=self.device)
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return seq, txt_len
@torch.inference_mode()
def generate_tokens(self,
seq: torch.Tensor,
txt_len: int,
num: int = 8) -> torch.Tensor:
logger.info('--- generate_tokens ---')
start = time.perf_counter()
# calibrate text length
log_attention_weights = torch.zeros(
len(seq),
len(seq),
device=self.device,
dtype=torch.half if self.fp16 else torch.float32)
log_attention_weights[:, :txt_len] = self.args.attn_plus
get_func = functools.partial(get_masks_and_position_ids_coglm,
context_length=txt_len)
output_list = []
remaining = num
for _ in range((num + self.max_batch_size - 1) // self.max_batch_size):
self.strategy.start_pos = txt_len + 1
coarse_samples = filling_sequence(
self.model,
seq.clone(),
batch_size=min(remaining, self.max_batch_size),
strategy=self.strategy,
log_attention_weights=log_attention_weights,
get_masks_and_position_ids=get_func)[0]
output_list.append(coarse_samples)
remaining -= self.max_batch_size
output_tokens = torch.cat(output_list, dim=0)
logger.debug(f'{output_tokens.shape=}')
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return output_tokens
@staticmethod
def postprocess(tensor: torch.Tensor) -> np.ndarray:
return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute(
1, 2, 0).to(torch.uint8).numpy()
@torch.inference_mode()
def generate_images(self, seq: torch.Tensor, txt_len: int,
tokens: torch.Tensor) -> list[np.ndarray]:
logger.info('--- generate_images ---')
start = time.perf_counter()
logger.debug(f'{self.only_first_stage=}')
res = []
if self.only_first_stage:
for i in range(len(tokens)):
seq = tokens[i]
decoded_img = tokenizer.decode(image_ids=seq[-400:])
decoded_img = torch.nn.functional.interpolate(decoded_img,
size=(480, 480))
decoded_img = self.postprocess(decoded_img[0])
res.append(decoded_img) # only the last image (target)
else: # sr
iter_tokens = self.srg.sr_base(tokens[:, -400:], seq[:txt_len])
for seq in iter_tokens:
decoded_img = tokenizer.decode(image_ids=seq[-3600:])
decoded_img = torch.nn.functional.interpolate(decoded_img,
size=(480, 480))
decoded_img = self.postprocess(decoded_img[0])
res.append(decoded_img) # only the last image (target)
elapsed = time.perf_counter() - start
logger.info(f'--- done ({elapsed=:.3f}) ---')
return res
class AppModel(Model):
def __init__(self, max_inference_batch_size: int, only_first_stage: bool):
super().__init__(max_inference_batch_size, only_first_stage)
self.translator = gr.Interface.load(
'spaces/chinhon/translation_eng2ch')
self.rng = random.Random()
def make_grid(self, images: list[np.ndarray] | None) -> np.ndarray | None:
if images is None or len(images) == 0:
return None
ncols = 1
while True:
if ncols**2 >= len(images):
break
ncols += 1
nrows = (len(images) + ncols - 1) // ncols
h, w = images[0].shape[:2]
grid = np.zeros((h * nrows, w * ncols, 3), dtype=np.uint8)
for i in range(nrows):
for j in range(ncols):
index = ncols * i + j
if index >= len(images):
break
grid[h * i:h * (i + 1), w * j:w * (j + 1)] = images[index]
return grid
def run_advanced(
self, text: str, translate: bool, style: str, seed: int,
only_first_stage: bool, num: int
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]:
logger.info(
f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}'
)
if translate:
text = translated_text = self.translator(text)
else:
translated_text = None
results = self.run(text, style, seed, only_first_stage, num)
grid_image = self.make_grid(results)
return translated_text, grid_image, results
def run_simple(self, text: str) -> np.ndarray | None:
logger.info(f'{text=}')
if text.isascii():
text = self.translator(text)
seed = self.rng.randint(0, 100000)
results = self.run(text, 'photo', seed, False, 4)
grid_image = self.make_grid(results)
return grid_image