nightfury commited on
Commit
31f6f75
1 Parent(s): 14f47c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -74
app.py CHANGED
@@ -59,7 +59,7 @@ model_id_or_path = "CompVis/stable-diffusion-v1-4"
59
  pipe = StableDiffusionInpaintingPipeline.from_pretrained(
60
  model_id_or_path,
61
  revision="fp16",
62
- torch_dtype=torch.half, #float16
63
  use_auth_token=auth_token
64
  )
65
 
@@ -69,23 +69,25 @@ model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
69
  model.eval()
70
  model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
71
 
 
 
72
  transform = transforms.Compose([
73
  transforms.ToTensor(),
74
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
75
- transforms.Resize((512, 512)),
76
  ])
77
 
78
  def predict(radio, dict, word_mask, prompt=""):
79
  if(radio == "draw a mask above"):
80
  with autocast(device): #"cuda"
81
- init_image = dict["image"].convert("RGB").resize((512, 512))
82
- mask = dict["mask"].convert("RGB").resize((512, 512))
83
  elif(radio == "type what to keep"):
84
  img = transform(dict["image"]).squeeze(0)
85
  word_masks = [word_mask]
86
  with torch.no_grad():
87
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
88
- init_image = dict['image'].convert('RGB').resize((512, 512))
89
  filename = f"{uuid.uuid4()}.png"
90
  plt.imsave(filename,torch.sigmoid(preds[0][0]))
91
  img2 = cv2.imread(filename)
@@ -99,7 +101,7 @@ def predict(radio, dict, word_mask, prompt=""):
99
  word_masks = [word_mask]
100
  with torch.no_grad():
101
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
102
- init_image = dict['image'].convert('RGB').resize((512, 512))
103
  filename = f"{uuid.uuid4()}.png"
104
  plt.imsave(filename,torch.sigmoid(preds[0][0]))
105
  img2 = cv2.imread(filename)
@@ -127,68 +129,6 @@ css = '''
127
  .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
128
  #image_upload .touch-none{display: flex}
129
 
130
- .markdown-body {
131
- font-family: -apple-system,BlinkMacSystemFont,"Segoe UI",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji";
132
- font-size: 16px;
133
- line-height: 1.5;
134
- word-wrap: break-word;
135
- }
136
- .container-lg {
137
- max-width: 1012px;
138
- margin-right: auto;
139
- margin-left: auto;
140
- }
141
- [data-color-mode="auto"][data-light-theme*="light"] {
142
- --color-workflow-card-connector: var(--color-scale-gray-3);
143
- --color-workflow-card-connector-bg: var(--color-scale-gray-3);
144
- --color-workflow-card-connector-inactive: var(--color-border-default);
145
- --color-workflow-card-connector-inactive-bg: var(--color-border-default);
146
- --color-workflow-card-connector-highlight: var(--color-scale-blue-4);
147
- --color-workflow-card-connector-highlight-bg: var(--color-scale-blue-4);
148
- --color-workflow-card-bg: var(--color-scale-white);
149
- --color-workflow-card-inactive-bg: var(--color-canvas-inset);
150
- --color-workflow-card-header-shadow: rgba(0, 0, 0, 0);
151
- --color-workflow-card-progress-complete-bg: var(--color-scale-blue-4);
152
- --color-workflow-card-progress-incomplete-bg: var(--color-scale-gray-2);
153
- --color-discussions-state-answered-icon: var(--color-scale-white);
154
- --color-bg-discussions-row-emoji-box: rgba(209, 213, 218, 0.5);
155
- --color-notifications-button-text: var(--color-fg-muted);
156
- --color-notifications-button-hover-text: var(--color-fg-default);
157
- --color-notifications-button-hover-bg: var(--color-scale-gray-2);
158
- --color-notifications-row-read-bg: var(--color-canvas-subtle);
159
- --color-notifications-row-bg: var(--color-scale-white);
160
- --color-icon-directory: var(--color-scale-blue-3);
161
- --color-checks-step-error-icon: var(--color-scale-red-4);
162
- --color-calendar-halloween-graph-day-L1-bg: #ffee4a;
163
- --color-calendar-halloween-graph-day-L2-bg: #ffc501;
164
- --color-calendar-halloween-graph-day-L3-bg: #fe9600;
165
- --color-calendar-halloween-graph-day-L4-bg: #03001c;
166
- --color-calendar-graph-day-bg: #ebedf0;
167
- --color-calendar-graph-day-border: rgba(27, 31, 35, 0.06);
168
- --color-calendar-graph-day-L1-bg: #9be9a8;
169
- --color-calendar-graph-day-L2-bg: #40c463;
170
- --color-calendar-graph-day-L3-bg: #30a14e;
171
- --color-calendar-graph-day-L4-bg: #216e39;
172
- --color-calendar-graph-day-L1-border: rgba(27, 31, 35, 0.06);
173
- --color-calendar-graph-day-L2-border: rgba(27, 31, 35, 0.06);
174
- --color-calendar-graph-day-L3-border: rgba(27, 31, 35, 0.06);
175
- --color-calendar-graph-day-L4-border: rgba(27, 31, 35, 0.06);
176
- --color-user-mention-fg: var(--color-fg-default);
177
- --color-user-mention-bg: var(--color-attention-subtle);
178
- --color-text-white: var(--color-scale-white);
179
- }
180
- :root {
181
- --Layout-pane-width: 220px;
182
- --Layout-content-width: 100%;
183
- --Layout-template-columns: 1fr var(--Layout-pane-width);
184
- --Layout-template-areas: "content pane";
185
- --Layout-column-gap: 16px;
186
- --Layout-row-gap: 16px;
187
- --Layout-outer-spacing-x: 0px;
188
- --Layout-outer-spacing-y: 0px;
189
- --Layout-inner-spacing-min: 0px;
190
- --Layout-inner-spacing-max: 0px;
191
- }
192
  '''
