Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import math | |
import random | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EDMEulerScheduler, StableDiffusionXLInstructPix2PixPipeline, AutoencoderKL | |
from custom_pipeline import CosStableDiffusionXLInstructPix2PixPipeline | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub import InferenceClient | |
from diffusers import StableDiffusion3Pipeline, SD3Transformer2DModel, FlowMatchEulerDiscreteScheduler | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
repo = "fluently/Fluently-XL-Final" | |
pipe = StableDiffusionXLPipeline.from_pretrained(repo, torch_dtype=torch.float16, vae=vae) | |
pipe.load_lora_weights("KingNish/Better-Image-XL-Lora", weight_name="example-03.safetensors", adapter_name="lora") | |
pipe.set_adapters("lora") | |
pipe.to("cuda") | |
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16") | |
refiner.load_lora_weights("KingNish/Better-Image-XL-Lora", weight_name="example-03.safetensors", adapter_name="lora") | |
refiner.set_adapters("lora") | |
refiner.to("cuda") | |
help_text = """ | |
To optimize image results: | |
- Adjust the **Image CFG weight** if the image isn't changing enough or is changing too much. Lower it to allow bigger changes, or raise it to preserve original details. | |
- Modify the **Text CFG weight** to influence how closely the edit follows text instructions. Increase it to adhere more to the text, or decrease it for subtler changes. | |
- Experiment with different **random seeds** and **CFG values** for varied outcomes. | |
- **Rephrase your instructions** for potentially better results. | |
- **Increase the number of steps** for enhanced edits. | |
""" | |
def set_timesteps_patched(self, num_inference_steps: int, device = None): | |
self.num_inference_steps = num_inference_steps | |
ramp = np.linspace(0, 1, self.num_inference_steps) | |
sigmas = torch.linspace(math.log(self.config.sigma_min), math.log(self.config.sigma_max), len(ramp)).exp().flip(0) | |
sigmas = (sigmas).to(dtype=torch.float32, device=device) | |
self.timesteps = self.precondition_noise(sigmas) | |
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) | |
self._step_index = None | |
self._begin_index = None | |
self.sigmas = self.sigmas.to("cpu") | |
# Image Editor | |
edit_file = hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors") | |
EDMEulerScheduler.set_timesteps = set_timesteps_patched | |
pipe_edit = StableDiffusionXLInstructPix2PixPipeline.from_single_file( | |
edit_file, num_in_channels=8, is_cosxl_edit=True, vae=vae, torch_dtype=torch.float16, | |
) | |
pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction") | |
pipe_edit.to("cuda") | |
# Generator | |
def king(type , | |
input_image , | |
instruction: str , | |
steps: int = 8, | |
randomize_seed: bool = False, | |
seed: int = 25, | |
text_cfg_scale: float = 7.3, | |
image_cfg_scale: float = 1.7, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 6, | |
use_resolution_binning: bool = True, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if type=="Image Editing" : | |
if randomize_seed: | |
seed = random.randint(0, 99999) | |
text_cfg_scale = text_cfg_scale | |
image_cfg_scale = image_cfg_scale | |
input_image = input_image | |
steps=steps | |
generator = torch.manual_seed(seed) | |
output_image = pipe_edit( | |
instruction, image=input_image, | |
guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale, | |
num_inference_steps=steps, generator=generator, output_type="latent", | |
).images | |
refine = refiner( | |
prompt=instruction, | |
guidance_scale=guidance_scale, | |
num_inference_steps=steps, | |
image=output_image, | |
generator=generator, | |
).images[0] | |
return seed, refine | |
else : | |
if randomize_seed: | |
seed = random.randint(0, 99999) | |
generator = torch.Generator().manual_seed(seed) | |
image = pipe( | |
prompt = instruction, | |
guidance_scale = guidance_scale, | |
num_inference_steps = steps, | |
width = width, | |
height = height, | |
generator = generator, | |
output_type="latent", | |
).images | |
refine = refiner( | |
prompt=instruction, | |
guidance_scale=guidance_scale, | |
num_inference_steps=steps, | |
image=image, | |
generator=generator, | |
).images[0] | |
return seed, refine | |
client = InferenceClient() | |
# Prompt classifier | |
def response(instruction, input_image=None ): | |
if input_image is None: | |
output="Image Generation" | |
else: | |
try: | |
text = instruction | |
labels = ["Image Editing", "Image Generation"] | |
classification = client.zero_shot_classification(text, labels, multi_label=True) | |
output = classification[0] | |
output = str(output) | |
if "Editing" in output: | |
output = "Image Editing" | |
else: | |
output = "Image Generation" | |
except error: | |
output = "Image Generation" | |
return output | |
css = ''' | |
.gradio-container{max-width: 700px !important} | |
h1{text-align:center} | |
footer { | |
visibility: hidden | |
} | |
''' | |
examples=[ | |
[ | |
"Image Generation", | |
None, | |
"A luxurious supercar with a unique design. The car should have a pearl white finish, and gold accents. 4k, realistic.", | |
], | |
[ | |
"Image Editing", | |
"./supercar.png", | |
"make it red", | |
], | |
[ | |
"Image Editing", | |
"./red_car.png", | |
"add some snow", | |
], | |
[ | |
"Image Generation", | |
None, | |
"An alien grasping a sign board contain word 'ALIEN', futuristic, neonpunk, detailed", | |
], | |
[ | |
"Image Generation", | |
None, | |
"Beautiful Eiffel Tower at Night", | |
], | |
] | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Image Generator Pro") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
instruction = gr.Textbox(lines=1, label="Instruction", interactive=True) | |
with gr.Column(scale=1): | |
type = gr.Dropdown(["Image Generation","Image Editing"], label="Task", value="Image Generation",interactive=True, info="AI will select option based on your query, but if it selects wrong, please choose correct one.") | |
with gr.Column(scale=1): | |
generate_button = gr.Button("Generate") | |
with gr.Row(): | |
input_image = gr.Image(label="Image", type="pil", interactive=True) | |
with gr.Row(): | |
text_cfg_scale = gr.Number(value=7.3, step=0.1, label="Text CFG", interactive=True) | |
image_cfg_scale = gr.Number(value=1.7, step=0.1,label="Image CFG", interactive=True) | |
guidance_scale = gr.Number(value=6.0, step=0.1, label="Image Generation Guidance Scale", interactive=True) | |
steps = gr.Number(value=25, step=1, label="Steps", interactive=True) | |
randomize_seed = gr.Radio( | |
["Fix Seed", "Randomize Seed"], | |
value="Randomize Seed", | |
type="index", | |
show_label=False, | |
interactive=True, | |
) | |
seed = gr.Number(value=1371, step=1, label="Seed", interactive=True) | |
with gr.Row(): | |
width = gr.Slider( label="Width", minimum=256, maximum=2048, step=64, value=1024) | |
height = gr.Slider( label="Height", minimum=256, maximum=2048, step=64, value=1024) | |
gr.Examples( | |
examples=examples, | |
inputs=[type,input_image, instruction], | |
fn=king, | |
outputs=[input_image], | |
cache_examples=False, | |
) | |
gr.Markdown(help_text) | |
instruction.change(fn=response, inputs=[instruction,input_image], outputs=type, queue=False) | |
input_image.upload(fn=response, inputs=[instruction,input_image], outputs=type, queue=False) | |
gr.on(triggers=[ | |
generate_button.click, | |
instruction.submit | |
], | |
fn=king, | |
inputs=[type, | |
input_image, | |
instruction, | |
steps, | |
randomize_seed, | |
seed, | |
text_cfg_scale, | |
image_cfg_scale, | |
width, | |
height, | |
guidance_scale, | |
], | |
outputs=[seed, input_image], | |
) | |
demo.queue(max_size=99999).launch() |