Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import time | |
import os | |
import gradio as gr | |
import torch | |
from einops import rearrange | |
from PIL import Image | |
from flux.cli import SamplingOptions | |
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack | |
from flux.util import load_ae, load_clip, load_flow_model, load_t5 | |
from pulid.pipeline_flux import PuLIDPipeline | |
from pulid.utils import resize_numpy_image_long | |
def get_models(name: str, device: torch.device, offload: bool): | |
t5 = load_t5(device, max_length=128) | |
clip = load_clip(device) | |
model = load_flow_model(name, device="cpu" if offload else device) | |
model.eval() | |
ae = load_ae(name, device="cpu" if offload else device) | |
return model, ae, t5, clip | |
class FluxGenerator: | |
def __init__(self): | |
self.device = torch.device('cuda') | |
self.offload = False | |
self.model_name = 'flux-dev' | |
self.model, self.ae, self.t5, self.clip = get_models( | |
self.model_name, | |
device=self.device, | |
offload=self.offload, | |
) | |
self.pulid_model = PuLIDPipeline(self.model, 'cuda', weight_dtype=torch.bfloat16) | |
self.pulid_model.load_pretrain() | |
flux_generator = FluxGenerator() | |
def generate_image( | |
width, | |
height, | |
num_steps, | |
start_step, | |
guidance, | |
seed, | |
prompt, | |
id_image=None, | |
id_weight=1.0, | |
neg_prompt="", | |
true_cfg=1.0, | |
timestep_to_start_cfg=1, | |
max_sequence_length=128, | |
): | |
flux_generator.t5.max_length = max_sequence_length | |
seed = int(seed) | |
if seed == -1: | |
seed = None | |
opts = SamplingOptions( | |
prompt=prompt, | |
width=width, | |
height=height, | |
num_steps=num_steps, | |
guidance=guidance, | |
seed=seed, | |
) | |
if opts.seed is None: | |
opts.seed = torch.Generator(device="cpu").seed() | |
t0 = time.perf_counter() | |
use_true_cfg = abs(true_cfg - 1.0) > 1e-2 | |
if id_image is not None: | |
id_image = resize_numpy_image_long(id_image, 1024) | |
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) | |
else: | |
id_embeddings = None | |
uncond_id_embeddings = None | |
# prepare input | |
x = get_noise( | |
1, | |
opts.height, | |
opts.width, | |
device=flux_generator.device, | |
dtype=torch.bfloat16, | |
seed=opts.seed, | |
) | |
timesteps = get_schedule( | |
opts.num_steps, | |
x.shape[-1] * x.shape[-2] // 4, | |
shift=True, | |
) | |
if flux_generator.offload: | |
flux_generator.t5, flux_generator.clip = flux_generator.t5.to(flux_generator.device), flux_generator.clip.to(flux_generator.device) | |
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt) | |
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None | |
# offload TEs to CPU, load model to gpu | |
if flux_generator.offload: | |
flux_generator.t5, flux_generator.clip = flux_generator.t5.cpu(), flux_generator.clip.cpu() | |
torch.cuda.empty_cache() | |
flux_generator.model = flux_generator.model.to(flux_generator.device) | |
# denoise initial noise | |
x = denoise( | |
flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, | |
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
neg_txt=inp_neg["txt"] if use_true_cfg else None, | |
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, | |
neg_vec=inp_neg["vec"] if use_true_cfg else None, | |
) | |
# offload model, load autoencoder to gpu | |
if flux_generator.offload: | |
flux_generator.model.cpu() | |
torch.cuda.empty_cache() | |
flux_generator.ae.decoder.to(x.device) | |
# decode latents to pixel space | |
x = unpack(x.float(), opts.height, opts.width) | |
with torch.autocast(device_type=flux_generator.device.type, dtype=torch.bfloat16): | |
x = flux_generator.ae.decode(x) | |
if flux_generator.offload: | |
flux_generator.ae.decoder.cpu() | |
torch.cuda.empty_cache() | |
t1 = time.perf_counter() | |
# bring into PIL format | |
x = x.clamp(-1, 1) | |
x = rearrange(x[0], "c h w -> h w c") | |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | |
return img, str(opts.seed), flux_generator.pulid_model.debug_img_list | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
offload: bool = False): | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: | |
gr.Markdown("## AI ํฌํ ์ง๋: ์ฐ์ฃผ") | |
gr.Markdown("### ์ด์ฉ ์๋ด: 1) ์์์ค ํ๋๋ฅผ ์ ํ. 2) ์นด๋ฉ๋ผ ๋ฒํผ์ ํด๋ฆญํ๊ณ ์ผ๊ตด์ด ๋ณด์ด๋ฉด ์นด๋ฉ๋ผ ๋ฒํผ ํด๋ฆญ. 3) '์์ฑ' ๋ฒํผ์ ํด๋ฆญํ๊ณ ๊ธฐ๋ค๋ฆฌ๋ฉด ๋ฉ๋๋ค.") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="ํ๋กฌํํธ", value="์ด์ํ, ์๊ฐ, ์ํ์ ") | |
id_image = gr.Image(label="ID ์ด๋ฏธ์ง", sources=["webcam", "upload"], type="numpy") | |
generate_btn = gr.Button("์์ฑ") | |
with gr.Column(): | |
output_image = gr.Image(label="์์ฑ๋ ์ด๋ฏธ์ง") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### ์์") | |
all_examples = [ | |
['์ฌ์๊ฐ "PuLID for FLUX"๋ผ๊ณ ์ฐ์ธ ๋น๋๋ ๋ น์ ํ์งํ์ ๋ค๊ณ ์๋ค', 'example_inputs/liuyifei.png'], | |
['์๋ชจ์ต ์ด์ํ', 'example_inputs/liuyifei.png'], | |
['VR ๊ธฐ์ ๋ถ์๊ธฐ์ ํฐ ๋จธ๋ฆฌ ์ฌ์ฑ', 'example_inputs/liuyifei.png'], | |
['์ด๋ฆฐ ์์ด๊ฐ ์์ด์คํฌ๋ฆผ์ ๋จน๊ณ ์๋ค', 'example_inputs/liuyifei.png'], | |
['๋จ์๊ฐ "PuLID for FLUX"๋ผ๊ณ ์ฐ์ธ ํ์งํ์ ๋ค๊ณ ์๋ค, ๊ฒจ์ธ, ๋ ๋ด๋ฆผ', 'example_inputs/pengwei.jpg'], | |
['์ด์ํ, ์ด๋ถ ์กฐ๋ช ', 'example_inputs/pengwei.jpg'], | |
['25์ธ ๋จ์ฑ์ ์ด๋์ด ํ๋กํ ์ฌ์ง, ์ ์์ ์ฐ๊ธฐ๊ฐ ๋์ค๊ณ ์์', 'example_inputs/pengwei.jpg'], | |
['๋ฏธ๊ตญ ๋งํ ์คํ์ผ, ์๋ 1๋ช ', 'example_inputs/pengwei.jpg'], | |
['์ด์ํ, ํฝ์ฌ ์คํ์ผ', 'example_inputs/pengwei.jpg'], | |
['์ด์ํ, ์ผ์ ์กฐ๊ฐ์', 'example_inputs/lecun.jpg'], | |
] | |
example_images = [example[1] for example in all_examples] | |
example_captions = [example[0] for example in all_examples] | |
gallery = gr.Gallery( | |
value=list(zip(example_images, example_captions)), | |
label="์์ ๊ฐค๋ฌ๋ฆฌ", | |
show_label=False, | |
elem_id="gallery", | |
columns=5, | |
rows=2, | |
object_fit="contain", | |
height="auto" | |
) | |
def fill_example(evt: gr.SelectData): | |
return [all_examples[evt.index][i] for i in [0, 1]] | |
gallery.select( | |
fill_example, | |
None, | |
[prompt, id_image], | |
) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
gr.Slider(256, 1536, 896, step=16, visible=False), # width | |
gr.Slider(256, 1536, 1152, step=16, visible=False), # height | |
gr.Slider(1, 20, 20, step=1, visible=False), # num_steps | |
gr.Slider(0, 10, 0, step=1, visible=False), # start_step | |
gr.Slider(1.0, 10.0, 4, step=0.1, visible=False), # guidance | |
gr.Textbox(-1, visible=False), # seed | |
prompt, | |
id_image, | |
gr.Slider(0.0, 3.0, 1, step=0.05, visible=False), # id_weight | |
gr.Textbox("์ ํ์ง, ์ต์ ์ ํ์ง, ํ ์คํธ, ์๋ช , ์ํฐ๋งํฌ, ์ฌ๋ถ์ ํ๋ค๋ฆฌ", visible=False), # neg_prompt | |
gr.Slider(1.0, 10.0, 1, step=0.1, visible=False), # true_cfg | |
gr.Slider(0, 20, 1, step=1, visible=False), # timestep_to_start_cfg | |
gr.Slider(128, 512, 128, step=128, visible=False), # max_sequence_length | |
], | |
outputs=[output_image], | |
) | |
return demo | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev") | |
parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'), | |
help="ํ์ฌ๋ flux-dev๋ง ์ง์ํฉ๋๋ค") | |
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", | |
help="์ฌ์ฉํ ๋๋ฐ์ด์ค") | |
parser.add_argument("--offload", action="store_true", help="์ฌ์ฉํ์ง ์์ ๋ ๋ชจ๋ธ์ CPU๋ก ์ฎ๊น๋๋ค") | |
parser.add_argument("--port", type=int, default=8080, help="์ฌ์ฉํ ํฌํธ") | |
parser.add_argument("--dev", action='store_true', help="๊ฐ๋ฐ ๋ชจ๋") | |
parser.add_argument("--pretrained_model", type=str, help='๊ฐ๋ฐ์ฉ') | |
args = parser.parse_args() | |
import huggingface_hub | |
huggingface_hub.login(os.getenv('HF_TOKEN')) | |
demo = create_demo(args, args.name, args.device, args.offload) | |
demo.launch() | |