HopeKr commited on
Commit
d0d8da6
1 Parent(s): 3f6b5b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -28
app.py CHANGED
@@ -1,24 +1,18 @@
1
- from huggingface_hub import notebook_login
2
-
3
- notebook_login()
4
  import inspect
 
5
  from typing import List, Optional, Union
6
-
7
  import numpy as np
8
  import torch
9
-
10
  import PIL
11
  import gradio as gr
12
  from diffusers import StableDiffusionInpaintPipeline
13
- device = "cuda"
14
- model_path = "runwayml/stable-diffusion-inpainting"
15
-
16
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
17
- model_path,
18
- torch_dtype=torch.float16,
19
- ).to(device)
20
  import requests
21
  from io import BytesIO
 
 
 
 
22
 
23
  def image_grid(imgs, rows, cols):
24
  assert len(imgs) == rows*cols
@@ -30,24 +24,36 @@ def image_grid(imgs, rows, cols):
30
  for i, img in enumerate(imgs):
31
  grid.paste(img, box=(i%cols*w, i//cols*h))
32
  return grid
 
33
 
 
 
 
 
 
34
 
35
  def download_image(url):
36
  response = requests.get(url)
37
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
38
 
39
- img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
40
- mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
 
 
 
 
 
 
 
 
41
  image = download_image(img_url).resize((512, 512))
42
- image
43
- mask_image = download_image(mask_url).resize((512, 512))
44
- mask_image
45
- prompt = "a mecha robot sitting on a bench"
46
 
47
  guidance_scale=7.5
48
  num_samples = 3
49
- generator = torch.Generator(device="cuda").manual_seed(0) # change the seed to get different results
50
-
51
  images = pipe(
52
  prompt=prompt,
53
  image=image,
@@ -56,15 +62,10 @@ images = pipe(
56
  generator=generator,
57
  num_images_per_prompt=num_samples,
58
  ).images
59
- # insert initial image in the list so we can compare side by side
60
  images.insert(0, image)
61
  image_grid(images, 1, num_samples + 1)
62
- def predict(dict, prompt):
63
- image = dict['image'].convert("RGB").resize((512, 512))
64
- mask_image = dict['mask'].convert("RGB").resize((512, 512))
65
- images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
66
- return(images[0])
67
- gr.Interface(
68
  predict,
69
  title = 'Stable Diffusion In-Painting',
70
  inputs=[
@@ -74,4 +75,4 @@ def predict(dict, prompt):
74
  outputs = [
75
  gr.Image()
76
  ]
77
- ).launch(debug=True)
 
 
 
 
1
  import inspect
2
+ import os
3
  from typing import List, Optional, Union
 
4
  import numpy as np
5
  import torch
 
6
  import PIL
7
  import gradio as gr
8
  from diffusers import StableDiffusionInpaintPipeline
9
+ from rembg import remove
 
 
 
 
 
 
10
  import requests
11
  from io import BytesIO
12
+ from huggingface_hub import login
13
+
14
+ token = os.getenv("WRITE_TOKEN")
15
+ login(token, True)
16
 
17
  def image_grid(imgs, rows, cols):
18
  assert len(imgs) == rows*cols
 
24
  for i, img in enumerate(imgs):
25
  grid.paste(img, box=(i%cols*w, i//cols*h))
26
  return grid
27
+
28
 
29
+ def predict(dict, prompt):
30
+ image = dict['image'].convert("RGB").resize((512, 512))
31
+ mask_image = dict['mask'].convert("RGB").resize((512, 512))
32
+ images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
33
+ return(images[0])
34
 
35
  def download_image(url):
36
  response = requests.get(url)
37
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
38
 
39
+ model_path = "runwayml/stable-diffusion-inpainting"
40
+
41
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
42
+ model_path,
43
+ # revision="fp16",
44
+ # torch_dtype=torch.float16,
45
+ use_auth_token=True
46
+ )
47
+
48
+ img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
49
  image = download_image(img_url).resize((512, 512))
50
+ inverted_mask_image = remove(data = image, only_mask = True)
51
+ mask_image = PIL.ImageOps.invert(inverted_mask_image)
52
+ prompt = "crazy portal universe"
 
53
 
54
  guidance_scale=7.5
55
  num_samples = 3
56
+ generator = torch.Generator(device="cpu").manual_seed(0) # change the seed to get different results
 
57
  images = pipe(
58
  prompt=prompt,
59
  image=image,
 
62
  generator=generator,
63
  num_images_per_prompt=num_samples,
64
  ).images
 
65
  images.insert(0, image)
66
  image_grid(images, 1, num_samples + 1)
67
+
68
+ gr.Interface(
 
 
 
 
69
  predict,
70
  title = 'Stable Diffusion In-Painting',
71
  inputs=[
 
75
  outputs = [
76
  gr.Image()
77
  ]
78
+ ).launch(debug=True)