styledrop / app.py
zideliu
Update
0b6e063
import os
import gradio as gr
import open_clip
import torch
import taming.models.vqgan
import ml_collections
import einops
import random
import pathlib
import subprocess
import shlex
import wget
# Model
from libs.muse import MUSE
import utils
import numpy as np
from PIL import Image
def d(**kwargs):
"""Helper of creating a config dict."""
return ml_collections.ConfigDict(initial_dictionary=kwargs)
def get_config():
config = ml_collections.ConfigDict()
config.seed = 1234
config.z_shape = (8, 16, 16)
config.autoencoder = d(
config_file='vq-f16-jax.yaml',
)
config.resume_root="assets/ckpts/cc3m-285000.ckpt"
config.adapter_path=None
config.optimizer = d(
name='adamw',
lr=0.0002,
weight_decay=0.03,
betas=(0.99, 0.99),
)
config.lr_scheduler = d(
name='customized',
warmup_steps=5000
)
config.nnet = d(
name='uvit_t2i_vq',
img_size=16,
codebook_size=1024,
in_chans=4,
embed_dim=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
qkv_bias=False,
clip_dim=1280,
num_clip_token=77,
use_checkpoint=True,
skip=True,
d_prj=32,
is_shared=False
)
config.muse = d(
ignore_ind=-1,
smoothing=0.1,
gen_temp=4.5
)
config.sample = d(
sample_steps=36,
n_samples=50,
mini_batch_size=8,
cfg=True,
linear_inc_scale=True,
scale=10.,
path='',
lambdaA=2.0, # Stage I: 2.0; Stage II: TODO
lambdaB=5.0, # Stage I: 5.0; Stage II: TODO
)
return config
print("cuda available:",torch.cuda.is_available())
print("cuda device count:",torch.cuda.device_count())
print("cuda device name:",torch.cuda.get_device_name(0))
print(os.system("nvidia-smi"))
print(os.system("nvcc --version"))
empty_context = np.load("assets/contexts/empty_context.npy")
config = get_config()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
# Load open_clip and vq model
prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k',precision='fp16')
prompt_model = prompt_model.to(device)
prompt_model.eval()
tokenizer = open_clip.get_tokenizer('ViT-bigG-14')
print("GPU memory:",torch.cuda.memory_allocated(0))
print("downloading cc3m-285000.ckpt")
os.makedirs("assets/ckpts/cc3m-285000.ckpt",exist_ok=True)
os.system("wget https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth -O assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth")
os.system("wget https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth -O assets/ckpts/cc3m-285000.ckpt/optimizer.pth")
os.system("wget https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth -O assets/ckpts/cc3m-285000.ckpt/nnet.pth")
os.system("wget https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth -O assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth")
os.system("wget https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth -O assets/ckpts/cc3m-285000.ckpt/step.pth")
os.system("wget https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt -O assets/vqgan_jax_strongaug.ckpt")
os.system("ls assets/ckpts/cc3m-285000.ckpt")
# wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth","assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth")
# wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth","assets/ckpts/cc3m-285000.ckpt/optimizer.pth")
# wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth","assets/ckpts/cc3m-285000.ckpt/nnet.pth")
# wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth","assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth")
# wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth","assets/ckpts/cc3m-285000.ckpt/step.pth")
# wget.download("https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt","assets/vqgan_jax_strongaug.ckpt")
# os.system("ls assets/ckpts/cc3m-285000.ckpt")
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None):
_cond = nnet_ema(x, context=context)
_cond_w_adapter = nnet_ema(x,context=context,use_adapter=True)
_empty_context = torch.tensor(empty_context, device=device)
_empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
_uncond = nnet_ema(x, context=_empty_context)
res = _cond + scale * (_cond - _uncond)
if lambdaA is not None:
res = _cond_w_adapter + lambdaA*(_cond_w_adapter - _cond) + lambdaB*(_cond - _uncond)
return res
def unprocess(x):
x.clamp_(0., 1.)
return x
vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml')
vq_model.eval()
vq_model.requires_grad_(False)
vq_model.to(device)
## config
muse = MUSE(codebook_size=vq_model.n_embed, device=device, **config.muse)
train_state = utils.initialize_train_state(config, device)
train_state.resume(ckpt_root=config.resume_root)
nnet_ema = train_state.nnet_ema
nnet_ema.eval()
nnet_ema.requires_grad_(False)
nnet_ema.to(device)
style_ref = {
"None":None,
"0102":"style_adapter/0102.pth",
"0103":"style_adapter/0103.pth",
"0106":"style_adapter/0106.pth",
"0108":"style_adapter/0108.pth",
"0301":"style_adapter/0301.pth",
"0305":"style_adapter/0305.pth",
}
style_postfix ={
"None":"",
"0102":" in watercolor painting style",
"0103":" in watercolor painting style",
"0106":" in line drawing style",
"0108":" in oil painting style",
"0301":" in 3d rendering style",
"0305":" in kid crayon drawing style",
}
def decode(_batch):
return vq_model.decode_code(_batch)
def process(prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image=None):
config.sample.lambdaA = lambdaA
config.sample.lambdaB = lambdaB
config.sample.sample_steps = sample_steps
print(style)
adapter_path = style_ref[style]
adapter_postfix = style_postfix[style]
print(f"load adapter path: {adapter_path}")
if adapter_path is not None:
nnet_ema.adapter.load_state_dict(torch.load(adapter_path))
else:
config.sample.lambdaA=None
config.sample.lambdaB=None
print("load adapter Done!")
# Encode prompt
prompt = prompt+adapter_postfix
text_tokens = tokenizer(prompt).to(device)
text_embedding = prompt_model.encode_text(text_tokens)
text_embedding = text_embedding.repeat(num_samples, 1, 1) # B 77 1280
print(text_embedding.shape)
print(f"lambdaA: {lambdaA}, lambdaB: {lambdaB}, sample_steps: {sample_steps}")
if seed==-1:
seed = random.randint(0,65535)
config.seed = seed
print(f"seed: {seed}")
set_seed(config.seed)
res = muse.generate(config,num_samples,cfg_nnet,decode,is_eval=True,context=text_embedding)
print(res.shape)
res = (res*255+0.5).clamp_(0,255).permute(0,2,3,1).to('cpu',torch.uint8).numpy()
im = [res[i] for i in range(num_samples)]
return im
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("## StyleDrop based on Muse (Inference Only) ")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(label="Run")
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=1234)
style = gr.Radio(choices=["0102","0103","0106","0108","0305","None"],type="value",value="None",label="Style")
with gr.Accordion("Advanced options",open=False):
lambdaA = gr.Slider(label="lambdaA", minimum=0.0, maximum=5.0, value=2.0, step=0.01)
lambdaB = gr.Slider(label="lambdaB", minimum=0.0, maximum=10.0, value=5.0, step=0.01)
sample_steps = gr.Slider(label="Sample steps", minimum=1, maximum=50, value=36, step=1)
image=gr.Image(value=None)
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(columns=2, height='auto')
with gr.Row():
examples = [
[
"A banana on the table",
1,2.0,5.0,"0103",1234,36,
"data/image_01_03.jpg",
],
[
"A cow",
1,2.0,5.0,"0102",1234,36,
"data/image_01_02.jpg",
],
[
"A portrait of tabby cat",
1,2.0,5.0,"0106",1234,36,
"data/image_01_06.jpg",
],
[
"A church in the field",
1,2.0,5.0,"0108",1234,36,
"data/image_01_08.jpg",
],
[
"A Christmas tree",
1,2.0,5.0,"0305",1234,36,
"data/image_03_05.jpg",
]
]
gr.Examples(examples=examples,
fn=process,
inputs=[
prompt,
num_samples,lambdaA,lambdaB,style,seed,sample_steps,image,
],
outputs=result_gallery,
cache_examples=os.getenv('SYSTEM') == 'spaces'
)
ips = [prompt,num_samples,lambdaA,lambdaB,style,seed,sample_steps,image]
run_button.click(
fn=process,
inputs=ips,
outputs=[result_gallery]
)
block.queue().launch(share=False)