193
  def swap_word_mask(radio_option):
194
  if(radio_option == "draw a mask above"):
@@ -226,18 +166,18 @@ with image_blocks as demo:
226
  <rect x="69" y="69" width="23" height="23" fill="black"></rect>
227
  <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
228
  <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
229
- <rect x="115" y="46" width="23" height="23" fill="white"></rect>
230
- <rect x="115" y="115" width="23" height="23" fill="white"></rect>
231
  <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
232
  <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
233
  <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
234
  <rect x="92" y="69" width="23" height="23" fill="white"></rect>
235
- <rect x="69" y="46" width="23" height="23" fill="white"></rect>
236
  <rect x="69" y="115" width="23" height="23" fill="white"></rect>
237
  <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
238
- <rect x="46" y="46" width="23" height="23" fill="black"></rect>
239
  <rect x="46" y="115" width="23" height="23" fill="black"></rect>
240
- <rect x="46" y="69" width="23" height="23" fill="black"></rect>
241
  <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
242
  <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
243
  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
@@ -258,7 +198,7 @@ with image_blocks as demo:
258
  with gr.Box(elem_id="mask_radio").style(border=False):
259
  radio = gr.Radio(["draw a mask above", "type what to mask below", "type what to keep"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
260
  word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
261
- img_res = gr.inputs.Dropdown("512*512", "256*256")
262
  prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
263
  radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
264
  radio.change(None, inputs=[], outputs=image_blocks, _js = """
59
  pipe = StableDiffusionInpaintingPipeline.from_pretrained(
60
  model_id_or_path,
61
  revision="fp16",
62
+ torch_dtype=torch.float16, #float16
63
  use_auth_token=auth_token
64
  )
65
 
69
  model.eval()
70
  model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
71
 
72
+ imgRes = 256
73
+
74
  transform = transforms.Compose([
75
  transforms.ToTensor(),
76
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
77
+ transforms.Resize((imgRes, imgRes)),
78
  ])
79
 
80
  def predict(radio, dict, word_mask, prompt=""):
81
  if(radio == "draw a mask above"):
82
  with autocast(device): #"cuda"
83
+ init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
84
+ mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
85
  elif(radio == "type what to keep"):
86
  img = transform(dict["image"]).squeeze(0)
87
  word_masks = [word_mask]
88
  with torch.no_grad():
89
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
90
+ init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
91
  filename = f"{uuid.uuid4()}.png"
92
  plt.imsave(filename,torch.sigmoid(preds[0][0]))
93
  img2 = cv2.imread(filename)
101
  word_masks = [word_mask]
102
  with torch.no_grad():
103
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
104
+ init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
105
  filename = f"{uuid.uuid4()}.png"
106
  plt.imsave(filename,torch.sigmoid(preds[0][0]))
107
  img2 = cv2.imread(filename)
129
  .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
130
  #image_upload .touch-none{display: flex}
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  '''
133
  def swap_word_mask(radio_option):
134
  if(radio_option == "draw a mask above"):
166
  <rect x="69" y="69" width="23" height="23" fill="black"></rect>
167
  <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
168
  <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
169
+ <rect x="115" y="46" width="23" height="23" fill="black"></rect>
170
+ <rect x="115" y="115" width="23" height="23" fill="black"></rect>
171
  <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
172
  <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
173
  <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
174
  <rect x="92" y="69" width="23" height="23" fill="white"></rect>
175
+ <rect x="69" y="46" width="23" height="23" fill="black"></rect>
176
  <rect x="69" y="115" width="23" height="23" fill="white"></rect>
177
  <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
178
+ <rect x="46" y="46" width="23" height="23" fill="white"></rect>
179
  <rect x="46" y="115" width="23" height="23" fill="black"></rect>
180
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
181
  <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
182
  <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
183
  <rect x="23" y="69" width="23" height="23" fill="black"></rect>
198
  with gr.Box(elem_id="mask_radio").style(border=False):
199
  radio = gr.Radio(["draw a mask above", "type what to mask below", "type what to keep"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
200
  word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
201
+ img_res = gr.inputs.Dropdown("512*512", "256*256").style(container=True)
202
  prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
203
  radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
204
  radio.change(None, inputs=[], outputs=image_blocks, _js = """