praeclarumjj3 commited on
Commit
ba9718c
β€’
1 Parent(s): a5c1cde

Switch to Streamlit

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: FcF Inpainting
3
  emoji: πŸͺ„ ✨ ✨
4
  colorFrom: blue
5
  colorTo: purple
6
- sdk: gradio
7
- sdk_version: 2.9.4
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: FcF-Inpainting
3
  emoji: πŸͺ„ ✨ ✨
4
  colorFrom: blue
5
  colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -9,13 +9,9 @@ from PIL import Image
9
  import numpy as np
10
  import torch
11
  import legacy
12
- import paddlehub as hub
13
  import cv2
14
-
15
- u2net = hub.Module(name='U2Net')
16
-
17
- # gradio app imports
18
- import gradio as gr
19
  from torchvision.transforms import ToTensor, ToPILImage
20
  image_to_tensor = ToTensor()
21
  tensor_to_image = ToPILImage()
@@ -24,6 +20,17 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  class_idx = None
25
  truncation_psi = 0.1
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  def create_model(network_pkl):
28
  print('Loading networks from "%s"...' % network_pkl)
29
  with dnnlib.util.open_url(network_pkl) as f:
@@ -48,11 +55,6 @@ def fcf_inpaint(G, org_img, erased_img, mask):
48
  comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
49
  return comp_img
50
 
51
- def show_images(img, width, height):
52
- """ Display a batch of images inline. """
53
- img = Image.fromarray(img)
54
- img = img.resize((width, height))
55
- return img
56
 
57
  def denorm(img):
58
  img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
@@ -64,38 +66,32 @@ def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
64
  img = np.array(pil_img)
65
  return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
66
 
67
- def inpaint(input_img, mask, option):
68
- width, height = input_img.size
69
-
70
- if option == "Automatic":
71
- result = u2net.Segmentation(
72
- images=[cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)],
73
- paths=None,
74
- batch_size=1,
75
- input_size=320,
76
- output_dir='output',
77
- visualization=True)
78
- mask = Image.fromarray(result[0]['mask'])
79
- mask = mask.convert('L')
80
- else:
81
- mask = mask.resize((width,height))
82
 
83
- if width != 512 or height != 512:
84
- input_img = input_img.resize((512, 512))
85
- mask = mask.resize((512, 512))
86
-
87
- rgb = input_img.convert('RGB')
88
  rgb = np.array(rgb)
 
 
 
89
 
90
- mask = np.array(mask)
91
- if option == 'Manual':
92
- mask = (mask[:, :, 0] == rgb[:, :, 0]) * (mask[:, :, 1] == rgb[:, :, 1]) * (mask[:, :, 2] == rgb[:, :, 2])
93
- mask = 1. - mask.astype(np.float32) * 1.
94
- kernel = np.ones((3, 3), np.uint8)
95
- mask = cv2.dilate(mask, kernel)
96
- mask = mask * 255.
 
 
97
 
98
- mask /= 255.
 
 
 
 
 
99
  mask_tensor = torch.from_numpy(mask).to(torch.float32)
100
  mask_tensor = mask_tensor.unsqueeze(0)
101
  mask_tensor = mask_tensor.unsqueeze(0).to(device)
@@ -107,60 +103,93 @@ def inpaint(input_img, mask, option):
107
  rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
108
  rgb_erased = rgb_erased.to(torch.float32)
109
 
110
- model = create_model("models/places_512.pkl")
111
  comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
112
  rgb_erased = denorm(rgb_erased)
113
  comp_img = denorm(comp_img)
 
114
 
115
- return show_images(rgb_erased, width, height), show_images(comp_img, width, height)
116
-
117
- gradio_inputs = [gr.inputs.Image(type='pil',
118
- tool="editor",
119
- label="Image"),
120
- # gr.inputs.Image(type='pil',source="canvas", label="Mask", invert_colors=True),
121
- gr.inputs.Image(type='pil',
122
- tool="editor",
123
- label="Mask"),
124
- gr.inputs.Radio(choices=["Automatic", "Manual"], type="value", default="Manual", label="Masking Choice")
125
- ]
126
-
127
- gradio_outputs = [gr.outputs.Image(label='Image with Hole'),
128
- gr.outputs.Image(label='Inpainted Image')]
129
-
130
- examples = [['test_512/person512.png', 'test_512/person512.png', 'Automatic'],
131
- ['test_512/a_org.png', 'test_512/a_overlay.png', 'Manual'],
132
- ['test_512/f_org.png', 'test_512/f_overlay.png', 'Manual'],
133
- ['test_512/g_org.png', 'test_512/g_overlay.png', 'Manual'],
134
- ['test_512/h_org.png', 'test_512/h_overlay.png', 'Manual'],
135
- ['test_512/i_org.png', 'test_512/i_overlay.png', 'Manual'],
136
- ['test_512/b_org.png', 'test_512/b_overlay.png', 'Manual'],
137
- ['test_512/c_org.png', 'test_512/c_overlay.png', 'Manual'],
138
- ['test_512/d_org.png', 'test_512/d_overlay.png', 'Manual'],
139
- ['test_512/e_org.png', 'test_512/e_overlay.png', 'Manual']]
140
-
141
- title = "FcF-Inpainting"
142
 
