nigeljw commited on
Commit
78392d4
·
1 Parent(s): 8167b2b

Simplified demo with masks and better user control

Browse files
Files changed (3) hide show
  1. app.py +31 -42
  2. assets/masks/sphere.png +0 -0
  3. assets/masks/square.png +0 -0
app.py CHANGED
@@ -3,10 +3,7 @@ import torch
3
  import numpy
4
  from PIL import Image
5
  from torchvision import transforms
6
- #from torchvision import transforms
7
  from diffusers import StableDiffusionInpaintPipeline
8
- #from diffusers import StableDiffusionUpscalePipeline
9
- #from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
10
  from diffusers import DPMSolverMultistepScheduler
11
 
12
  deviceStr = "cuda" if torch.cuda.is_available() else "cpu"
@@ -19,68 +16,60 @@ if deviceStr == "cuda":
19
  safety_checker=lambda images, **kwargs: (images, False))
20
  pipeline.to(device)
21
  pipeline.enable_xformers_memory_efficient_attention()
 
22
  else:
23
  pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",
24
- safety_checker=lambda images, **kwargs: (images, False))
25
-
26
- #superresolutionPipe = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler")
27
-
28
- #pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
29
- #generator = torch.Generator(device).manual_seed(seed)
30
- latents = torch.randn((1, 4, 64, 64), device=device)
31
- schedulers = [
32
- "DDIMScheduler", "LMSDiscreteScheduler", "PNDMScheduler"
33
- ]
34
- latentNoiseInputs = [
35
- "Uniform", "Low Discrepency Sequence"
36
- ]
37
-
38
- imageSize = (512, 512, 3)
39
- imageSize2 = (512, 512)
40
- #lastImage = Image.new(mode="RGB", size=(imageSize[0], imageSize[1]))
41
 
42
- def diffuse(prompt, negativePrompt, inputImage, mask, guidanceScale, numInferenceSteps, seed, noiseScheduler, latentNoise):
43
- #width = inputImage.size[1]
44
- #height = 512
45
- #print(inputImage.size)
46
- #image = numpy.resize(inputImage, imageSize)
47
- #pilImage.thumbnail(imageSize2)
48
 
49
- #transforms.Resize(imageSize2)(inputImage)
 
50
 
51
- #pilImage = Image.fromarray(inputImage)
52
- #pilImage.resize(imageSize2)
53
- #imageArray = numpy.asarray(pilImage)
54
 
55
- #inputImage = torch.nn.functional.interpolate(inputImage, size=imageSize)
56
-
57
- if mask is None:
58
- return inputImage
 
 
 
 
 
 
 
 
59
 
60
- generator = torch.Generator(device).manual_seed(seed)
61
-
62
  newImage = pipeline(prompt=prompt,
63
  negative_prompt=negativePrompt,
64
  image=inputImage,
65
  mask_image=mask,
66
  guidance_scale=guidanceScale,
67
  num_inference_steps=numInferenceSteps,
 
68
  generator=generator).images[0]
69
 
 
 
70
  return newImage
71
 
 
 
72
  prompt = gradio.Textbox(label="Prompt", placeholder="A person in a room", lines=3)
73
  negativePrompt = gradio.Textbox(label="Negative Prompt", placeholder="Text", lines=3)
74
- #inputImage = gradio.Image(label="Input Image", type="pil")
75
  inputImage = gradio.Image(label="Input Feed", source="webcam", shape=[512,512], streaming=True)
76
- mask = gradio.Image(label="Mask", type="pil")
77
  outputImage = gradio.Image(label="Extrapolated Field of View")
78
  guidanceScale = gradio.Slider(label="Guidance Scale", maximum=1, value=0.75)
79
  numInferenceSteps = gradio.Slider(label="Number of Inference Steps", maximum=100, value=25)
