radames HF staff commited on
Commit
331e32b
β€’
1 Parent(s): db6894a

use new runwayml inpaint model πŸŽ‰

Browse files
frontend/src/lib/App.svelte CHANGED
@@ -71,12 +71,15 @@
71
  const sessionHash = crypto.randomUUID();
72
  const base64Crop = $maskEl.toDataURL('image/png');
73
 
74
- const payload = {
75
  fn_index: 0,
76
- data: [base64Crop, prompt, 0.75, 7.5, 35, 'patchmatch'],
77
  session_hash: sessionHash
78
  };
79
 
 
 
 
 
80
  const websocket = new WebSocket(PUBLIC_WS_INPAINTING);
81
  // websocket.onopen = async function (event) {
82
  // websocket.send(JSON.stringify({ hash: sessionHash }));
@@ -94,9 +97,12 @@
94
  const data = JSON.parse(event.data);
95
  $loadingState = '';
96
  switch (data.msg) {
 
 
 
97
  case 'send_data':
98
  $loadingState = 'Sending Data';
99
- websocket.send(JSON.stringify(payload));
100
  break;
101
  case 'queue_full':
102
  $loadingState = 'Queue full';
 
71
  const sessionHash = crypto.randomUUID();
72
  const base64Crop = $maskEl.toDataURL('image/png');
73
 
74
+ const hashpayload = {
75
  fn_index: 0,
 
76
  session_hash: sessionHash
77
  };
78
 
79
+ const datapayload = {
80
+ data: [base64Crop, prompt, 0.75, 7.5, 35, 'patchmatch']
81
+ }
82
+
83
  const websocket = new WebSocket(PUBLIC_WS_INPAINTING);
84
  // websocket.onopen = async function (event) {
85
  // websocket.send(JSON.stringify({ hash: sessionHash }));
 
97
  const data = JSON.parse(event.data);
98
  $loadingState = '';
99
  switch (data.msg) {
100
+ case 'send_hash':
101
+ websocket.send(JSON.stringify(hashpayload));
102
+ break;
103
  case 'send_data':
104
  $loadingState = 'Sending Data';
105
+ websocket.send(JSON.stringify({...hashpayload, ...datapayload}));
106
  break;
107
  case 'queue_full':
108
  $loadingState = 'Queue full';
frontend/src/lib/Buttons/RoomsSelector.svelte CHANGED
@@ -25,10 +25,10 @@
25
  }
26
  onMount(() => {
27
  refreshRooms();
28
- window.addEventListener('click', clickHandler, true);
29
  const interval = setInterval(refreshRooms, 3000);
30
  return () => {
31
- window.removeEventListener('click', clickHandler, true);
32
  clearInterval(interval);
33
  };
34
  });
 
25
  }
26
  onMount(() => {
27
  refreshRooms();
28
+ window.addEventListener('pointerdown', clickHandler, true);
29
  const interval = setInterval(refreshRooms, 3000);
30
  return () => {
31
+ window.removeEventListener('pointerdown', clickHandler, true);
32
  clearInterval(interval);
33
  };
34
  });
frontend/src/lib/PromptModal.svelte CHANGED
@@ -23,11 +23,12 @@
23
  inputEl.focus();
24
  prompt = initPrompt;
25
  window.addEventListener('keyup', onKeyup);
26
- window.addEventListener('click', cancelHandler, true);
 
27
 
28
  return () => {
29
  window.removeEventListener('keyup', onKeyup);
30
- window.removeEventListener('click', cancelHandler, true);
31
  };
32
  });
33
  let timer: NodeJS.Timeout;
 
23
  inputEl.focus();
24
  prompt = initPrompt;
25
  window.addEventListener('keyup', onKeyup);
26
+ window.addEventListener('pointerdown', cancelHandler, true);
27
+
28
 
29
  return () => {
30
  window.removeEventListener('keyup', onKeyup);
31
+ window.removeEventListener('pointerdown', cancelHandler, true);
32
  };
33
  });
34
  let timer: NodeJS.Timeout;
requirements.txt CHANGED
@@ -1,15 +1,15 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  huggingface_hub
4
- diffusers
5
- transformers
6
- scikit-image
7
- pillow
8
- opencv-python-headless
9
- fastapi
10
- uvicorn
11
- httpx
12
- gradio
13
- boto3
14
- python-magic
15
- fastapi-utils
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch
3
  huggingface_hub
4
+ diffusers==0.6.0
5
+ transformers==4.23.1
6
+ scikit-image==0.19.3
7
+ Pillow==9.2.0
8
+ opencv-python-headless==4.6.0.66
9
+ fastapi==0.85.1
10
+ uvicorn==0.18.3
11
+ httpx==0.23.0
12
+ gradio==3.6
13
+ boto3==1.24.93
14
+ python-magic==0.4.27
15
+ fastapi-utils==0.2.1
stablediffusion-infinity/app.py CHANGED
@@ -71,21 +71,11 @@ model = {}
71
 
72
 
73
  def get_model():
