radames HF staff commited on
Commit
3207814
1 Parent(s): d6fedfa
app.py CHANGED
@@ -3,7 +3,7 @@ from fastapi import FastAPI
3
  from config import args
4
  from device import device, torch_dtype
5
  from app_init import init_app
6
- from user_queue import user_data_events
7
  from util import get_pipeline_class
8
 
9
 
@@ -11,4 +11,4 @@ app = FastAPI()
11
 
12
  pipeline_class = get_pipeline_class(args.pipeline)
13
  pipeline = pipeline_class(args, device, torch_dtype)
14
- init_app(app, user_data_events, args, pipeline)
 
3
  from config import args
4
  from device import device, torch_dtype
5
  from app_init import init_app
6
+ from user_queue import user_data
7
  from util import get_pipeline_class
8
 
9
 
 
11
 
12
  pipeline_class = get_pipeline_class(args.pipeline)
13
  pipeline = pipeline_class(args, device, torch_dtype)
14
+ init_app(app, user_data, args, pipeline)
app_init.py CHANGED
@@ -7,16 +7,15 @@ from fastapi import Request
7
  import logging
8
  import traceback
9
  from config import Args
10
- from user_queue import UserDataEventMap, UserDataEvent
11
  import uuid
12
- from asyncio import Event, sleep
13
  import time
14
- from PIL import Image
15
  from types import SimpleNamespace
16
- from util import pil_to_frame, is_firefox
 
17
 
18
 
19
- def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipeline):
20
  app.add_middleware(
21
  CORSMiddleware,
22
  allow_origins=["*"],
@@ -28,44 +27,42 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
28
  @app.websocket("/ws")
29
  async def websocket_endpoint(websocket: WebSocket):
30
  await websocket.accept()
31
- if args.max_queue_size > 0 and len(user_data_events) >= args.max_queue_size:
 
32
  print("Server is full")
33
  await websocket.send_json({"status": "error", "message": "Server is full"})
34
  await websocket.close()
35
  return
36
-
37
  try:
38
- uid = str(uuid.uuid4())
39
- print(f"New user connected: {uid}")
40
- await websocket.send_json(
41
- {"status": "success", "message": "Connected", "userId": uid}
42
- )
43
- user_data_events[uid] = UserDataEvent()
44
  await websocket.send_json(
45
- {"status": "start", "message": "Start Streaming", "userId": uid}
46
  )
47
- await handle_websocket_data(websocket, uid)
48
  except WebSocketDisconnect as e:
49
- logging.error(f"WebSocket Error: {e}, {uid}")
50
  traceback.print_exc()
51
  finally:
52
- print(f"User disconnected: {uid}")
53
- del user_data_events[uid]
54
 
55
  @app.get("/queue_size")
56
  async def get_queue_size():
57
- queue_size = len(user_data_events)
58
  return JSONResponse({"queue_size": queue_size})
59
 
60
  @app.get("/stream/{user_id}")
61
  async def stream(user_id: uuid.UUID, request: Request):
62
- uid = str(user_id)
63
  try:
 
64
 
65
  async def generate():
66
  while True:
67
- data = await user_data_events[uid].wait_for_data()
68
- params = data["params"]
 
69
  image = pipeline.predict(params)
70
  if image is None:
71
  continue
@@ -81,13 +78,12 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
81
  headers={"Cache-Control": "no-cache"},
82
  )
83
  except Exception as e:
84
- logging.error(f"Streaming Error: {e}, {user_data_events}")
85
  traceback.print_exc()
86
  return HTTPException(status_code=404, detail="User not found")
87
 
88
- async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
89
- uid = str(user_id)
90
- if uid not in user_data_events:
91
  return HTTPException(status_code=404, detail="User not found")
92
  last_time = time.time()
93
  try:
