ilanser's picture
Update app.py
77db537
raw
history blame contribute delete
No virus
2.31 kB
import gradio as gr
from PIL import Image
import base64
import io
import glob
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
def predict(sketch, description):
# Convert sketch to PIL image
sketch_pil = Image.fromarray(sketch)
hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
image = hed(sketch_pil, scribble=True)
model_id = "runwayml/stable-diffusion-v1-5"
controlnet_id = "lllyasviel/sd-controlnet-scribble"
# Load ControlNet model
controlnet = ControlNetModel.from_pretrained(controlnet_id)
# Create pipeline with ControlNet model
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet)
# Use improved scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# Enable smart CPU offloading and memory efficient attention
# pipe.enable_model_cpu_offload()
# pipe.enable_xformers_memory_efficient_attention()
# Move pipeline to GPU
# pipe = pipe.to("cuda")
result = pipe(description, image, num_inference_steps=10).images[0]
return result
with gr.Blocks() as iface:
# Define sketchpad with custom size and stroke width
sketchpad = gr.Sketchpad(shape=(400, 300), brush_radius=5, label="Sketchpad- Draw something")
txt= gr.Textbox(lines=3, label="Description - Describe your sketch with style")
im = gr.Image(label="Output Image", interactive=False)
button = gr.Button(value="Submit")
button.click(predict, inputs=[sketchpad, txt], outputs=im)
flag= gr.CSVLogger()
flag.setup([sketchpad, txt, im], "flagged_data_points")
button_flag = gr.Button(value="Flag")
button_flag.click(lambda *args: flag.flag(args), [sketchpad, txt, im], None, preprocess=False)
# iface = gr.Interface(fn=predict, inputs=[sketchpad, "text"], outputs=im, live=False, title="Sketch2Image")
## get all the file path from flagged/sketch folder into a list
sketch_path = glob.glob("flagged/sketch/*.png")
# gr.Examples(examples = list(map(lambda x: [x ,"draw in the style of crayon by kids"], sketch_path)), inputs=[sketchpad,txt], outputs=im, fn=predict, cache_examples=True)
iface.launch()