143
- description = "<p style='color:royalblue; font-weight: w300;'> \
144
- [Note: Queue time may take up to 20 seconds! The image and mask are resized to 512x512 before inpainting.] To use FcF-Inpainting: <br> \
145
- (1) <span style='color:#E0B941;'>Upload the Same Image to both</span> input boxes (Image and Mask) below. <br> \
146
- (2a) <span style='color:#E0B941;'>Manual Option:</span> The TUI Image Editor used by gradio <a style='color: #E0B941;' href='https://github.com/gradio-app/gradio/issues/1810' target='_blank'>changes the image when saved</a>. We compute the mask after comparing the two inputs. Therefore, we need to save both inputs: <br> \
147
- &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; - <span style='color:#E0B941;'>Image:</span> Click on the edit button on the top-right and save without making any changes. <br> \
148
- &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; - <span style='color:#E0B941;'>Mask:</span> Draw a mask (hole) using the brush (click on the edit button in the top right of the Mask View and select the draw option). <br> \
149
- (2b) <span style='color:#E0B941;'>Automatic Option:</span> This option will generate a mask using a pretrained U2Net model. <br> \
150
- (3) Click on <span style='color:#E0B941;'>Submit</span> and witness the MAGIC! πŸͺ„ ✨ ✨</p>"
151
-
152
- article = "<p style='color: #E0B941; text-align: center'> <a style='color: #E0B941;' href='https://praeclarumjj3.github.io/fcf-inpainting/' target='_blank'>Project Page</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github Repo</a></p>"
153
-
154
- css = ".image-preview {height: 32rem; width: auto;} .output-image {height: 32rem; width: auto;} .panel-buttons { display: flex; flex-direction: row;}"
155
-
156
- iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
157
- outputs=gradio_outputs,
158
- css=css,
159
- layout="vertical",
160
- theme="dark-huggingface",
161
- examples_per_page=5,
162
- thumbnail="fcf_gan.png",
163
- allow_flagging="never",
164
- examples=examples, title=title,
165
- description=description, article=article)
166
- iface.launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import torch
11
  import legacy
 
12
  import cv2
13
+ from streamlit_drawable_canvas import st_canvas
14
+ import streamlit as st
 
 
 
15
  from torchvision.transforms import ToTensor, ToPILImage
16
  image_to_tensor = ToTensor()
17
  tensor_to_image = ToPILImage()
 
20
  class_idx = None
21
  truncation_psi = 0.1
22
 
23
+ title = "FcF-Inpainting"
24
+
25
+ description = "<p style='color:royalblue; font-size: 14px; font-weight: w300;'> \
26
+ [Note: The image and mask are resized to 512x512 before inpainting. The <span style='color:#E0B941;'>Run FcF-Inpainting</span> button will automatically appear after you draw a mask.] To use FcF-Inpainting: <br> \
27
+ (1) <span style='color:#E0B941;'>Upload an Image</span> or <span style='color:#E0B941;'> select a sample image on the left</span>. <br> \
28
+ (2) Adjust the brush stroke width and <span style='color:#E0B941;'>draw the mask on the image</span>. You may also change the drawing tool on the sidebar. <br>\
29
+ (3) After drawing a mask, click the <span style='color:#E0B941;'>Run FcF-Inpainting</span> and witness the MAGIC! πŸͺ„ ✨ ✨<br> \
30
+ (4) You may <span style='color:#E0B941;'>download/undo/redo/delete</span> the changes on the image using the options below the image box.</p>"
31
+
32
+ article = "<p style='color: #E0B941; font-size: 16px; font-weight: w500; text-align: center'> <a style='color: #E0B941;' href='https://praeclarumjj3.github.io/fcf-inpainting/' target='_blank'>Project Page</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a style='color: #E0B941;' href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github</a></p>"
33
+
34
  def create_model(network_pkl):