@@ -98,19 +94,20 @@ def init_app(app: FastAPI, user_data_events: UserDataEventMap, args: Args, pipel
98
  params = SimpleNamespace(**params.dict())
99
  if info.input_mode == "image":
100
  image_data = await websocket.receive_bytes()
101
- pil_image = Image.open(io.BytesIO(image_data))
102
- params.image = pil_image
103
- user_data_events[uid].update_data({"params": params})
104
  if args.timeout > 0 and time.time() - last_time > args.timeout:
105
  await websocket.send_json(
106
  {
107
  "status": "timeout",
108
  "message": "Your session has ended",
109
- "userId": uid,
110
  }
111
  )
112
  await websocket.close()
113
  return
 
114
 
115
  except Exception as e:
116
  logging.error(f"Error: {e}")
 
7
  import logging
8
  import traceback
9
  from config import Args
10
+ from user_queue import UserData
11
  import uuid
 
12
  import time
 
13
  from types import SimpleNamespace
14
+ from util import pil_to_frame, bytes_to_pil, is_firefox
15
+ import asyncio
16
 
17
 
18
+ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
 
27
  @app.websocket("/ws")
28
  async def websocket_endpoint(websocket: WebSocket):
29
  await websocket.accept()
30
+ user_count = user_data.get_user_count()
31
+ if args.max_queue_size > 0 and user_count >= args.max_queue_size:
32
  print("Server is full")
33
  await websocket.send_json({"status": "error", "message": "Server is full"})
34
  await websocket.close()
35
  return
 
36
  try:
37
+ user_id = uuid.uuid4()
38
+ print(f"New user connected: {user_id}")
39
+ await user_data.create_user(user_id, websocket)
 
 
 
40
  await websocket.send_json(
41
+ {"status": "connected", "message": "Connected", "userId": str(user_id)}
42
  )
43
+ await handle_websocket_data(user_id, websocket)
44
  except WebSocketDisconnect as e:
45
+ logging.error(f"WebSocket Error: {e}, {user_id}")
46
  traceback.print_exc()
47
  finally:
48
+ print(f"User disconnected: {user_id}")
49
+ user_data.delete_user(user_id)
50
 
51
  @app.get("/queue_size")
52
  async def get_queue_size():
53
+ queue_size = user_data.get_user_count()
54
  return JSONResponse({"queue_size": queue_size})
55
 
56
  @app.get("/stream/{user_id}")
57
  async def stream(user_id: uuid.UUID, request: Request):
 
58
  try:
59
+ print(f"New stream request: {user_id}")
60
 
61
  async def generate():
62
  while True:
63
+ params = await user_data.get_latest_data(user_id)
64
+ if not params:
65
+ continue
66
  image = pipeline.predict(params)
67
  if image is None:
68
  continue
 
78
  headers={"Cache-Control": "no-cache"},
79
  )
80
  except Exception as e:
81
+ logging.error(f"Streaming Error: {e}, {user_id} ")
82
  traceback.print_exc()
83
  return HTTPException(status_code=404, detail="User not found")
84
 
85
+ async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
86
+ if not user_data.check_user(user_id):
 
87
  return HTTPException(status_code=404, detail="User not found")
88
  last_time = time.time()
89
  try:
 
94
  params = SimpleNamespace(**params.dict())
95
  if info.input_mode == "image":
96
  image_data = await websocket.receive_bytes()
97
+ params.image = bytes_to_pil(image_data)
98
+
99
+ await user_data.update_data(user_id, params)
100
  if args.timeout > 0 and time.time() - last_time > args.timeout:
101
  await websocket.send_json(
102
  {
103
  "status": "timeout",
104
  "message": "Your session has ended",
105
+ "userId": user_id,
106
  }
107
  )
108
  await websocket.close()
109
  return
110
+ await asyncio.sleep(1.0 / 24)
111
 
112
  except Exception as e:
113
  logging.error(f"Error: {e}")
frontend/src/lib/components/Button.svelte CHANGED
@@ -7,7 +7,7 @@
7
  <slot />
8
  </button>
9
 
10
- <style lang="postcss">
11
  .button {
12
  @apply rounded bg-gray-700 p-2 font-normal text-white hover:bg-gray-800 disabled:cursor-not-allowed disabled:bg-gray-300 dark:disabled:bg-gray-700 dark:disabled:text-black;
13
  }
 
7
  <slot />
8
  </button>
9
 
10
+ <style lang="postcss" scoped>
11
  .button {
12
  @apply rounded bg-gray-700 p-2 font-normal text-white hover:bg-gray-800 disabled:cursor-not-allowed disabled:bg-gray-300 dark:disabled:bg-gray-700 dark:disabled:text-black;
13
  }
frontend/src/lib/components/ImagePlayer.svelte CHANGED
@@ -1,18 +1,19 @@
1
  <script lang="ts">
2
- import { isLCMRunning, lcmLiveState, lcmLiveActions } from '$lib/lcmLive';
3
  import { onFrameChangeStore } from '$lib/mediaStream';
4
  import { PUBLIC_BASE_URL } from '$env/static/public';
5
 
6
- $: streamId = $lcmLiveState?.streamId;
7
  $: {
8
- console.log('streamId', streamId);
9
  }
 
 
10
  </script>
11
 
12
  <div class="relative overflow-hidden rounded-lg border border-slate-300">
13
  <!-- svelte-ignore a11y-missing-attribute -->
