ViewDiffusion / app.py
nigeljw's picture
switched IO paths for host OS
5e772ca
raw
history blame
3.44 kB
import gradio
import torch
import numpy
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionInpaintPipeline
from diffusers import DPMSolverMultistepScheduler
deviceStr = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(deviceStr)
if deviceStr == "cuda":
pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
safety_checker=lambda images, **kwargs: (images, False))
pipeline.to(device)
pipeline.enable_xformers_memory_efficient_attention()
latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16)
else:
pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",
safety_checker=lambda images, **kwargs: (images, False))
latents = torch.randn((1, 4, 64, 64), device=device)
imageSize = (512, 512)
lastImage = Image.new(mode="RGB", size=imageSize)
lastSeed = 512
generator = torch.Generator(device).manual_seed(512)
def diffuse(staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed):
global latents, lastSeed, generator, deviceStr, lastImage
if mask is None or pauseInference is True:
return lastImage
if staticLatents is False:
if deviceStr == "cuda":
latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16)
else:
latents = torch.randn((1, 4, 64, 64), device=device)
if lastSeed != seed:
generator = torch.Generator(device).manual_seed(seed)
lastSeed = seed
newImage = pipeline(prompt=prompt,
negative_prompt=negativePrompt,
image=inputImage,
mask_image=mask,
guidance_scale=guidanceScale,
num_inference_steps=numInferenceSteps,
latents=latents,
generator=generator).images[0]
lastImage = newImage
return newImage
defaultMask = Image.open("assets/masks/sphere.png")
prompt = gradio.Textbox(label="Prompt", placeholder="A person in a room", lines=3)
negativePrompt = gradio.Textbox(label="Negative Prompt", placeholder="Text", lines=3)
inputImage = gradio.Image(label="Input Feed", source="webcam", shape=[512,512], streaming=True)
mask = gradio.Image(label="Mask", type="pil", value=defaultMask)
outputImage = gradio.Image(label="Extrapolated Field of View")
guidanceScale = gradio.Slider(label="Guidance Scale", maximum=1, value=0.75)
numInferenceSteps = gradio.Slider(label="Number of Inference Steps", maximum=100, value=25)
seed = gradio.Slider(label="Generator Seed", maximum=10000, value=4096)
staticLatents =gradio.Checkbox(label="Static Latents", value=True)
pauseInference = gradio.Checkbox(label="Pause Inference", value=False)
inputs=[staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed]
ux = gradio.Interface(fn=diffuse, title="View Diffusion", inputs=inputs, outputs=outputImage, live=True)
ux.launch()