|
import os |
|
import yaml |
|
import torch |
|
import sys |
|
sys.path.append(os.path.abspath('./')) |
|
from inference.utils import * |
|
from train import WurstCoreB |
|
from gdf import DDPMSampler |
|
from train import WurstCore_t2i as WurstCoreC |
|
from core.utils import load_or_fail |
|
import numpy as np |
|
import random |
|
import argparse |
|
import gradio as gr |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( '--height', type=int, default=2560, help='image height') |
|
parser.add_argument('--width', type=int, default=5120, help='image width') |
|
parser.add_argument('--seed', type=int, default=123, help='random seed') |
|
parser.add_argument('--dtype', type=str, default='bf16', help=' if bf16 does not work, change it to float32 ') |
|
parser.add_argument('--config_c', type=str, |
|
default='configs/training/t2i.yaml' ,help='config file for stage c, latent generation') |
|
parser.add_argument('--config_b', type=str, |
|
default='configs/inference/stage_b_1b.yaml' ,help='config file for stage b, latent decoding') |
|
parser.add_argument( '--prompt', type=str, |
|
default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') |
|
parser.add_argument( '--num_image', type=int, default=1, help='how many images generated') |
|
parser.add_argument( '--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') |
|
parser.add_argument( '--stage_a_tiled', action='store_true', help='whther or nor to use tiled decoding for stage a to save memory') |
|
parser.add_argument( '--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added paramter of UltraPixel') |
|
args = parser.parse_args() |
|
return args |
|
|
|
def clear_image(): |
|
return None |
|
def load_message(height, width, seed, prompt, args, stage_a_tiled): |
|
args.height = height |
|
args.width = width |
|
args.seed = seed |
|
args.prompt = prompt + ' rich detail, 4k, high quality' |
|
args.stage_a_tiled = stage_a_tiled |
|
return args |
|
def get_image(height, width, seed, prompt, cfg, timesteps, stage_a_tiled): |
|
global args |
|
args = load_message(height, width, seed, prompt, args, stage_a_tiled) |
|
torch.manual_seed(args.seed) |
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float |
|
|
|
captions = [args.prompt] * args.num_image |
|
height, width = args.height, args.width |
|
batch_size=1 |
|
height_lr, width_lr = get_target_lr_size(height / width, std_size=32) |
|
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) |
|
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) |
|
|
|
|
|
extras.sampling_configs['cfg'] = 4 |
|
extras.sampling_configs['shift'] = 1 |
|
extras.sampling_configs['timesteps'] = 20 |
|
extras.sampling_configs['t_start'] = 1.0 |
|
extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) |
|
|
|
|
|
|
|
|
|
extras_b.sampling_configs['cfg'] = 1.1 |
|
extras_b.sampling_configs['shift'] = 1 |
|
extras_b.sampling_configs['timesteps'] = 10 |
|
extras_b.sampling_configs['t_start'] = 1.0 |
|
|
|
for _, caption in enumerate(captions): |
|
|
|
|
|
batch = {'captions': [caption] * batch_size} |
|
|
|
|
|
|
|
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) |
|
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
models.generator.cuda() |
|
print('STAGE C GENERATION***************************') |
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) |
|
|
|
|
|
|
|
models.generator.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) |
|
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) |
|
conditions_b['effnet'] = sampled_c |
|
unconditions_b['effnet'] = torch.zeros_like(sampled_c) |
|
print('STAGE B + A DECODING***************************') |
|
|
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) |
|
|
|
torch.cuda.empty_cache() |
|
imgs = show_images(sampled) |
|
|
|
|
|
|
|
|
|
return imgs[0] |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
height = gr.Slider(value=2304, step=32, minimum=1536, maximum=4096, label='Height') |
|
width = gr.Slider(value=4096, step=32, minimum=1536, maximum=5120, label='Width') |
|
seed = gr.Number(value=123, step=1, label='Random Seed') |
|
prompt = gr.Textbox(value='', max_lines=4, label='Text Prompt') |
|
cfg = gr.Slider(value=4, step=0.1, minimum=3, maximum=10, label='CFG') |
|
timesteps = gr.Slider(value=20, step=1, minimum=10, maximum=50, label='Timesteps') |
|
stage_a_tiled = gr.Checkbox(value=False, label='Stage_a_tiled') |
|
with gr.Row(): |
|
clear_button = gr.Button("Clear!") |
|
polish_button = gr.Button("Submit!") |
|
with gr.Column(): |
|
output_img = gr.Image(label='Output Image', sources=None) |
|
with gr.Column(): |
|
prompt2 = gr.Textbox( |
|
value=''' |
|
1. a happy cat |
|
2. a happy girl |
|
''', label='Text prompt examples' |
|
) |
|
|
|
polish_button.click(get_image, inputs=[height, width, seed, prompt, cfg, timesteps, stage_a_tiled], outputs=output_img) |
|
polish_button.click(clear_image, inputs=[], outputs=output_img) |
|
|
|
if __name__ == "__main__": |
|
|
|
args = parse_args() |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
config_file = args.config_c |
|
with open(config_file, "r", encoding="utf-8") as file: |
|
loaded_config = yaml.safe_load(file) |
|
|
|
core = WurstCoreC(config_dict=loaded_config, device=device, training=False) |
|
|
|
|
|
config_file_b = args.config_b |
|
with open(config_file_b, "r", encoding="utf-8") as file: |
|
config_file_b = yaml.safe_load(file) |
|
|
|
core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) |
|
|
|
extras = core.setup_extras_pre() |
|
models = core.setup_models(extras) |
|
models.generator.eval().requires_grad_(False) |
|
print("STAGE C READY") |
|
|
|
extras_b = core_b.setup_extras_pre() |
|
models_b = core_b.setup_models(extras_b, skip_clip=True) |
|
models_b = WurstCoreB.Models( |
|
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} |
|
) |
|
models_b.generator.bfloat16().eval().requires_grad_(False) |
|
print("STAGE B READY") |
|
|
|
pretrained_path = args.pretrained_path |
|
sdd = torch.load(pretrained_path, map_location='cpu') |
|
collect_sd = {} |
|
for k, v in sdd.items(): |
|
collect_sd[k[7:]] = v |
|
|
|
models.train_norm.load_state_dict(collect_sd) |
|
models.generator.eval() |
|
models.train_norm.eval() |
|
|
|
|
|
demo.launch( |
|
debug=True, share=True, |
|
|
|
|
|
) |