35
  print('Loading networks from "%s"...' % network_pkl)
36
  with dnnlib.util.open_url(network_pkl) as f:
 
55
  comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
56
  return comp_img
57
 
 
 
 
 
 
58
 
59
  def denorm(img):
60
  img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
 
66
  img = np.array(pil_img)
67
  return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
68
 
69
+ def process_mask(input_img, mask):
70
+ rgb = cv2.cvtColor(input_img, cv2.COLOR_RGBA2RGB)
71
+ mask = 255 - mask[:,:,3]
72
+ mask = (mask > 0) * 1
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
74
  rgb = np.array(rgb)
75
+ mask_tensor = torch.from_numpy(mask).to(torch.float32)
76
+ mask_tensor = mask_tensor.unsqueeze(0)
77
+ mask_tensor = mask_tensor.unsqueeze(0).to(device)
78
 
79
+ rgb = rgb.transpose(2,0,1)
80
+ rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
81
+ rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
82
+ rgb_erased = rgb.clone()
83
+ rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
84
+ rgb_erased = rgb_erased.to(torch.float32)
85
+
86
+ rgb_erased = denorm(rgb_erased)
87
+ return rgb_erased
88
 
89
+ def inpaint(input_img, mask, model):
90
+ rgb = cv2.cvtColor(input_img, cv2.COLOR_RGBA2RGB)
91
+ mask = 255 - mask[:,:,3]
92
+ mask = (mask > 0) * 1
93
+
94
+ rgb = np.array(rgb)
95
  mask_tensor = torch.from_numpy(mask).to(torch.float32)
96
  mask_tensor = mask_tensor.unsqueeze(0)
97
  mask_tensor = mask_tensor.unsqueeze(0).to(device)
 
103
  rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
104
  rgb_erased = rgb_erased.to(torch.float32)
105
 
 
106
  comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
107
  rgb_erased = denorm(rgb_erased)
108
  comp_img = denorm(comp_img)
109
+ return comp_img
110
 
111
+ def run_app(model):
112
+
113
+ if "button_id" not in st.session_state:
114
+ st.session_state["button_id"] = ""
115
+ if "color_to_label" not in st.session_state:
116
+ st.session_state["color_to_label"] = {}
117
+ image_inpainting(model)
118
+
119
+ with st.sidebar:
120
+ st.markdown("---")
121
+
122
+ def image_inpainting(model):
123
+ if 'reuse_image' not in st.session_state:
124
+ st.session_state.reuse_image = None
125
+
126
+ st.title(title)
127
+ st.markdown(article, unsafe_allow_html=True)
128
+ st.markdown(description, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
129
 
130
+ image = st.sidebar.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
131
+
132
+ sample_image = st.sidebar.radio('Choose a Sample Image', [
133
+ 'scene-background.png',
134
+ 'fence-background.png',
135
+ 'bench.png',
136
+ 'house.png',
137
+ 'landscape.png',
138
+ 'truck.png',
139
+ 'scenery.png',
140
+ 'grass-texture.png',
141
+ 'mapview-texture.png',
142
+ ])
143
+
144
+ drawing_mode = st.sidebar.selectbox(
145
+ "Drawing tool:", ("freedraw", "line")
146
+ )
147
+
148
+ image = Image.open(image).convert("RGBA") if image else Image.open(f"./test_512/{sample_image}").convert("RGBA")
149
+ image = image.resize((512, 512))
150
+ width, height = image.size
151
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 100, 20)
152
+
153
+ canvas_result = st_canvas(
154
+ stroke_color="rgba(255, 0, 255, 0.8)",
155
+ stroke_width=stroke_width,
156
+ background_image=image,
157
+ height=height,
158
+ width=width,
159
+ drawing_mode=drawing_mode,
160
+ key="canvas",
161
+ )
162
+ if canvas_result.image_data is not None and image and len(canvas_result.json_data["objects"]) > 0:
163
+
164
+ im = canvas_result.image_data.copy()
165
+ background = np.where(
166
+ (im[:, :, 0] == 0) &
167
+ (im[:, :, 1] == 0) &
168
+ (im[:, :, 2] == 0)
169
+ )
170
+ drawing = np.where(
171
+ (im[:, :, 0] == 255) &
172
+ (im[:, :, 1] == 0) &
173
+ (im[:, :, 2] == 255)
174
+ )
175
+ im[background]=[0,0,0,255]
176
+ im[drawing]=[0,0,0,0] #RGBA
177
+ if st.button('Run FcF-Inpainting'):
178
+ col1, col2 = st.columns([1,1])
179
+ with col1:
180
+ # if st.button('Show Image with Holes'):
181
+ st.write("Masked Image")
182
+ mask_show = process_mask(np.array(image), np.array(im))
183
+ st.image(mask_show)
184
+ with col2:
185
+ st.write("Inpainted Image")
186
+ inpainted_img = inpaint(np.array(image), np.array(im), model)
187
+ st.image(inpainted_img)
188
+
189
+ if __name__ == "__main__":
190
+ st.set_page_config(
191
+ page_title="FcF-Inpainting", page_icon=":sparkles:"
192
+ )
193
+ st.sidebar.subheader("Configuration")
194
+ model = create_model("models/places_512.pkl")
195
+ run_app(model)
output/result_0.png DELETED
Binary file (247 kB)
 
