# This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py from __future__ import annotations import argparse import functools import logging import os import pathlib import random import subprocess import sys import time import zipfile from typing import Any if os.getenv('SYSTEM') == 'spaces': subprocess.run('pip install icetk==0.0.3'.split()) subprocess.run('pip install SwissArmyTransformer==0.2.4'.split()) subprocess.run( 'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31' .split()) #subprocess.run('git clone https://github.com/NVIDIA/apex'.split()) #subprocess.run('git checkout 1403c21'.split(), cwd='apex') #with open('patch.apex') as f: # subprocess.run('patch -p1'.split(), cwd='apex', stdin=f) #subprocess.run( # 'pip install -v --disable-pip-version-check --no-cache-dir --global-option --cpp_ext --global-option --cuda_ext ./' # .split(), # cwd='apex') #subprocess.run('rm -rf apex'.split()) with open('patch') as f: subprocess.run('patch -p1'.split(), cwd='CogView2', stdin=f) from huggingface_hub import hf_hub_download def download_and_extract_icetk_models() -> None: icetk_model_dir = pathlib.Path('/home/user/.icetk_models') icetk_model_dir.mkdir() path = hf_hub_download('THUDM/icetk', 'models.zip', use_auth_token=os.getenv('HF_TOKEN')) with zipfile.ZipFile(path) as f: f.extractall(path=icetk_model_dir.as_posix()) def download_and_extract_cogview2_models(name: str) -> None: path = hf_hub_download('THUDM/CogView2', name, use_auth_token=os.getenv('HF_TOKEN')) with zipfile.ZipFile(path) as f: f.extractall() os.remove(path) download_and_extract_icetk_models() names = [ 'coglm.zip', 'cogview2-dsr.zip', 'cogview2-itersr.zip', ] for name in names: download_and_extract_cogview2_models(name) os.environ['SAT_HOME'] = '/home/user/app/sharefs/cogview-new' import gradio as gr import numpy as np import torch from icetk import icetk as tokenizer from SwissArmyTransformer import get_args from SwissArmyTransformer.arguments import set_random_seed from SwissArmyTransformer.generation.autoregressive_sampling import \ filling_sequence from SwissArmyTransformer.model import CachedAutoregressiveModel app_dir = pathlib.Path(__file__).parent submodule_dir = app_dir / 'CogView2' sys.path.insert(0, submodule_dir.as_posix()) from coglm_strategy import CoglmStrategy from sr_pipeline import SRGroup formatter = logging.Formatter( '[%(asctime)s] %(name)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.propagate = False logger.addHandler(stream_handler) tokenizer.add_special_tokens( ['', '', '']) def get_masks_and_position_ids_coglm( seq: torch.Tensor, context_length: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: tokens = seq.unsqueeze(0) attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) attention_mask.tril_() attention_mask[..., :context_length] = 1 attention_mask.unsqueeze_(1) position_ids = torch.zeros(len(seq), device=tokens.device, dtype=torch.long) torch.arange(0, context_length, out=position_ids[:context_length]) torch.arange(512, 512 + len(seq) - context_length, out=position_ids[context_length:]) position_ids = position_ids.unsqueeze(0) return tokens, attention_mask, position_ids class InferenceModel(CachedAutoregressiveModel): def final_forward(self, logits, **kwargs): logits_parallel = logits logits_parallel = torch.nn.functional.linear( logits_parallel.float(), self.transformer.word_embeddings.weight[:20000].float()) return logits_parallel def get_recipe(name: str) -> dict[str, Any]: r = { 'attn_plus': 1.4, 'temp_all_gen': 1.15, 'topk_gen': 16, 'temp_cluster_gen': 1., 'temp_all_dsr': 1.5, 'topk_dsr': 100, 'temp_cluster_dsr': 0.89, 'temp_all_itersr': 1.3, 'topk_itersr': 16, 'query_template': '{}', } if name == 'none': pass elif name == 'mainbody': r['query_template'] = '{} 高清摄影 隔绝' elif name == 'photo': r['query_template'] = '{} 高清摄影' elif name == 'flat': r['query_template'] = '{} 平面风格' # r['attn_plus'] = 1.8 # r['temp_cluster_gen'] = 0.75 r['temp_all_gen'] = 1.1 r['topk_dsr'] = 5 r['temp_cluster_dsr'] = 0.4 r['temp_all_itersr'] = 1 r['topk_itersr'] = 5 elif name == 'comics': r['query_template'] = '{} 漫画 隔绝' r['topk_dsr'] = 5 r['temp_cluster_dsr'] = 0.4 r['temp_all_gen'] = 1.1 r['temp_all_itersr'] = 1 r['topk_itersr'] = 5 elif name == 'oil': r['query_template'] = '{} 油画风格' pass elif name == 'sketch': r['query_template'] = '{} 素描风格' r['temp_all_gen'] = 1.1 elif name == 'isometric': r['query_template'] = '{} 等距矢量图' r['temp_all_gen'] = 1.1 elif name == 'chinese': r['query_template'] = '{} 水墨国画' r['temp_all_gen'] = 1.12 elif name == 'watercolor': r['query_template'] = '{} 水彩画风格' return r def get_default_args() -> argparse.Namespace: arg_list = ['--mode', 'inference', '--fp16'] args = get_args(arg_list) known = argparse.Namespace(img_size=160, only_first_stage=False, inverse_prompt=False, style='mainbody') args = argparse.Namespace(**vars(args), **vars(known), **get_recipe(known.style)) return args class Model: def __init__(self, max_inference_batch_size: int, only_first_stage: bool = False): self.args = get_default_args() self.args.only_first_stage = only_first_stage self.args.max_inference_batch_size = max_inference_batch_size self.model, self.args = self.load_model() self.strategy = self.load_strategy() self.srg = self.load_srg() self.query_template = self.args.query_template self.style = self.args.style self.device = torch.device(self.args.device) self.fp16 = self.args.fp16 self.max_batch_size = self.args.max_inference_batch_size self.only_first_stage = self.args.only_first_stage def load_model(self) -> tuple[InferenceModel, argparse.Namespace]: logger.info('--- load_model ---') start = time.perf_counter() model, args = InferenceModel.from_pretrained(self.args, 'coglm') if not self.args.only_first_stage: model.transformer.cpu() elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') return model, args def load_strategy(self) -> CoglmStrategy: logger.info('--- load_strategy ---') start = time.perf_counter() invalid_slices = [slice(tokenizer.num_image_tokens, None)] strategy = CoglmStrategy(invalid_slices, temperature=self.args.temp_all_gen, top_k=self.args.topk_gen, top_k_cluster=self.args.temp_cluster_gen) elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') return strategy def load_srg(self) -> SRGroup: logger.info('--- load_srg ---') start = time.perf_counter() srg = None if self.args.only_first_stage else SRGroup(self.args) if srg is not None: srg.dsr.max_bz = 2 elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') return srg def update_style(self, style: str) -> None: if style == self.style: return logger.info('--- update_style ---') start = time.perf_counter() self.style = style self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style))) self.query_template = self.args.query_template logger.debug(f'{self.query_template=}') self.strategy.temperature = self.args.temp_all_gen if self.srg is not None: self.srg.dsr.strategy.temperature = self.args.temp_all_dsr self.srg.dsr.strategy.topk = self.args.topk_dsr self.srg.dsr.strategy.temperature2 = self.args.temp_cluster_dsr self.srg.itersr.strategy.temperature = self.args.temp_all_itersr self.srg.itersr.strategy.topk = self.args.topk_itersr elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') def run(self, text: str, style: str, seed: int, only_first_stage: bool, num: int) -> list[np.ndarray] | None: logger.info('==================== run ====================') start = time.perf_counter() self.update_style(style) set_random_seed(seed) seq, txt_len = self.preprocess_text(text) if seq is None: return None self.only_first_stage = only_first_stage if not self.only_first_stage or self.srg is not None: self.srg.dsr.model.cpu() self.srg.itersr.model.cpu() torch.cuda.empty_cache() self.model.transformer.to(self.device) tokens = self.generate_tokens(seq, txt_len, num) if not self.only_first_stage: self.model.transformer.cpu() torch.cuda.empty_cache() self.srg.dsr.model.to(self.device) self.srg.itersr.model.to(self.device) torch.cuda.empty_cache() res = self.generate_images(seq, txt_len, tokens) elapsed = time.perf_counter() - start logger.info(f'Elapsed: {elapsed}') logger.info('==================== done ====================') return res @torch.inference_mode() def preprocess_text( self, text: str) -> tuple[torch.Tensor, int] | tuple[None, None]: logger.info('--- preprocess_text ---') start = time.perf_counter() text = self.query_template.format(text) logger.debug(f'{text=}') seq = tokenizer.encode(text) logger.info(f'{len(seq)=}') if len(seq) > 110: logger.info('The input text is too long.') return None, None txt_len = len(seq) - 1 seq = torch.tensor(seq + [-1] * 400, device=self.device) elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') return seq, txt_len @torch.inference_mode() def generate_tokens(self, seq: torch.Tensor, txt_len: int, num: int = 8) -> torch.Tensor: logger.info('--- generate_tokens ---') start = time.perf_counter() # calibrate text length log_attention_weights = torch.zeros( len(seq), len(seq), device=self.device, dtype=torch.half if self.fp16 else torch.float32) log_attention_weights[:, :txt_len] = self.args.attn_plus get_func = functools.partial(get_masks_and_position_ids_coglm, context_length=txt_len) output_list = [] remaining = num for _ in range((num + self.max_batch_size - 1) // self.max_batch_size): self.strategy.start_pos = txt_len + 1 coarse_samples = filling_sequence( self.model, seq.clone(), batch_size=min(remaining, self.max_batch_size), strategy=self.strategy, log_attention_weights=log_attention_weights, get_masks_and_position_ids=get_func)[0] output_list.append(coarse_samples) remaining -= self.max_batch_size output_tokens = torch.cat(output_list, dim=0) logger.debug(f'{output_tokens.shape=}') elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') return output_tokens @staticmethod def postprocess(tensor: torch.Tensor) -> np.ndarray: return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute( 1, 2, 0).to(torch.uint8).numpy() @torch.inference_mode() def generate_images(self, seq: torch.Tensor, txt_len: int, tokens: torch.Tensor) -> list[np.ndarray]: logger.info('--- generate_images ---') start = time.perf_counter() logger.debug(f'{self.only_first_stage=}') res = [] if self.only_first_stage: for i in range(len(tokens)): seq = tokens[i] decoded_img = tokenizer.decode(image_ids=seq[-400:]) decoded_img = torch.nn.functional.interpolate(decoded_img, size=(480, 480)) decoded_img = self.postprocess(decoded_img[0]) res.append(decoded_img) # only the last image (target) else: # sr iter_tokens = self.srg.sr_base(tokens[:, -400:], seq[:txt_len]) for seq in iter_tokens: decoded_img = tokenizer.decode(image_ids=seq[-3600:]) decoded_img = torch.nn.functional.interpolate(decoded_img, size=(480, 480)) decoded_img = self.postprocess(decoded_img[0]) res.append(decoded_img) # only the last image (target) elapsed = time.perf_counter() - start logger.info(f'--- done ({elapsed=:.3f}) ---') return res class AppModel(Model): def __init__(self, max_inference_batch_size: int, only_first_stage: bool): super().__init__(max_inference_batch_size, only_first_stage) self.translator = gr.Interface.load( 'spaces/chinhon/translation_eng2ch') self.rng = random.Random() def make_grid(self, images: list[np.ndarray] | None) -> np.ndarray | None: if images is None or len(images) == 0: return None ncols = 1 while True: if ncols**2 >= len(images): break ncols += 1 nrows = (len(images) + ncols - 1) // ncols h, w = images[0].shape[:2] grid = np.zeros((h * nrows, w * ncols, 3), dtype=np.uint8) for i in range(nrows): for j in range(ncols): index = ncols * i + j if index >= len(images): break grid[h * i:h * (i + 1), w * j:w * (j + 1)] = images[index] return grid def run_advanced( self, text: str, translate: bool, style: str, seed: int, only_first_stage: bool, num: int ) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]: logger.info( f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}' ) if translate: text = translated_text = self.translator(text) else: translated_text = None results = self.run(text, style, seed, only_first_stage, num) grid_image = self.make_grid(results) return translated_text, grid_image, results def run_simple(self, text: str) -> np.ndarray | None: logger.info(f'{text=}') if text.isascii(): text = self.translator(text) seed = self.rng.randint(0, 100000) results = self.run(text, 'photo', seed, False, 4) grid_image = self.make_grid(results) return grid_image