80
- seed = gradio.Slider(label="Generator Seed", maximum=1000, value=512)
81
- noiseScheduler = gradio.Dropdown(schedulers, label="Noise Scheduler", value="DDIMScheduler")
82
- latentNoise = gradio.Dropdown(latentNoiseInputs, label="Latent Noise", value="Iniform")
83
 
84
- inputs=[prompt, negativePrompt, inputImage, mask, guidanceScale, numInferenceSteps, seed, noiseScheduler, latentNoise]
85
  ux = gradio.Interface(fn=diffuse, title="View Diffusion", inputs=inputs, outputs=outputImage, live=True)
86
  ux.launch()
 
3
  import numpy
4
  from PIL import Image
5
  from torchvision import transforms
 
6
  from diffusers import StableDiffusionInpaintPipeline
 
 
7
  from diffusers import DPMSolverMultistepScheduler
8
 
9
  deviceStr = "cuda" if torch.cuda.is_available() else "cpu"
 
16
  safety_checker=lambda images, **kwargs: (images, False))
17
  pipeline.to(device)
18
  pipeline.enable_xformers_memory_efficient_attention()
19
+ latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16)
20
  else:
21
  pipeline = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting",
22
+ safety_checker=lambda images, **kwargs: (images, False))
23
+ latents = torch.randn((1, 4, 64, 64), device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ imageSize = (512, 512)
26
+ lastImage = Image.new(mode="RGB", size=imageSize)
 
 
 
 
27
 
28
+ lastSeed = 512
29
+ generator = torch.Generator(device).manual_seed(512)
30
 
31
+ def diffuse(staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed):
32
+ global latents, lastSeed, generator, deviceStr, lastImage
 
33
 
34
+ if mask is None or pauseInference is True:
35
+ return lastImage
36
+
37
+ if staticLatents is False:
38
+ if deviceStr == "cuda":
39
+ latents = torch.randn((1, 4, 64, 64), device=device, dtype=torch.float16)
40
+ else:
41
+ latents = torch.randn((1, 4, 64, 64), device=device)
42
+
43
+ if lastSeed != seed:
44
+ generator = torch.Generator(device).manual_seed(seed)
45
+ lastSeed = seed
46
 
 
 
47
  newImage = pipeline(prompt=prompt,
48
  negative_prompt=negativePrompt,
49
  image=inputImage,
50
  mask_image=mask,
51
  guidance_scale=guidanceScale,
52
  num_inference_steps=numInferenceSteps,
53
+ latents=latents,
54
  generator=generator).images[0]
55
 
56
+ lastImage = newImage
57
+
58
  return newImage
59
 
60
+ defaultMask = Image.open("assets\masks\sphere.png")
61
+
62
  prompt = gradio.Textbox(label="Prompt", placeholder="A person in a room", lines=3)
63
  negativePrompt = gradio.Textbox(label="Negative Prompt", placeholder="Text", lines=3)
 
64
  inputImage = gradio.Image(label="Input Feed", source="webcam", shape=[512,512], streaming=True)
65
+ mask = gradio.Image(label="Mask", type="pil", value=defaultMask)
66
  outputImage = gradio.Image(label="Extrapolated Field of View")
67
  guidanceScale = gradio.Slider(label="Guidance Scale", maximum=1, value=0.75)
68
  numInferenceSteps = gradio.Slider(label="Number of Inference Steps", maximum=100, value=25)
69
+ seed = gradio.Slider(label="Generator Seed", maximum=10000, value=4096)
70
+ staticLatents =gradio.Checkbox(label="Static Latents", value=True)
71
+ pauseInference = gradio.Checkbox(label="Pause Inference", value=False)
72
 
73
+ inputs=[staticLatents, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps, seed]
74
  ux = gradio.Interface(fn=diffuse, title="View Diffusion", inputs=inputs, outputs=outputImage, live=True)
75
  ux.launch()
assets/masks/sphere.png ADDED
assets/masks/square.png ADDED