output/result_mask_0.png DELETED
Binary file (28.5 kB)
 
requirements.txt CHANGED
@@ -1,7 +1,5 @@
1
  icecream==2.1.0
2
  opencv-python-headless
3
- paddlepaddle
4
- paddlehub
5
  psutil==5.8.0
6
  click
7
  requests
@@ -22,6 +20,7 @@ pydrive2
22
  pandas
23
  easydict
24
  kornia==0.5.0
25
- gradio==2.9.4
26
  ipython
27
- Jinja2
 
 
 
1
  icecream==2.1.0
2
  opencv-python-headless
 
 
3
  psutil==5.8.0
4
  click
5
  requests
 
20
  pandas
21
  easydict
22
  kornia==0.5.0
 
23
  ipython
24
+ Jinja2
25
+ streamlit-drawable-canvas
26
+ streamlit==1.11.0
test_512/.DS_Store CHANGED
Binary files a/test_512/.DS_Store and b/test_512/.DS_Store differ
 
test_512/a_mask.png DELETED
Binary file (3.23 kB)
 
test_512/a_overlay.png DELETED
Binary file (518 kB)
 
test_512/b_mask.png DELETED
Binary file (3.83 kB)
 
test_512/b_org.png DELETED
Binary file (573 kB)
 
test_512/b_overlay.png DELETED
Binary file (564 kB)
 
test_512/{c_org.png β†’ bench.png} RENAMED
File without changes
test_512/c_mask.png DELETED
Binary file (4.36 kB)
 
test_512/c_overlay.png DELETED
Binary file (587 kB)
 
test_512/d_mask.png DELETED
Binary file (4.51 kB)
 
test_512/d_overlay.png DELETED
Binary file (566 kB)
 
test_512/e_mask.png DELETED
Binary file (4.35 kB)
 
test_512/e_overlay.png DELETED
Binary file (363 kB)
 
test_512/f_mask.png DELETED
Binary file (3.75 kB)
 
test_512/f_overlay.png DELETED
Binary file (518 kB)
 
test_512/{a_org.png β†’ fence-background.png} RENAMED
File without changes
test_512/g_mask.png DELETED
Binary file (3.86 kB)
 
test_512/g_overlay.png DELETED
Binary file (562 kB)
 
test_512/grass-texture.png ADDED
test_512/h_mask.png DELETED
Binary file (4.18 kB)
 
test_512/h_overlay.png DELETED
Binary file (597 kB)
 
test_512/{f_org.png β†’ house.png} RENAMED
File without changes
test_512/i_mask.png DELETED
Binary file (3.28 kB)
 
test_512/i_org.png DELETED
Binary file (643 kB)
 
test_512/i_overlay.png DELETED
Binary file (637 kB)
 
test_512/{d_org.png β†’ landscape.png} RENAMED
File without changes
test_512/{h_org.png β†’ mapview-texture.png} RENAMED
File without changes
test_512/mask_auto.png DELETED
Binary file (2.29 kB)
 
test_512/{person512.png β†’ scene-background.png} RENAMED
File without changes
test_512/{e_org.png β†’ scenery.png} RENAMED
File without changes
test_512/{g_org.png β†’ truck.png} RENAMED
File without changes