Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import numpy as np | |
import random | |
from PIL import Image | |
import torch | |
from diffusers import ( | |
ControlNetModel, | |
DiffusionPipeline, | |
StableDiffusionControlNetPipeline, | |
StableDiffusionXLControlNetPipeline, | |
UniPCMultistepScheduler, | |
EulerDiscreteScheduler, | |
AutoencoderKL | |
) | |
from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor | |
from transformers import CLIPImageProcessor | |
from diffusers.utils import load_image | |
from gradio_imageslider import ImageSlider | |
import boto3 | |
from io import BytesIO | |
device = "cuda" | |
base_model_id = "SG161222/RealVisXL_V4.0" | |
controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0" | |
vae_model_id = "madebyollin/sdxl-vae-fp16-fix" | |
if torch.cuda.is_available(): | |
# load pipe | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_model_id, | |
# variant="fp16", | |
use_safetensors=True, | |
torch_dtype=torch.float32 | |
) | |
# vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.float16) | |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
base_model_id, | |
controlnet=controlnet, | |
# vae=vae, | |
# variant="fp16", | |
use_safetensors=True, | |
torch_dtype=torch.float32, | |
) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
pipe.to(device) | |
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") | |
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
USE_TORCH_COMPILE = 0 | |
ENABLE_CPU_OFFLOAD = 0 | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def get_depth_map(image): | |
original_size = (image.size[1], image.size[0]) | |
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") | |
with torch.no_grad(), torch.autocast("cuda"): | |
depth_map = depth_estimator(image).predicted_depth | |
depth_map = torch.nn.functional.interpolate( | |
depth_map.unsqueeze(1), | |
size=original_size, | |
mode="bicubic", | |
align_corners=False, | |
) | |
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) | |
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) | |
depth_map = (depth_map - depth_min) / (depth_max - depth_min) | |
image = torch.cat([depth_map] * 3, dim=1) | |
image = image.permute(0, 2, 3, 1).cpu().numpy()[0] | |
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) | |
return image | |
def upload_to_s3(image, region, access_key, secret_key, bucket_name): | |
s3 = boto3.client( | |
's3', | |
region_name=region, | |
aws_access_key_id=access_key, | |
aws_secret_access_key=secret_key | |
) | |
image_key = f"generated_images/{random.randint(0, MAX_SEED)}.png" | |
buffer = BytesIO() | |
image.save(buffer, "PNG") | |
buffer.seek(0) | |
s3.upload_fileobj(buffer, bucket_name, image_key) | |
return image_key | |
def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, region, access_key, secret_key, progress=gr.Progress()): | |
if image_url: | |
orginal_image = load_image(image_url) | |
else: | |
orginal_image = Image.fromarray(image) | |
size = (orginal_image.size[0], orginal_image.size[1]) | |
print(size) | |
depth_image = get_depth_map(orginal_image) | |
generator = torch.Generator().manual_seed(seed) | |
generated_image = pipe( | |
prompt=prompt, | |
negative_prompt=n_prompt, | |
width=size[0], | |
height=size[1], | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_steps, | |
strength=control_strength, | |
generator=generator, | |
image=depth_image | |
).images[0] | |
if upload_to_s3: | |
url = upload_to_s3(generated_image, region, access_key, secret_key, bucket) | |
result = {"status": "success", "url": url} | |
else: | |
result = {"status": "success", "message": "Image generated but not uploaded"} | |
return [[depth_image, generated_image], json.dumps(result)] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image() | |
image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)") | |
prompt = gr.Textbox(label="Prompt") | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced options", open=True): | |
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1) | |
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1) | |
control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
n_prompt = gr.Textbox( | |
label="Negative prompt", | |
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", | |
) | |
upload_to_s3 = gr.Checkbox(label="Upload to S3", value=False) | |
region = gr.Textbox(label="S3 Region", placeholder="Enter S3 region here") | |
access_key = gr.Textbox(label="Access Key", placeholder="Enter S3 access key here") | |
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter S3 secret key here") | |
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter S3 bucket name here") | |
with gr.Column(): | |
result = ImageSlider(label="Generate image", type="pil", slider_color="pink") | |
logs = gr.Textbox(label="logs") | |
inputs = [ | |
image, | |
image_url, | |
prompt, | |
n_prompt, | |
num_steps, | |
guidance_scale, | |
control_strength, | |
seed, | |
upload_to_s3, | |
region, | |
access_key, | |
secret_key, | |
bucket | |
] | |
run_button.click( | |
fn=randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False, | |
).then( | |
fn=process, | |
inputs=inputs, | |
outputs=[result, logs], | |
api_name=False | |
) | |
demo.queue().launch() |