radames HF staff commited on
Commit
61988b7
β€’
1 Parent(s): 3d2cb9e

image uploads on predict fn, much faster πŸš„

Browse files
frontend/src/lib/App.svelte CHANGED
@@ -10,7 +10,7 @@
10
  import { PUBLIC_WS_INPAINTING } from '$env/static/public';
11
  import type { PromptImgKey } from '$lib/types';
12
  import { Status } from '$lib/types';
13
- import { loadingState, currZoomTransform, maskEl } from '$lib/store';
14
  import { useMyPresence, useObject, useOthers } from '$lib/liveblocks';
15
  import { base64ToBlob, uploadImage } from '$lib/utils';
16
  import { nanoid } from 'nanoid';
@@ -50,6 +50,7 @@
50
  $loadingState = 'Pending';
51
  const prompt = $myPresence.currentPrompt;
52
  const position = $myPresence.frame;
 
53
  console.log('Generating...', prompt, position);
54
  myPresence.update({
55
  status: Status.loading
@@ -106,23 +107,31 @@
106
  break;
107
  case 'process_completed':
108
  try {
109
- const imgBase64 = data.output.data[0] as string;
110
- const isNSWF = data.output.data[1] as boolean;
 
 
 
 
 
 
111
  if (isNSWF) {
112
  throw new Error('NFSW');
113
  }
114
  const key = getKey(position);
115
- const imgBlob = await base64ToBlob(imgBase64);
116
- const imgURL = await uploadImage(imgBlob, prompt, key);
117
- const promptImg = {
118
  prompt,
119
- imgURL: imgURL.filename,
120
  position,
121
  date: new Date().getTime(),
122
- id: nanoid()
 
123
  };
124
- $promptImgStorage.set(key, promptImg);
125
- console.log(imgURL);
 
 
126
  $loadingState = data.success ? 'Complete' : 'Error';
127
  setTimeout(() => {
128
  $loadingState = '';
@@ -139,7 +148,7 @@
139
  });
140
  setTimeout(() => {
141
  $loadingState = '';
142
- }, 2000);
143
  }
144
  websocket.close();
145
  return;
 
10
  import { PUBLIC_WS_INPAINTING } from '$env/static/public';
11
  import type { PromptImgKey } from '$lib/types';
12
  import { Status } from '$lib/types';
13
+ import { loadingState, currZoomTransform, maskEl, selectedRoomID } from '$lib/store';
14
  import { useMyPresence, useObject, useOthers } from '$lib/liveblocks';
15
  import { base64ToBlob, uploadImage } from '$lib/utils';
16
  import { nanoid } from 'nanoid';
 
50
  $loadingState = 'Pending';
51
  const prompt = $myPresence.currentPrompt;
52
  const position = $myPresence.frame;
53
+ const room = $selectedRoomID || 'default';
54
  console.log('Generating...', prompt, position);
55
  myPresence.update({
56
  status: Status.loading
 
107
  break;
108
  case 'process_completed':
109
  try {
110
+ const params = data.output.data[0] as {
111
+ is_nsfw: boolean;
112
+ image: {
113
+ url: string;
114
+ filename: string;
115
+ };
116
+ };
117
+ const isNSWF = params.is_nsfw;
118
  if (isNSWF) {
119
  throw new Error('NFSW');
120
  }
121
  const key = getKey(position);
122
+ // const imgBlob = await base64ToBlob(imgBase64);
123
+ const promptImgParams = {
 
124
  prompt,
125
+ imgURL: params.image.filename,
126
  position,
127
  date: new Date().getTime(),
128
+ id: nanoid(),
129
+ room: room
130
  };
131
+ // const imgURL = await uploadImage(imgBlob, promptImgParams);
132
+
133
+ $promptImgStorage.set(key, promptImgParams);
134
+ console.log(params.image.url);
135
  $loadingState = data.success ? 'Complete' : 'Error';
136
  setTimeout(() => {
137
  $loadingState = '';
 
148
  });
149
  setTimeout(() => {
150
  $loadingState = '';
151
+ }, 10000);
152
  }
153
  websocket.close();
154
  return;
frontend/src/lib/constants.ts CHANGED
@@ -14,7 +14,7 @@ export const COLORS = [
14
  export const EMOJIS = ['🐝', '🐌', '🐞', '🐜', 'πŸ¦‹', 'πŸ›', '🐝', '🐞', '🦟', 'πŸ¦—', 'πŸ•·', 'πŸ¦‚', '🐒', '🐍', '🦎', 'πŸ¦–', 'πŸ¦•', 'πŸ™', 'πŸ¦‘', '🐠', '🐟', '🐑', '🐬', '🦈', '🐳', 'πŸ‹', '🐊', 'πŸ…', 'πŸ†', 'πŸ¦“', '🦍', '🦧', '🐘', 'πŸ¦›', '🦏', 'πŸͺ', '🐫', 'πŸ¦’', 'πŸƒ', 'πŸ‚', 'πŸ„', '🐎', 'πŸ–',
15
  '🐏', 'πŸ‘', '🐐', 'πŸ•', '🐩', '🐈', 'πŸ“', 'πŸ¦ƒ', 'πŸ¦…', 'πŸ¦†', '🦒', 'πŸ¦‰', '🦚', '🦜', 'πŸ¦‡', '🐁', 'πŸ€', '🐿', 'πŸ‡', '🐿', 'πŸ¦”', 'πŸ¦‡', '🐻', '🐻', '🐨', '🐼', '🐡', 'πŸ™ˆ', 'πŸ™‰', 'πŸ™Š', 'πŸ’', 'πŸ‰', '🐲', 'πŸ¦•', 'πŸ¦–', '🐊', '🐒', '🦎', '🐍', '🐦', '🐧', 'πŸ¦…', 'πŸ¦†', 'πŸ¦‰', 'πŸ¦‡']
16
 
17
- export const MAX_CAPACITY = 10;
18
 
19
  export const CANVAS_SIZE = {
20
  width: 512 * 8,
 
14
  export const EMOJIS = ['🐝', '🐌', '🐞', '🐜', 'πŸ¦‹', 'πŸ›', '🐝', '🐞', '🦟', 'πŸ¦—', 'πŸ•·', 'πŸ¦‚', '🐒', '🐍', '🦎', 'πŸ¦–', 'πŸ¦•', 'πŸ™', 'πŸ¦‘', '🐠', '🐟', '🐑', '🐬', '🦈', '🐳', 'πŸ‹', '🐊', 'πŸ…', 'πŸ†', 'πŸ¦“', '🦍', '🦧', '🐘', 'πŸ¦›', '🦏', 'πŸͺ', '🐫', 'πŸ¦’', 'πŸƒ', 'πŸ‚', 'πŸ„', '🐎', 'πŸ–',
15
  '🐏', 'πŸ‘', '🐐', 'πŸ•', '🐩', '🐈', 'πŸ“', 'πŸ¦ƒ', 'πŸ¦…', 'πŸ¦†', '🦒', 'πŸ¦‰', '🦚', '🦜', 'πŸ¦‡', '🐁', 'πŸ€', '🐿', 'πŸ‡', '🐿', 'πŸ¦”', 'πŸ¦‡', '🐻', '🐻', '🐨', '🐼', '🐡', 'πŸ™ˆ', 'πŸ™‰', 'πŸ™Š', 'πŸ’', 'πŸ‰', '🐲', 'πŸ¦•', 'πŸ¦–', '🐊', '🐒', '🦎', '🐍', '🐦', '🐧', 'πŸ¦…', 'πŸ¦†', 'πŸ¦‰', 'πŸ¦‡']
16
 
17
+ export const MAX_CAPACITY = 50;
18
 
19
  export const CANVAS_SIZE = {
20
  width: 512 * 8,
frontend/src/lib/types.ts CHANGED
@@ -31,6 +31,7 @@ export type PromptImgObject = {
31
  }
32
  date: number;
33
  id: string;
 
34
  };
35
 
36
  export type PromptImgKey = string;
 
31
  }
32
  date: number;
33
  id: string;
34
+ roomid: string;
35
  };
36
 
37
  export type PromptImgKey = string;
frontend/src/lib/utils.ts CHANGED
@@ -21,20 +21,31 @@ export function base64ToBlob(base64image: string): Promise<Blob> {
21
  img.src = base64image;
22
  });
23
  }
24
- export async function uploadImage(imagBlob: Blob, prompt: string, key: string): Promise<{
 
 
 
 
 
 
 
25
  url: string;
26
  filename: string;
27
  }> {
28
  // simple regex slugify string for file name
29
- const promptSlug = slugify(prompt);
30
-
31
- const hash = crypto.randomUUID().split('-')[0];
32
- const fileName = `color-palette-${hash}-${promptSlug}-${key}.jpeg`;
33
 
34
  const file = new File([imagBlob], fileName, { type: 'image/jpeg' });
35
 
36
  const formData = new FormData()
37
  formData.append('file', file)
 
 
 
 
 
38
 
39
  const response = await fetch(PUBLIC_API_BASE + "/uploadfile", {
40
  method: 'POST',
 
21
  img.src = base64image;
22
  });
23
  }
24
+
25
+ export async function uploadImage(imagBlob: Blob, params: {
26
+ prompt: string;
27
+ position: { x: number; y: number };
28
+ date: number;
29
+ id: string;
30
+ room: string;
31
+ }): Promise<{
32
  url: string;
33
  filename: string;
34
  }> {
35
  // simple regex slugify string for file name
36
+ const promptSlug = slugify(params.prompt);
37
+ const key = `${params.position.x}_${params.position.y}`;
38
+ const fileName = `sd-${params.id}-${promptSlug}-${key}.jpeg`;
 
39
 
40
  const file = new File([imagBlob], fileName, { type: 'image/jpeg' });
41
 
42
  const formData = new FormData()
43
  formData.append('file', file)
44
+ formData.append('prompt', params.prompt)
45
+ formData.append('id', params.id)
46
+ formData.append('position', JSON.stringify(params.position))
47
+ formData.append('room', params.room)
48
+ formData.append('date', JSON.stringify(params.date))
49
 
50
  const response = await fetch(PUBLIC_API_BASE + "/uploadfile", {
51
  method: 'POST',
requirements.txt CHANGED
@@ -12,4 +12,5 @@ httpx==0.23.0
12
  gradio==3.6
13
  boto3==1.24.93
14
  python-magic==0.4.27
15
- fastapi-utils==0.2.1
 
 
12
  gradio==3.6
13
  boto3==1.24.93
14
  python-magic==0.4.27
15
+ fastapi-utils==0.2.1
16
+ shortuuid==1.0.9
stablediffusion-infinity/app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
 
4
  from pathlib import Path
5
  import uvicorn
6
- from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, Depends, status, Request
7
  from fastapi.staticfiles import StaticFiles
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi_utils.tasks import repeat_every
@@ -23,7 +23,8 @@ import boto3
23
  import magic
24
  import sqlite3
25
  import requests
26
- import uuid
 
27
 
28
  AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
29
  AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
@@ -118,15 +119,13 @@ def get_model():
118
  get_model()
119
 
120
 
121
- def run_outpaint(
122
  input_image,
123
  prompt_text,
124
  strength,
125
  guidance,
126
  step,
127
  fill_mode,
128
-
129
-
130
  ):
131
  inpaint = get_model()
132
  sel_buffer = np.array(input_image)
@@ -176,7 +175,19 @@ def run_outpaint(
176
  num_inference_steps=step,
177
  guidance_scale=guidance,
178
  )
179
- return output['images'][0], output["nsfw_content_detected"][0]
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  with blocks as demo:
@@ -212,8 +223,7 @@ with blocks as demo:
212
 
213
  model_input = gr.Image(label="Input", type="pil", image_mode="RGBA")
214
  proceed_button = gr.Button("Proceed", elem_id="proceed")
215
- model_output = gr.Image(label="Output")
216
- is_nsfw = gr.JSON()
217
 
218
  proceed_button.click(
219
  fn=run_outpaint,
@@ -225,7 +235,7 @@ with blocks as demo:
225
  sd_step,
226
  init_mode,
227
  ],
228
- outputs=[model_output, is_nsfw],
229
  )
230
 
231
 
@@ -257,8 +267,8 @@ def get_room_count(room_id: str, jwtToken: str = ''):
257
  raise Exception("Error getting room count")
258
 
259
 
260
- @app.on_event("startup")
261
- @repeat_every(seconds=60)
262
  async def sync_rooms():
263
  print("Syncing rooms")
264
  try:
@@ -277,18 +287,18 @@ async def sync_rooms():
277
  print("Rooms update failed")
278
 
279
 
280
- @app.get('/api/rooms')
281
  async def get_rooms(db: sqlite3.Connection = Depends(get_db)):
282
  rooms = db.execute("SELECT * FROM rooms").fetchall()
283
  return rooms
284
 
285
 
286
- @app.post('/api/auth')
287
  async def autorize(request: Request, db: sqlite3.Connection = Depends(get_db)):
288
  data = await request.json()
289
  room = data["room"]
290
  payload = {
291
- "userId": str(uuid.uuid4()),
292
  "userInfo": {
293
  "name": "Anon"
294
  }}
@@ -307,8 +317,40 @@ async def autorize(request: Request, db: sqlite3.Connection = Depends(get_db)):
307
  raise Exception(response.status_code, response.text)
308
 
309
 
310
- @app.post('/api/uploadfile')
311
- async def create_upload_file(background_tasks: BackgroundTasks, file: UploadFile):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  contents = await file.read()
313
  file_size = len(contents)
314
  if not 0 < file_size < 20E+06:
@@ -329,6 +371,8 @@ async def create_upload_file(background_tasks: BackgroundTasks, file: UploadFile
329
  file.filename, ExtraArgs={"ContentType": file.content_type, "CacheControl": "max-age=31536000"})
330
  temp_file.close()
331
 
 
 
332
  return {"url": f'https://d26smi9133w0oo.cloudfront.net/uploads/{file.filename}', "filename": file.filename}
333
 
334
 
 
3
 
4
  from pathlib import Path
5
  import uvicorn
6
+ from fastapi import FastAPI, BackgroundTasks, HTTPException, UploadFile, Form, Depends, status, Request
7
  from fastapi.staticfiles import StaticFiles
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi_utils.tasks import repeat_every
 
23
  import magic
24
  import sqlite3
25
  import requests
26
+ import shortuuid
27
+ import re
28
 
29
  AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
30
  AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY')
 
119
  get_model()
120
 
121
 
122
+ async def run_outpaint(
123
  input_image,
124
  prompt_text,
125
  strength,
126
  guidance,
127
  step,
128
  fill_mode,
 
 
129
  ):
130
  inpaint = get_model()
131
  sel_buffer = np.array(input_image)
 
175
  num_inference_steps=step,
176
  guidance_scale=guidance,
177
  )
178
+ image = output["images"][0]
179
+ is_nsfw = output["nsfw_content_detected"][0]
180
+ image_url = {}
181
+
182
+ if not is_nsfw:
183
+ print("not nsfw, uploading")
184
+ image_url = await upload_file(image, prompt_text)
185
+
186
+ params = {
187
+ "is_nsfw": is_nsfw,
188
+ "image": image_url
189
+ }
190
+ return params
191
 
192
 
193
  with blocks as demo:
 
223
 
224
  model_input = gr.Image(label="Input", type="pil", image_mode="RGBA")
225
  proceed_button = gr.Button("Proceed", elem_id="proceed")
226
+ params = gr.JSON()
 
227
 
228
  proceed_button.click(
229
  fn=run_outpaint,
 
235
  sd_step,
236
  init_mode,
237
  ],
238
+ outputs=[params],
239
  )
240
 
241
 
 
267
  raise Exception("Error getting room count")
268
 
269
 
270
+ @ app.on_event("startup")
271
+ @ repeat_every(seconds=60)
272
  async def sync_rooms():
273
  print("Syncing rooms")
274
  try:
 
287
  print("Rooms update failed")
288
 
289
 
290
+ @ app.get('/api/rooms')
291
  async def get_rooms(db: sqlite3.Connection = Depends(get_db)):
292
  rooms = db.execute("SELECT * FROM rooms").fetchall()
293
  return rooms
294
 
295
 
296
+ @ app.post('/api/auth')
297
  async def autorize(request: Request, db: sqlite3.Connection = Depends(get_db)):
298
  data = await request.json()
299
  room = data["room"]
300
  payload = {
301
+ "userId": str(shortuuid.uuid()),
302
  "userInfo": {
303
  "name": "Anon"
304
  }}
 
317
  raise Exception(response.status_code, response.text)
318
 
319
 
320
+ def slugify(value):
321
+ value = re.sub(r'[^\w\s-]', '', value).strip().lower()
322
+ out = re.sub(r'[-\s]+', '-', value)
323
+ return out[:400]
324
+
325
+
326
+
327
+ async def upload_file(image: Image.Image, prompt: str):
328
+ image = image.convert('RGB')
329
+ print("Uploading file from predict")
330
+ temp_file = io.BytesIO()
331
+ image.save(temp_file, format="JPEG")
332
+ temp_file.seek(0)
333
+ id = shortuuid.uuid()
334
+ prompt_slug = slugify(prompt)
335
+ filename = f"{id}-{prompt_slug}.jpg"
336
+ s3.upload_fileobj(Fileobj=temp_file, Bucket=AWS_S3_BUCKET_NAME, Key="uploads/" +
337
+ filename, ExtraArgs={"ContentType": "image/jpeg", "CacheControl": "max-age=31536000"})
338
+ temp_file.close()
339
+
340
+ out = {"url": f'https://d26smi9133w0oo.cloudfront.net/uploads/{filename}',
341
+ "filename": filename}
342
+ print(out)
343
+ return out
344
+
345
+
346
+ @ app.post('/api/uploadfile')
347
+ async def create_upload_file(background_tasks: BackgroundTasks,
348
+ file: UploadFile,
349
+ prompt: str = Form(),
350
+ id: str = Form(),
351
+ position: object = Form(),
352
+ room: str = Form(),
353
+ date: int = Form()):
354
  contents = await file.read()
355
  file_size = len(contents)
356
  if not 0 < file_size < 20E+06:
 
371
  file.filename, ExtraArgs={"ContentType": file.content_type, "CacheControl": "max-age=31536000"})
372
  temp_file.close()
373
 
374
+ print("File uploaded", prompt, id, position, room, date)
375
+
376
  return {"url": f'https://d26smi9133w0oo.cloudfront.net/uploads/{file.filename}', "filename": file.filename}
377
 
378