Spaces:
Sleeping
Sleeping
| import math | |
| import random | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageOps | |
| from diffusers import StableDiffusionInstructPix2PixPipeline | |
| model_id = "timbrooks/instruct-pix2pix" | |
| def main(): | |
| pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, safety_checker=None).to("cpu") | |
| def generate( | |
| input_image: Image.Image, | |
| instruction: str, | |
| steps: int, | |
| randomize_seed: bool, | |
| seed: int, | |
| randomize_cfg: bool, | |
| text_cfg_scale: float, | |
| image_cfg_scale: float, | |
| ): | |
| seed = random.randint(0, 100000) if randomize_seed else seed | |
| text_cfg_scale = round(random.uniform(6.0, 9.0), ndigits=2) if randomize_cfg else text_cfg_scale | |
| image_cfg_scale = round(random.uniform(1.2, 1.8), ndigits=2) if randomize_cfg else image_cfg_scale | |
| width, height = input_image.size | |
| factor = 512 / max(width, height) | |
| factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) | |
| width = int((width * factor) // 64) * 64 | |
| height = int((height * factor) // 64) * 64 | |
| input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) | |
| if instruction == "": | |
| return [input_image, seed] | |
| generator = torch.manual_seed(seed) | |
| edited_image = pipe( | |
| instruction, image=input_image, | |
| guidance_scale=text_cfg_scale, image_guidance_scale=image_cfg_scale, | |
| num_inference_steps=steps, generator=generator, | |
| ).images[0] | |
| return [seed, text_cfg_scale, image_cfg_scale, edited_image] | |
| def reset(): | |
| return [0, "Randomize Seed", 1371, "Fix CFG", 7.5, 1.5, None] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=100): | |
| generate_button = gr.Button("Generate") | |
| with gr.Column(scale=1, min_width=100): | |
| reset_button = gr.Button("Reset") | |
| with gr.Column(scale=3): | |
| instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True) | |
| with gr.Row(): | |
| input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
| edited_image = gr.Image(label=f"Edited Image", type="pil", interactive=False) | |
| input_image.style(height=512, width=512) | |
| edited_image.style(height=512, width=512) | |
| with gr.Row(): | |
| steps = gr.Number(value=50, precision=0, label="Steps", interactive=True) | |
| randomize_seed = gr.Radio( | |
| ["Fix Seed", "Randomize Seed"], | |
| value="Randomize Seed", | |
| type="index", | |
| show_label=False, | |
| interactive=True, | |
| ) | |
| seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True) | |
| randomize_cfg = gr.Radio( | |
| ["Fix CFG", "Randomize CFG"], | |
| value="Fix CFG", | |
| type="index", | |
| show_label=False, | |
| interactive=True, | |
| ) | |
| text_cfg_scale = gr.Number(value=7.5, label=f"Text CFG", interactive=True) | |
| image_cfg_scale = gr.Number(value=1.5, label=f"Image CFG", interactive=True) | |
| generate_button.click( | |
| fn=generate, | |
| inputs=[ | |
| input_image, | |
| instruction, | |
| steps, | |
| randomize_seed, | |
| seed, | |
| randomize_cfg, | |
| text_cfg_scale, | |
| image_cfg_scale, | |
| ], | |
| outputs=[seed, text_cfg_scale, image_cfg_scale, edited_image], | |
| ) | |
| reset_button.click( | |
| fn=reset, | |
| inputs=[], | |
| outputs=[steps, randomize_seed, seed, randomize_cfg, text_cfg_scale, image_cfg_scale, edited_image], | |
| ) | |
| demo.queue(concurrency_count=1) | |
| demo.launch(share=False) | |
| if __name__ == "__main__": | |
| main() |