Gradio demo

#1
by kamalkraj - opened

Hi @akhaliq ,

To run the CogView2 model it is recommended run in a A100 GPU. As i said in this issue it is better to run the model prediction service as separate and host a gradio/streamlit demo in the spaces.
We can configure the url to the prediction service using secrets.

I don't think we will be able to load the model in a T4 GPU, I have already tried in TITAN RTX GPU which has 24GB memory. The model takes around 36GB of GPU memory.

Prediction service code from the docker image.

import os

os.environ["SAT_HOME"] = "checkpoints"

# import subprocess
#
# subprocess.call(["mkdir", "-p", "root/.icetk_models"])
# subprocess.call(["cp", "checkpoints/ice_image.pt", "/root/.icetk_models"])
# import shutil
from typing import List
import tempfile
import torch
import argparse
from functools import partial
import numpy as np
from torchvision.utils import save_image, make_grid

from PIL import Image

from cog import BasePredictor, Path, Input, BaseModel
from SwissArmyTransformer import get_args, get_tokenizer
from SwissArmyTransformer.model import CachedAutoregressiveModel
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
from SwissArmyTransformer.generation.autoregressive_sampling import (
    filling_sequence,
    evaluate_perplexity,
)
from SwissArmyTransformer.generation.utils import (
    timed_name,
    save_multiple_images,
    generate_continually,
)

from coglm_strategy import CoglmStrategy
from sr_pipeline import SRGroup

from icetk import icetk as tokenizer

tokenizer.add_special_tokens(
    ["<start_of_image>", "<start_of_english>", "<start_of_chinese>"]
)


class InferenceModel(CachedAutoregressiveModel):
    def final_forward(self, logits, **kwargs):
        logits_parallel = logits
        logits_parallel = torch.nn.functional.linear(
            logits_parallel.float(),
            self.transformer.word_embeddings.weight[:20000].float(),
        )
        return logits_parallel


class ModelOutput(BaseModel):
    image: Path