74
- if "text2img" not in model:
75
- text2img = StableDiffusionPipeline.from_pretrained(
76
- "CompVis/stable-diffusion-v1-4",
77
  revision="fp16",
78
- torch_dtype=torch.float16,
79
- use_auth_token=HF_TOKEN,
80
- ).to("cuda")
81
- inpaint = StableDiffusionInpaintPipeline(
82
- vae=text2img.vae,
83
- text_encoder=text2img.text_encoder,
84
- tokenizer=text2img.tokenizer,
85
- unet=text2img.unet,
86
- scheduler=text2img.scheduler,
87
- safety_checker=text2img.safety_checker,
88
- feature_extractor=text2img.feature_extractor,
89
  ).to("cuda")
90
 
91
  # lms = LMSDiscreteScheduler(
@@ -108,11 +98,10 @@ def get_model():
108
  # inpaint.enable_attention_slicing()
109
  # except:
110
  # pass
111
- model["text2img"] = text2img
112
  model["inpaint"] = inpaint
113
  # model["img2img"] = img2img
114
 
115
- return model["text2img"], model["inpaint"]
116
  # model["img2img"]
117
 
118
 
@@ -127,8 +116,10 @@ def run_outpaint(
127
  guidance,
128
  step,
129
  fill_mode,
 
 
130
  ):
131
- text2img, inpaint = get_model()
132
  sel_buffer = np.array(input_image)
133
  img = sel_buffer[:, :, 0:3]
134
  mask = sel_buffer[:, :, -1]
@@ -167,25 +158,30 @@ def run_outpaint(
167
  mask_image = Image.fromarray(mask)
168
 
169
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
170
- with autocast("cuda"):
171
- images = inpaint(
172
- prompt=prompt_text,
173
- init_image=init_image.resize(
174
- (process_size, process_size), resample=SAMPLING_MODE
175
- ),
176
- mask_image=mask_image.resize((process_size, process_size)),
177
- strength=strength,
178
- num_inference_steps=step,
179
- guidance_scale=guidance,
180
- )
181
  else:
182
  print("text2image")
183
- with autocast("cuda"):
184
- images = text2img(
185
- prompt=prompt_text, height=process_size, width=process_size,
186
- )
 
 
 
187
 
188
- return images['sample'][0], images["nsfw_content_detected"][0]
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
 
191
  with blocks as demo:
@@ -242,6 +238,8 @@ blocks.config['dev_mode'] = False
242
 
243
  app = gr.mount_gradio_app(app, blocks, "/gradio",
244
  gradio_api_url="http://0.0.0.0:7860/gradio/")
 
 
245
  def generateAuthToken():
246
  response = requests.get(f"https://liveblocks.io/api/authorize",
247
  headers={"Authorization": f"Bearer {LIVEBLOCKS_SECRET}"})
 
71
 
72
 
73
  def get_model():
74
+ if "inpaint" not in model:
75
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
76
+ "runwayml/stable-diffusion-inpainting",
77
  revision="fp16",
78
+ torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
79
  ).to("cuda")
80
 
81
  # lms = LMSDiscreteScheduler(
 
98
  # inpaint.enable_attention_slicing()
99
  # except:
100
  # pass
 
101
  model["inpaint"] = inpaint
102
  # model["img2img"] = img2img
103
 
104
+ return model["inpaint"]
105
  # model["img2img"]
106
 
107
 
 
116
  guidance,
117
  step,
118
  fill_mode,
119
+
120
+
121
  ):
122
+ inpaint = get_model()
123
  sel_buffer = np.array(input_image)
124
  img = sel_buffer[:, :, 0:3]
125
  mask = sel_buffer[:, :, -1]
 
158
  mask_image = Image.fromarray(mask)
159
 
160
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
161
+
 
 
 
 
 
 
 
 
 
 
162
  else:
163
  print("text2image")
164
+ print("inpainting")
165
+ img, mask = functbl[fill_mode](img, mask)
166
+ init_image = Image.fromarray(img)
167
+ mask = 255 - mask
168
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
169
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
170
+ mask_image = Image.fromarray(mask)
171
 
172
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
173
+ with autocast("cuda"):
174
+ output = inpaint(
175
+ prompt=prompt_text,
176
+ image=init_image.resize(
177
+ (process_size, process_size), resample=SAMPLING_MODE
178
+ ),
179
+ mask_image=mask_image.resize((process_size, process_size)),
180
+ strength=strength,
181
+ num_inference_steps=step,
182
+ guidance_scale=guidance,
183
+ )
184
+ return output['images'][0], output["nsfw_content_detected"][0]
185
 
186
 
187
  with blocks as demo:
 
238
 
239
  app = gr.mount_gradio_app(app, blocks, "/gradio",
240
  gradio_api_url="http://0.0.0.0:7860/gradio/")
241
+
242
+
243
  def generateAuthToken():
244
  response = requests.get(f"https://liveblocks.io/api/authorize",
245
  headers={"Authorization": f"Bearer {LIVEBLOCKS_SECRET}"})
stablediffusion-infinity/rooms.db CHANGED
Binary files a/stablediffusion-infinity/rooms.db and b/stablediffusion-infinity/rooms.db differ