14
- {#if $isLCMRunning}
15
- <img class="aspect-square w-full rounded-lg" src={PUBLIC_BASE_URL + '/stream/' + streamId} />
16
  {:else}
17
  <div class="aspect-square w-full rounded-lg" />
18
  {/if}
 
1
  <script lang="ts">
2
+ import { lcmLiveStatus, LCMLiveStatus, streamId } from '$lib/lcmLive';
3
  import { onFrameChangeStore } from '$lib/mediaStream';
4
  import { PUBLIC_BASE_URL } from '$env/static/public';
5
 
 
6
  $: {
7
+ console.log('streamId', $streamId);
8
  }
9
+ $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
10
+ $: console.log('isLCMRunning', isLCMRunning);
11
  </script>
12
 
13
  <div class="relative overflow-hidden rounded-lg border border-slate-300">
14
  <!-- svelte-ignore a11y-missing-attribute -->
15
+ {#if isLCMRunning}
16
+ <img class="aspect-square w-full rounded-lg" src={PUBLIC_BASE_URL + '/stream/' + $streamId} />
17
  {:else}
18
  <div class="aspect-square w-full rounded-lg" />
19
  {/if}
frontend/src/lib/components/PipelineOptions.svelte CHANGED
@@ -17,13 +17,13 @@
17
  <div>
18
  {#if featuredOptions}
19
  {#each featuredOptions as params}
20
- {#if params.field === FieldType.range}
21
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
22
- {:else if params.field === FieldType.seed}
23
  <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
24
- {:else if params.field === FieldType.textarea}
25
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
26
- {:else if params.field === FieldType.checkbox}
27
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
28
  {/if}
29
  {/each}
@@ -33,17 +33,17 @@
33
  <details open>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
  <div
36
- class="grid grid-cols-1 items-center gap-3 {pipelineValues.length > 5 ? 'sm:grid-cols-2' : ''}"
37
  >
38
  {#if advanceOptions}
39
  {#each advanceOptions as params}
40
- {#if params.field === FieldType.range}
41
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
42
- {:else if params.field === FieldType.seed}
43
  <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
44
- {:else if params.field === FieldType.textarea}
45
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
46
- {:else if params.field === FieldType.checkbox}
47
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
48
  {/if}
49
  {/each}
 
17
  <div>
18
  {#if featuredOptions}
19
  {#each featuredOptions as params}
20
+ {#if params.field === FieldType.RANGE}
21
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
22
+ {:else if params.field === FieldType.SEED}
23
  <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
24
+ {:else if params.field === FieldType.TEXTAREA}
25
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
26
+ {:else if params.field === FieldType.CHECKBOX}
27
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
28
  {/if}
29
  {/each}
 
33
  <details open>
34
  <summary class="cursor-pointer font-medium">Advanced Options</summary>
35
  <div
36
+ class="grid grid-cols-1 items-center gap-3 {pipelineParams.length > 5 ? 'sm:grid-cols-2' : ''}"
37
  >
38
  {#if advanceOptions}
39
  {#each advanceOptions as params}
40
+ {#if params.field === FieldType.RANGE}
41
  <InputRange {params} bind:value={$pipelineValues[params.id]}></InputRange>
42
+ {:else if params.field === FieldType.SEED}
43
  <SeedInput bind:value={$pipelineValues[params.id]}></SeedInput>
44
+ {:else if params.field === FieldType.TEXTAREA}
45
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
46
+ {:else if params.field === FieldType.CHECKBOX}
47
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
48
  {/if}
49
  {/each}
frontend/src/lib/components/VideoInput.svelte CHANGED
@@ -1,42 +1,38 @@
1
  <script lang="ts">
2
  import 'rvfc-polyfill';
3
- import { onMount, onDestroy } from 'svelte';
4
  import {
5
- mediaStreamState,
6
- mediaStreamActions,
7
- isMediaStreaming,
8
- MediaStreamStatus,
9
- onFrameChangeStore
10
  } from '$lib/mediaStream';
11
 
12
- $: mediaStream = $mediaStreamState.mediaStream;
13
-
14
  let videoEl: HTMLVideoElement;
15
  let videoFrameCallbackId: number;
16
  const WIDTH = 512;
17
  const HEIGHT = 512;
 
18
 
19
  onDestroy(() => {
20
  if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
21
  });
22
 
23
- function srcObject(node: HTMLVideoElement, stream: MediaStream) {
24
- node.srcObject = stream;
25
- return {
26
- update(newStream: MediaStream) {
27
- if (node.srcObject != newStream) {
28
- node.srcObject = newStream;
29
- }
30
- }
31
- };
32
  }
 
 
33
  async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
34
- const blob = await grapBlobImg();
35
- onFrameChangeStore.set({ now, metadata, blob });
 
 
 
36
  videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
37
  }
38
 
39
- $: if ($isMediaStreaming == MediaStreamStatus.CONNECTED) {
40
  videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
41
  }
42
  async function grapBlobImg() {
@@ -70,7 +66,6 @@
70
  autoplay
71
  muted
72
  loop
73
- use:srcObject={mediaStream}
74
  ></video>
75
  </div>
76
  <svg
 
1
  <script lang="ts">
2
  import 'rvfc-polyfill';
3
+ import { onDestroy } from 'svelte';
4
  import {
5
+ mediaStreamStatus,
6
+ MediaStreamStatusEnum,
7
+ onFrameChangeStore,
8
+ mediaStream
 
9
  } from '$lib/mediaStream';
10
 
 
 
11
  let videoEl: HTMLVideoElement;
12
  let videoFrameCallbackId: number;
13
  const WIDTH = 512;
14
  const HEIGHT = 512;
15
+ const THROTTLE_FPS = 10;
16
 
17
  onDestroy(() => {
18
  if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
19
  });
20
 
21
+ $: if (videoEl) {
22
+ videoEl.srcObject = $mediaStream;
 
 
 
 
 
 
 
23
  }
24
+
25
+ let last_millis = 0;
26
  async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
27
+ if (now - last_millis > 1000 / THROTTLE_FPS) {
28
+ const blob = await grapBlobImg();
29
+ onFrameChangeStore.set({ now, metadata, blob });
30
+ last_millis = now;
31
+ }
32
  videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
33
  }
34
 
35
+ $: if ($mediaStreamStatus == MediaStreamStatusEnum.CONNECTED) {
36
  videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
37
  }
38
  async function grapBlobImg() {
 
66
  autoplay
67
  muted
68
  loop
 
69
  ></video>
70
  </div>
71
  <svg
frontend/src/lib/lcmLive.ts CHANGED
@@ -1,27 +1,17 @@
1
  import { writable } from 'svelte/store';
2
  import { PUBLIC_WSS_URL } from '$env/static/public';
3
 
4
- export const isStreaming = writable(false);
5
- export const isLCMRunning = writable(false);
6
-
7
 
8
  export enum LCMLiveStatus {
9
- INIT = "init",
10
  CONNECTED = "connected",
11
  DISCONNECTED = "disconnected",
 
12
  }
13
 
14
- interface lcmLive {
15
- streamId: string | null;
16
- status: LCMLiveStatus
17
- }
18
-
19
- const initialState: lcmLive = {
20
- streamId: null,
21
- status: LCMLiveStatus.INIT
22
- };
23
 
24
- export const lcmLiveState = writable(initialState);
 
25
 
26
  let websocket: WebSocket | null = null;
27
  export const lcmLiveActions = {
@@ -37,12 +27,8 @@ export const lcmLiveActions = {
37
  console.log("Connected to websocket");
38
  };
39
  websocket.onclose = () => {
40
- lcmLiveState.update((state) => ({
41
- ...state,
42
- status: LCMLiveStatus.DISCONNECTED
43
- }));
44
  console.log("Disconnected from websocket");
45
- isLCMRunning.set(false);
46
  };
47
  websocket.onerror = (err) => {
48
  console.error(err);
@@ -51,47 +37,29 @@ export const lcmLiveActions = {
51
  const data = JSON.parse(event.data);
52
  console.log("WS: ", data);
53
  switch (data.status) {
54
- case "success":
55
- break;
56
- case "start":
57
- const streamId = data.userId;
58
- lcmLiveState.update((state) => ({
59
- ...state,
60
- status: LCMLiveStatus.CONNECTED,
61
- streamId: streamId,
62
- }));
63
- isLCMRunning.set(true);
64
- resolve(streamId);
65
  break;
66
  case "timeout":
67
  console.log("timeout");
68
- isLCMRunning.set(false);
69
- lcmLiveState.update((state) => ({
70
- ...state,
71
- status: LCMLiveStatus.DISCONNECTED,
72
- streamId: null,
73
- }));
74
  reject("timeout");
75
  case "error":
76
  console.log(data.message);
77
- isLCMRunning.set(false);
78
- lcmLiveState.update((state) => ({
79
- ...state,
80
- status: LCMLiveStatus.DISCONNECTED,
81
- streamId: null,
82
- }));
83
  reject(data.message);
84
  }
85
  };
86
 
87
  } catch (err) {
88
  console.error(err);
89
- isLCMRunning.set(false);
90
- lcmLiveState.update((state) => ({
91
- ...state,
92
- status: LCMLiveStatus.DISCONNECTED,
93
- streamId: null,
94
- }));
95
  reject(err);
96
  }
97
  });
@@ -113,7 +81,7 @@ export const lcmLiveActions = {
113
  websocket.close();
114
  }
115
  websocket = null;
116
- lcmLiveState.set({ status: LCMLiveStatus.DISCONNECTED, streamId: null });
117
- isLCMRunning.set(false)
118
  },
119
  };
 
1
  import { writable } from 'svelte/store';
2
  import { PUBLIC_WSS_URL } from '$env/static/public';
3
 
 
 
 
4
 
5
  export enum LCMLiveStatus {
 
6
  CONNECTED = "connected",
7
  DISCONNECTED = "disconnected",
8
+ WAIT = "wait",
9
  }
10
 
11
+ const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
 
 
 
 
 
 
 
 
12
 
13
+ export const lcmLiveStatus = writable<LCMLiveStatus>(initStatus);
14
+ export const streamId = writable<string | null>(null);
15
 
16
  let websocket: WebSocket | null = null;
17
  export const lcmLiveActions = {
 
27
  console.log("Connected to websocket");
28
  };
29
  websocket.onclose = () => {
30
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
 
 
 
31
  console.log("Disconnected from websocket");
 
32
  };
33
  websocket.onerror = (err) => {
34
  console.error(err);
 
37
  const data = JSON.parse(event.data);
38
  console.log("WS: ", data);
39
  switch (data.status) {
40
+ case "connected":
41
+ const userId = data.userId;
42
+ lcmLiveStatus.set(LCMLiveStatus.CONNECTED);
43
+ streamId.set(userId);
 
 
 
 
 
 
 
44
  break;
45
  case "timeout":
46
  console.log("timeout");
47
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
48
+ streamId.set(null);
 
 
 
 
49
  reject("timeout");
50
  case "error":
51
  console.log(data.message);
52
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
53
+ streamId.set(null);
 
 
 
 
54
  reject(data.message);
55
  }
56
  };
57
 
58
  } catch (err) {
59
  console.error(err);
60
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
61
+ streamId.set(null);
62
+
 
 
 
63
  reject(err);
64
  }
65
  });
 
81
  websocket.close();
82
  }
83
  websocket = null;
84
+ lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
85
+ streamId.set(null);
86
  },
87
  };
frontend/src/lib/mediaStream.ts CHANGED
@@ -1,38 +1,23 @@
1
  import { writable, type Writable } from 'svelte/store';
2
 
3
- export enum MediaStreamStatus {
4
  INIT = "init",
5
  CONNECTED = "connected",
6
  DISCONNECTED = "disconnected",
7
  }
8
  export const onFrameChangeStore: Writable<{ now: Number, metadata: VideoFrameCallbackMetadata, blob: Blob }> = writable();
9
- export const isMediaStreaming = writable(MediaStreamStatus.INIT);
10
 
11
- interface mediaStream {
12
- mediaStream: MediaStream | null;
13
- status: MediaStreamStatus
14
- devices: MediaDeviceInfo[];
15
- }
16
-
17
- const initialState: mediaStream = {
18
- mediaStream: null,
19
- status: MediaStreamStatus.INIT,
20
- devices: [],
21
- };
22
-
23
- export const mediaStreamState = writable(initialState);
24
 
25
  export const mediaStreamActions = {
26
  async enumerateDevices() {
27
- console.log("Enumerating devices");
28
  await navigator.mediaDevices.enumerateDevices()
29
  .then(devices => {
30
  const cameras = devices.filter(device => device.kind === 'videoinput');
31
- console.log("Cameras: ", cameras);
32
- mediaStreamState.update((state) => ({
33
- ...state,
34
- devices: cameras,
35
- }));
36
  })
37
  .catch(err => {
38
  console.error(err);
@@ -48,17 +33,14 @@ export const mediaStreamActions = {
48
 
49
  await navigator.mediaDevices
50
  .getUserMedia(constraints)
51
- .then((mediaStream) => {
52
- mediaStreamState.update((state) => ({
53
- ...state,
54
- mediaStream: mediaStream,
55
- status: MediaStreamStatus.CONNECTED,
56
- }));
57
- isMediaStreaming.set(MediaStreamStatus.CONNECTED);
58
  })
59
  .catch((err) => {
60
  console.error(`${err.name}: ${err.message}`);
61
- isMediaStreaming.set(MediaStreamStatus.DISCONNECTED);
 
62
  });
63
  },
64
  async switchCamera(mediaDevicedID: string) {
@@ -68,26 +50,19 @@ export const mediaStreamActions = {
68
  };
69
  await navigator.mediaDevices
70
  .getUserMedia(constraints)
71
- .then((mediaStream) => {
72
- mediaStreamState.update((state) => ({
73
- ...state,
74
- mediaStream: mediaStream,
75
- status: MediaStreamStatus.CONNECTED,
76
- }));
77
  })
78
  .catch((err) => {
79
  console.error(`${err.name}: ${err.message}`);
80
  });
81
  },
82
  async stop() {
83
- navigator.mediaDevices.getUserMedia({ video: true }).then((mediaStream) => {
84
- mediaStream.getTracks().forEach((track) => track.stop());
85
  });
86
- mediaStreamState.update((state) => ({
87
- ...state,
88
- mediaStream: null,
89
- status: MediaStreamStatus.DISCONNECTED,
90
- }));
91
- isMediaStreaming.set(MediaStreamStatus.DISCONNECTED);
92
  },
93
  };
 
1
  import { writable, type Writable } from 'svelte/store';
2
 
3
+ export enum MediaStreamStatusEnum {
4
  INIT = "init",
5
  CONNECTED = "connected",
6
  DISCONNECTED = "disconnected",
7
  }
8
  export const onFrameChangeStore: Writable<{ now: Number, metadata: VideoFrameCallbackMetadata, blob: Blob }> = writable();
 
9
 
10
+ export const mediaDevices = writable<MediaDeviceInfo[]>([]);
11
+ export const mediaStreamStatus = writable(MediaStreamStatusEnum.INIT);
12
+ export const mediaStream = writable<MediaStream | null>(null);
 
 
 
 
 
 
 
 
 
 
13
 
14
  export const mediaStreamActions = {
15
  async enumerateDevices() {
16
+ // console.log("Enumerating devices");
17
  await navigator.mediaDevices.enumerateDevices()
18
  .then(devices => {
19
  const cameras = devices.filter(device => device.kind === 'videoinput');
20
+ mediaDevices.set(cameras);
 
 
 
 
21
  })
22
  .catch(err => {
23
  console.error(err);
 
33
 
34
  await navigator.mediaDevices
35
  .getUserMedia(constraints)
36
+ .then((stream) => {
37
+ mediaStreamStatus.set(MediaStreamStatusEnum.CONNECTED);
38
+ mediaStream.set(stream);
 
 
 
 
39
  })
40
  .catch((err) => {
41
  console.error(`${err.name}: ${err.message}`);
42
+ mediaStreamStatus.set(MediaStreamStatusEnum.DISCONNECTED);
43
+ mediaStream.set(null);
44
  });
45
  },
46
  async switchCamera(mediaDevicedID: string) {
 
50
  };
51
  await navigator.mediaDevices
52
  .getUserMedia(constraints)
53
+ .then((stream) => {
54
+ mediaStreamStatus.set(MediaStreamStatusEnum.CONNECTED);
55
+ mediaStream.set(stream)
 
 
 
56
  })
57
  .catch((err) => {
58
  console.error(`${err.name}: ${err.message}`);
59
  });
60
  },
61
  async stop() {
62
+ navigator.mediaDevices.getUserMedia({ video: true }).then((stream) => {
63
+ stream.getTracks().forEach((track) => track.stop());
64
  });
65
+ mediaStreamStatus.set(MediaStreamStatusEnum.DISCONNECTED);
66
+ mediaStream.set(null);
 
 
 
 
67
  },
68
  };
frontend/src/lib/store.ts CHANGED
@@ -1,4 +1,4 @@
1
 
2
  import { writable, type Writable } from 'svelte/store';
3
 
4
- export const pipelineValues = writable({});
 
1
 
2
  import { writable, type Writable } from 'svelte/store';
3
 
4
+ export const pipelineValues = writable({} as Record<string, any>);
frontend/src/lib/types.ts CHANGED
@@ -1,13 +1,13 @@
1
  export const enum FieldType {
2
- range = "range",
3
- seed = "seed",
4
- textarea = "textarea",
5
- checkbox = "checkbox",
6
  }
7
  export const enum PipelineMode {
8
- image = "image",
9
- video = "video",
10
- text = "text",
11
  }
12
 
13
  export interface FieldProps {
 
1
  export const enum FieldType {
2
+ RANGE = "range",
3
+ SEED = "seed",
4
+ TEXTAREA = "textarea",
5
+ CHECKBOX = "checkbox",
6
  }
7
  export const enum PipelineMode {
8
+ IMAGE = "image",
9
+ VIDEO = "video",
10
+ TEXT = "text",
11
  }
12
 
13
  export interface FieldProps {
frontend/src/routes/+page.svelte CHANGED
@@ -8,12 +8,12 @@
8
  import Button from '$lib/components/Button.svelte';
9
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
10
  import Spinner from '$lib/icons/spinner.svelte';
11
- import { isLCMRunning, lcmLiveState, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
12
  import {
13
- mediaStreamState,
14
  mediaStreamActions,
15
- isMediaStreaming,
16
- onFrameChangeStore
 
17
  } from '$lib/mediaStream';
18
  import { pipelineValues } from '$lib/store';
19
 
@@ -30,7 +30,7 @@
30
  const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
31
  pipelineParams = Object.values(settings.input_params.properties);
32
  pipelineInfo = settings.info.properties;
33
- isImageMode = pipelineInfo.input_mode.default === PipelineMode.image;
34
  maxQueueSize = settings.max_queue_size;
35
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
36
  console.log('PARAMS', pipelineParams);
@@ -38,22 +38,37 @@
38
  }
39
  console.log('isImageMode', isImageMode);
40
 
 
 
 
 
 
 
41
  // send Webcam stream to LCM if image mode
42
  $: {
43
- if (isImageMode && $lcmLiveState.status === LCMLiveStatus.CONNECTED) {
 
 
 
 
44
  lcmLiveActions.send($pipelineValues);
45
  lcmLiveActions.send($onFrameChangeStore.blob);
46
  }
47
  }
48
 
49
- // send Webcam stream to LCM
50
- $: {
51
- if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
52
- lcmLiveActions.send($pipelineValues);
53
- }
54
- }
 
 
 
 
 
55
  async function toggleLcmLive() {
56
- if (!$isLCMRunning) {
57
  if (isImageMode) {
58
  await mediaStreamActions.enumerateDevices();
59
  await mediaStreamActions.start();
@@ -112,13 +127,13 @@
112
  <PipelineOptions {pipelineParams}></PipelineOptions>
113
  <div class="flex gap-3">
114
  <Button on:click={toggleLcmLive}>
115
- {#if $isLCMRunning}
116
  Stop
117
  {:else}
118
  Start
119
  {/if}
120
  </Button>
121
- <Button disabled={$isLCMRunning} classList={'ml-auto'}>Snapshot</Button>
122
  </div>
123
 
124
  <ImagePlayer>
 
8
  import Button from '$lib/components/Button.svelte';
9
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
10
  import Spinner from '$lib/icons/spinner.svelte';
11
+ import { lcmLiveStatus, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
12
  import {
 
13
  mediaStreamActions,
14
+ mediaStreamStatus,
15
+ onFrameChangeStore,
16
+ MediaStreamStatusEnum
17
  } from '$lib/mediaStream';
18
  import { pipelineValues } from '$lib/store';
19
 
 
30
  const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
31
  pipelineParams = Object.values(settings.input_params.properties);
32
  pipelineInfo = settings.info.properties;
33
+ isImageMode = pipelineInfo.input_mode.default === PipelineMode.IMAGE;
34
  maxQueueSize = settings.max_queue_size;
35
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
36
  console.log('PARAMS', pipelineParams);
 
38
  }
39
  console.log('isImageMode', isImageMode);
40
 
41
+ $: {
42
+ console.log('lcmLiveState', $lcmLiveStatus);
43
+ }
44
+ $: {
45
+ console.log('mediaStreamState', $mediaStreamStatus);
46
+ }
47
  // send Webcam stream to LCM if image mode
48
  $: {
49
+ if (
50
+ isImageMode &&
51
+ $lcmLiveStatus === LCMLiveStatus.CONNECTED &&
52
+ $mediaStreamStatus === MediaStreamStatusEnum.CONNECTED
53
+ ) {
54
  lcmLiveActions.send($pipelineValues);
55
  lcmLiveActions.send($onFrameChangeStore.blob);
56
  }
57
  }
58
 
59
+ $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
60
+ // $: {
61
+ // console.log('onFrameChangeStore', $onFrameChangeStore);
62
+ // }
63
+
64
+ // // send Webcam stream to LCM
65
+ // $: {
66
+ // if ($lcmLiveState.status === LCMLiveStatus.CONNECTED) {
67
+ // lcmLiveActions.send($pipelineValues);
68
+ // }
69
+ // }
70
  async function toggleLcmLive() {
71
+ if (!isLCMRunning) {
72
  if (isImageMode) {
73
  await mediaStreamActions.enumerateDevices();
74
  await mediaStreamActions.start();
 
127
  <PipelineOptions {pipelineParams}></PipelineOptions>
128
  <div class="flex gap-3">
129
  <Button on:click={toggleLcmLive}>
130
+ {#if isLCMRunning}
131
  Stop
132
  {:else}
133
  Start
134
  {/if}
135
  </Button>
136
+ <Button disabled={isLCMRunning} classList={'ml-auto'}>Snapshot</Button>
137
  </div>
138
 
139
  <ImagePlayer>
frontend/svelte.config.js CHANGED
@@ -1,10 +1,8 @@
1
  import adapter from '@sveltejs/adapter-static';
2
  import { vitePreprocess } from '@sveltejs/kit/vite';
3
-
4
  /** @type {import('@sveltejs/kit').Config} */
5
  const config = {
6
- preprocess: vitePreprocess(),
7
-
8
  kit: {
9
  adapter: adapter({
10
  pages: '../public',
 
1
  import adapter from '@sveltejs/adapter-static';
2
  import { vitePreprocess } from '@sveltejs/kit/vite';
 
3
  /** @type {import('@sveltejs/kit').Config} */
4
  const config = {
5
+ preprocess: vitePreprocess({ postcss: true }),
 
6
  kit: {
7
  adapter: adapter({
8
  pages: '../public',
user_queue.py CHANGED
@@ -1,29 +1,52 @@
1
- from typing import Dict, Union
2
  from uuid import UUID
3
  import asyncio
4
- from PIL import Image
5
- from typing import Dict, Union
6
- from PIL import Image
 
7
 
8
- InputParams = dict
9
- UserId = UUID
10
- EventDataContent = Dict[str, InputParams]
11
 
12
 
13
- class UserDataEvent:
14
  def __init__(self):
15
- self.data_event = asyncio.Event()
16
- self.data_content: EventDataContent = {}
17
 
18
- def update_data(self, new_data: EventDataContent):
19
- self.data_content = new_data
20
- self.data_event.set()
 
 
 
21
 
22
- async def wait_for_data(self) -> EventDataContent:
23
- await self.data_event.wait()
24
- self.data_event.clear()
25
- return self.data_content
26
 
 
 
 
 
 
 
 
 
 
27
 
28
- UserDataEventMap = Dict[UserId, UserDataEvent]
29
- user_data_events: UserDataEventMap = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
  from uuid import UUID
3
  import asyncio
4
+ from fastapi import WebSocket
5
+ from types import SimpleNamespace
6
+ from typing import Dict
7
+ from typing import Union
8
 
9
+ UserDataContent = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
 
 
10
 
11
 
12
+ class UserData:
13
  def __init__(self):
14
+ self.data_content: Dict[UUID, UserDataContent] = {}
 
15
 
16
+ async def create_user(self, user_id: UUID, websocket: WebSocket):
17
+ self.data_content[user_id] = {
18
+ "websocket": websocket,
19
+ "queue": asyncio.Queue(),
20
+ }
21
+ await asyncio.sleep(1)
22
 
23
+ def check_user(self, user_id: UUID) -> bool:
24
+ return user_id in self.data_content
 
 
25
 
26
+ async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
27
+ user_session = self.data_content[user_id]
28
+ queue = user_session["queue"]
29
+ while not queue.empty():
30
+ try:
31
+ queue.get_nowait()
32
+ except asyncio.QueueEmpty:
33
+ continue
34
+ await queue.put(new_data)
35
 
36
+ async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
37
+ user_session = self.data_content[user_id]
38
+ queue = user_session["queue"]
39
+ try:
40
+ return await queue.get()
41
+ except asyncio.QueueEmpty:
42
+ return None
43
+
44
+ def delete_user(self, user_id: UUID):
45
+ if user_id in self.data_content:
46
+ del self.data_content[user_id]
47
+
48
+ def get_user_count(self) -> int:
49
+ return len(self.data_content)
50
+
51
+
52
+ user_data = UserData()
util.py CHANGED
@@ -20,6 +20,11 @@ def get_pipeline_class(pipeline_name: str) -> ModuleType:
20
  return pipeline_class
21
 
22
 
 
 
 
 
 
23
  def pil_to_frame(image: Image.Image) -> bytes:
24
  frame_data = io.BytesIO()
25
  image.save(frame_data, format="JPEG")
 
20
  return pipeline_class
21
 
22
 
23
+ def bytes_to_pil(image_bytes: bytes) -> Image.Image:
24
+ image = Image.open(io.BytesIO(image_bytes))
25
+ return image
26
+
27
+
28
  def pil_to_frame(image: Image.Image) -> bytes:
29
  frame_data = io.BytesIO()
30
  image.save(frame_data, format="JPEG")