class Predictor(BasePredictor):
    def setup(self):
        # os.makedirs("root/.icetk_models", exist_ok=True)
        # shutil.copyfile('checkpoints/ice_image.pt', "root/.icetk_models/ice_image.pt")

        py_parser = argparse.ArgumentParser(add_help=False)
        py_parser.add_argument("--img-size", type=int, default=160)
        py_parser.add_argument("--only-first-stage", action="store_true")
        py_parser.add_argument("--inverse-prompt", action="store_true")

        known, args_list = py_parser.parse_known_args(
            [
                "--mode",
                "inference",
                "--batch-size",
                "16",
                "--max-inference-batch-size",
                "8",
                "--fp16",
            ]
        )
        args = get_args(args_list)
        r = {
            "attn_plus": 1.4,
            "temp_all_gen": 1.15,
            "topk_gen": 16,
            "temp_cluster_gen": 1.0,
            "temp_all_dsr": 1.5,
            "topk_dsr": 100,
            "temp_cluster_dsr": 0.89,
            "temp_all_itersr": 1.3,
            "topk_itersr": 16,
            "query_template": "{}<start_of_image>",
        }

        args = argparse.Namespace(**vars(args), **vars(known), **r)

        self.model, self.args = InferenceModel.from_pretrained(args, "coglm")
        self.text_model = CachedAutoregressiveModel(
            self.args, transformer=self.model.transformer
        )
        self.srg = SRGroup(self.args)
        self.invalid_slices = [slice(tokenizer.num_image_tokens, None)]
        # print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=args.device) / 1024. / 1024.))

    def predict(
        self,
        text: str = Input(
            default="a tiger wearing VR glasses",
            description="Text for generating image.",
        ),
        style: str = Input(
            choices=[
                "none",
                "mainbody",
                "photo",
                "flat",
                "comics",
                "oil",
                "sketch",
                "isometric",
                "chinese",
                "watercolor",
            ],
            default="mainbody",
            description="Choose the image style.",
        ),
    ) -> List[ModelOutput]:

        torch.cuda.empty_cache()
        # print('222')
        # print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=self.args.device) / 1024. / 1024.))

        args = adapt_to_style(self.args, style)

        strategy = CoglmStrategy(
            self.invalid_slices,
            temperature=args.temp_all_gen,
            top_k=args.topk_gen,
            top_k_cluster=args.temp_cluster_gen,
        )

        query_template = args.query_template

        # process
        with torch.no_grad():
            text = query_template.format(text)
            seq = tokenizer.encode(text)
            if len(seq) > 110:
                raise ValueError("text too long.")

            txt_len = len(seq) - 1
            seq = torch.tensor(seq + [-1] * 400, device=args.device)
            # calibrate text length
            log_attention_weights = torch.zeros(
                len(seq),
                len(seq),
                device=args.device,
                dtype=torch.half if args.fp16 else torch.float32,
            )
            log_attention_weights[:, :txt_len] = args.attn_plus
            # generation
            mbz = args.max_inference_batch_size
            assert args.batch_size < mbz or args.batch_size % mbz == 0
            get_func = partial(get_masks_and_position_ids_coglm, context_length=txt_len)
            output_list, score_list = [], []
            # print('333')
            # print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=args.device) / 1024. / 1024.))

            for _ in range(max(args.batch_size // mbz, 1)):
                strategy.start_pos = txt_len + 1
                coarse_samples = filling_sequence(
                    self.model,
                    seq.clone(),
                    batch_size=min(args.batch_size, mbz),
                    strategy=strategy,
                    log_attention_weights=log_attention_weights,
                    get_masks_and_position_ids=get_func,
                )[0]

                # get ppl for inverse prompting
                if args.inverse_prompt:
                    image_text_seq = torch.cat(
                        (
                            coarse_samples[:, -400:],
                            torch.tensor(
                                [tokenizer["<start_of_chinese>"]]
                                + tokenizer.encode(text),
                                device=args.device,
                            ).expand(coarse_samples.shape[0], -1),
                        ),
                        dim=1,
                    )
                    seqlen = image_text_seq.shape[1]
                    attention_mask = torch.zeros(
                        seqlen, seqlen, device=args.device, dtype=torch.long
                    )
                    attention_mask[:, :400] = 1
                    attention_mask[400:, 400:] = 1
                    attention_mask[400:, 400:].tril_()
                    position_ids = torch.zeros(
                        seqlen, device=args.device, dtype=torch.long
                    )
                    torch.arange(513, 513 + 400, out=position_ids[:400])
                    torch.arange(0, seqlen - 400, out=position_ids[400:])
                    loss_mask = torch.zeros(
                        seqlen, device=args.device, dtype=torch.long
                    )
                    loss_mask[401:] = 1
                    scores = evaluate_perplexity(
                        self.text_model,
                        image_text_seq,
                        attention_mask,
                        position_ids,
                        loss_mask,  # , invalid_slices=[slice(0, 20000)], reduction='mean'
                    )
                    score_list.extend(scores.tolist())
                    # ---------------------

                output_list.append(coarse_samples)
            output_tokens = torch.cat(output_list, dim=0)

            # print('444')
            #
            # print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=args.device) / 1024. / 1024.))

            if args.inverse_prompt:
                order_list = np.argsort(score_list)[::-1]
                print(sorted(score_list))
            else:
                order_list = range(output_tokens.shape[0])

            imgs, txts = [], []
            if args.only_first_stage:
                for i in order_list:
                    seq = output_tokens[i]
                    decoded_img = tokenizer.decode(image_ids=seq[-400:])
                    decoded_img = torch.nn.functional.interpolate(
                        decoded_img, size=(256, 256)
                    )
                    imgs.append(decoded_img)  # only the last image (target)
            if not args.only_first_stage:  # sr
                iter_tokens = self.srg.sr_base(output_tokens[:, -400:], seq[:txt_len])
                for seq in iter_tokens:
                    decoded_img = tokenizer.decode(image_ids=seq[-3600:])
                    decoded_img = torch.nn.functional.interpolate(
                        decoded_img, size=(256, 256)
                    )
                    imgs.append(decoded_img)  # only the last image (target)
        #     print('555')
        #     print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=args.device) / 1024. / 1024.))
        #
        # print('666')
        # print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=args.device) / 1024. / 1024.))


        # save
        output = []

        for i in range(len(imgs)):
            out_path = Path(tempfile.mkdtemp()) / f"output_{i}.png"
            save_image(imgs[i], str(out_path), normalize=True)
            # save_image(imgs[i], f"pp2_{i}.png", normalize=True)
            output.append(ModelOutput(image=out_path))

        # print('777')
        # print('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=args.device) / 1024. / 1024.))


        return output


def adapt_to_style(args, name):
    if name == "none":
        return args
    if name == "mainbody":
        args.query_template = "{} 高清摄影 隔绝<start_of_image>"

    elif name == "photo":
        args.query_template = "{} 高清摄影<start_of_image>"

    elif name == "flat":
        args.query_template = "{} 平面风格<start_of_image>"
        args.temp_all_gen = 1.1
        args.topk_dsr = 5
        args.temp_cluster_dsr = 0.4

        args.temp_all_itersr = 1
        args.topk_itersr = 5
    elif name == "comics":
        args.query_template = "{} 漫画 隔绝<start_of_image>"
        args.topk_dsr = 5
        args.temp_cluster_dsr = 0.4
        args.temp_all_gen = 1.1
        args.temp_all_itersr = 1
        args.topk_itersr = 5
    elif name == "oil":
        args.query_template = "{} 油画风格<start_of_image>"
        pass
    elif name == "sketch":
        args.query_template = "{} 素描风格<start_of_image>"
        args.temp_all_gen = 1.1
    elif name == "isometric":
        args.query_template = "{} 等距矢量图<start_of_image>"
        args.temp_all_gen = 1.1
    elif name == "chinese":
        args.query_template = "{} 水墨国画<start_of_image>"
        args.temp_all_gen = 1.12
    elif name == "watercolor":
        args.query_template = "{} 水彩画风格<start_of_image>"
    return args


def get_masks_and_position_ids_coglm(seq, context_length):
    tokens = seq.unsqueeze(0)

    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
    attention_mask.tril_()
    attention_mask[..., :context_length] = 1
    attention_mask.unsqueeze_(1)

    position_ids = torch.zeros(len(seq), device=tokens.device, dtype=torch.long)
    torch.arange(0, context_length, out=position_ids[:context_length])
    torch.arange(
        512, 512 + len(seq) - context_length, out=position_ids[context_length:]
    )

    position_ids = position_ids.unsqueeze(0)
    return tokens, attention_mask, position_ids

Sign up or log in to comment