radames HF staff commited on
Commit
1d3190d
1 Parent(s): fd757d2
app_init.py CHANGED
@@ -36,10 +36,16 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
36
  try:
37
  user_id = uuid.uuid4()
38
  print(f"New user connected: {user_id}")
 
39
  await user_data.create_user(user_id, websocket)
40
  await websocket.send_json(
41
  {"status": "connected", "message": "Connected", "userId": str(user_id)}
42
  )
 
 
 
 
 
43
  await handle_websocket_data(user_id, websocket)
44
  except WebSocketDisconnect as e:
45
  logging.error(f"WebSocket Error: {e}, {user_id}")
@@ -48,6 +54,46 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
48
  print(f"User disconnected: {user_id}")
49
  user_data.delete_user(user_id)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.get("/queue_size")
52
  async def get_queue_size():
53
  queue_size = user_data.get_user_count()
@@ -59,10 +105,20 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
59
  print(f"New stream request: {user_id}")
60
 
61
  async def generate():
 
 
62
  while True:
63
  params = await user_data.get_latest_data(user_id)
64
- if not params:
 
 
 
 
 
 
65
  continue
 
 
66
  image = pipeline.predict(params)
67
  if image is None:
68
  continue
@@ -71,6 +127,11 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
71
  # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
72
  if not is_firefox(request.headers["user-agent"]):
73
  yield frame
 
 
 
 
 
74
 
75
  return StreamingResponse(
76
  generate(),
@@ -82,37 +143,6 @@ def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
82
  traceback.print_exc()
83
  return HTTPException(status_code=404, detail="User not found")
84
 
85
- async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
86
- if not user_data.check_user(user_id):
87
- return HTTPException(status_code=404, detail="User not found")
88
- last_time = time.time()
89
- try:
90
- while True:
91
- params = await websocket.receive_json()
92
- params = pipeline.InputParams(**params)
93
- info = pipeline.Info()
94
- params = SimpleNamespace(**params.dict())
95
- if info.input_mode == "image":
96
- image_data = await websocket.receive_bytes()
97
- params.image = bytes_to_pil(image_data)
98
-
99
- await user_data.update_data(user_id, params)
100
- if args.timeout > 0 and time.time() - last_time > args.timeout:
101
- await websocket.send_json(
102
- {
103
- "status": "timeout",
104
- "message": "Your session has ended",
105
- "userId": user_id,
106
- }
107
- )
108
- await websocket.close()
109
- return
110
- await asyncio.sleep(1.0 / 24)
111
-
112
- except Exception as e:
113
- logging.error(f"Error: {e}")
114
- traceback.print_exc()
115
-
116
  # route to setup frontend
117
  @app.get("/settings")
118
  async def settings():
 
36
  try:
37
  user_id = uuid.uuid4()
38
  print(f"New user connected: {user_id}")
39
+
40
  await user_data.create_user(user_id, websocket)
41
  await websocket.send_json(
42
  {"status": "connected", "message": "Connected", "userId": str(user_id)}
43
  )
44
+ await websocket.send_json(
45
+ {
46
+ "status": "send_frame",
47
+ }
48
+ )
49
  await handle_websocket_data(user_id, websocket)
50
  except WebSocketDisconnect as e:
51
  logging.error(f"WebSocket Error: {e}, {user_id}")
 
54
  print(f"User disconnected: {user_id}")
55
  user_data.delete_user(user_id)
56
 
57
+ async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
58
+ if not user_data.check_user(user_id):
59
+ return HTTPException(status_code=404, detail="User not found")
60
+ last_time = time.time()
61
+ try:
62
+ while True:
63
+ data = await websocket.receive_json()
64
+ if data["status"] != "next_frame":
65
+ asyncio.sleep(1.0 / 24)
66
+ continue
67
+
68
+ params = await websocket.receive_json()
69
+ params = pipeline.InputParams(**params)
70
+ info = pipeline.Info()
71
+ params = SimpleNamespace(**params.dict())
72
+ if info.input_mode == "image":
73
+ image_data = await websocket.receive_bytes()
74
+ params.image = bytes_to_pil(image_data)
75
+ await user_data.update_data(user_id, params)
76
+ await websocket.send_json(
77
+ {
78
+ "status": "wait",
79
+ }
80
+ )
81
+ if args.timeout > 0 and time.time() - last_time > args.timeout:
82
+ await websocket.send_json(
83
+ {
84
+ "status": "timeout",
85
+ "message": "Your session has ended",
86
+ "userId": user_id,
87
+ }
88
+ )
89
+ await websocket.close()
90
+ return
91
+ await asyncio.sleep(1.0 / 24)
92
+
93
+ except Exception as e:
94
+ logging.error(f"Error: {e}")
95
+ traceback.print_exc()
96
+
97
  @app.get("/queue_size")
98
  async def get_queue_size():
99
  queue_size = user_data.get_user_count()
 
105
  print(f"New stream request: {user_id}")
106
 
107
  async def generate():
108
+ websocket = user_data.get_websocket(user_id)
109
+ last_params = SimpleNamespace()
110
  while True:
111
  params = await user_data.get_latest_data(user_id)
112
+ if not vars(params) or params.__dict__ == last_params.__dict__:
113
+ await websocket.send_json(
114
+ {
115
+ "status": "send_frame",
116
+ }
117
+ )
118
+ await asyncio.sleep(0.1)
119
  continue
120
+
121
+ last_params = params
122
  image = pipeline.predict(params)
123
  if image is None:
124
  continue
 
127
  # https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
128
  if not is_firefox(request.headers["user-agent"]):
129
  yield frame
130
+ await websocket.send_json(
131
+ {
132
+ "status": "send_frame",
133
+ }
134
+ )
135
 
136
  return StreamingResponse(
137
  generate(),
 
143
  traceback.print_exc()
144
  return HTTPException(status_code=404, detail="User not found")
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # route to setup frontend
147
  @app.get("/settings")
148
  async def settings():
frontend/src/lib/components/VideoInput.svelte CHANGED
@@ -12,7 +12,6 @@
12
  let videoFrameCallbackId: number;
13
  const WIDTH = 512;
14
  const HEIGHT = 512;
15
- const THROTTLE_FPS = 6;
16
 
17
  onDestroy(() => {
18
  if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
@@ -22,13 +21,9 @@
22
  videoEl.srcObject = $mediaStream;
23
  }
24
 
25
- let last_millis = 0;
26
  async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
27
- if (now - last_millis > 1000 / THROTTLE_FPS) {
28
- const blob = await grapBlobImg();
29
- onFrameChangeStore.set({ now, metadata, blob });
30
- last_millis = now;
31
- }
32
  videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
33
  }
34
 
 
12
  let videoFrameCallbackId: number;
13
  const WIDTH = 512;
14
  const HEIGHT = 512;
 
15
 
16
  onDestroy(() => {
17
  if (videoFrameCallbackId) videoEl.cancelVideoFrameCallback(videoFrameCallbackId);
 
21
  videoEl.srcObject = $mediaStream;
22
  }
23
 
 
24
  async function onFrameChange(now: DOMHighResTimeStamp, metadata: VideoFrameCallbackMetadata) {
25
+ const blob = await grapBlobImg();
26
+ onFrameChangeStore.set({ blob });
 
 
 
27
  videoFrameCallbackId = videoEl.requestVideoFrameCallback(onFrameChange);
28
  }
29
 
frontend/src/lib/lcmLive.ts CHANGED
@@ -6,6 +6,7 @@ export enum LCMLiveStatus {
6
  CONNECTED = "connected",
7
  DISCONNECTED = "disconnected",
8
  WAIT = "wait",
 
9
  }
10
 
11
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
@@ -15,7 +16,7 @@ export const streamId = writable<string | null>(null);
15
 
16
  let websocket: WebSocket | null = null;
17
  export const lcmLiveActions = {
18
- async start() {
19
  return new Promise((resolve, reject) => {
20
 
21
  try {
@@ -43,6 +44,17 @@ export const lcmLiveActions = {
43
  streamId.set(userId);
44
  resolve(userId);
45
  break;
 
 
 
 
 
 
 
 
 
 
 
46
  case "timeout":
47
  console.log("timeout");
48
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
@@ -60,7 +72,6 @@ export const lcmLiveActions = {
60
  console.error(err);
61
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
62
  streamId.set(null);
63
-
64
  reject(err);
65
  }
66
  });
 
6
  CONNECTED = "connected",
7
  DISCONNECTED = "disconnected",
8
  WAIT = "wait",
9
+ SEND_FRAME = "send_frame",
10
  }
11
 
12
  const initStatus: LCMLiveStatus = LCMLiveStatus.DISCONNECTED;
 
16
 
17
  let websocket: WebSocket | null = null;
18
  export const lcmLiveActions = {
19
+ async start(getSreamdata: () => any[]) {
20
  return new Promise((resolve, reject) => {
21
 
22
  try {
 
44
  streamId.set(userId);
45
  resolve(userId);
46
  break;
47
+ case "send_frame":
48
+ lcmLiveStatus.set(LCMLiveStatus.SEND_FRAME);
49
+ const streamData = getSreamdata();
50
+ websocket?.send(JSON.stringify({ status: "next_frame" }));
51
+ for (const d of streamData) {
52
+ this.send(d);
53
+ }
54
+ break;
55
+ case "wait":
56
+ lcmLiveStatus.set(LCMLiveStatus.WAIT);
57
+ break;
58
  case "timeout":
59
  console.log("timeout");
60
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
 
72
  console.error(err);
73
  lcmLiveStatus.set(LCMLiveStatus.DISCONNECTED);
74
  streamId.set(null);
 
75
  reject(err);
76
  }
77
  });
frontend/src/lib/mediaStream.ts CHANGED
@@ -5,7 +5,7 @@ export enum MediaStreamStatusEnum {
5
  CONNECTED = "connected",
6
  DISCONNECTED = "disconnected",
7
  }
8
- export const onFrameChangeStore: Writable<{ now: Number, metadata: VideoFrameCallbackMetadata, blob: Blob }> = writable();
9
 
10
  export const mediaDevices = writable<MediaDeviceInfo[]>([]);
11
  export const mediaStreamStatus = writable(MediaStreamStatusEnum.INIT);
 
5
  CONNECTED = "connected",
6
  DISCONNECTED = "disconnected",
7
  }
8
+ export const onFrameChangeStore: Writable<{ blob: Blob }> = writable({ blob: new Blob() });
9
 
10
  export const mediaDevices = writable<MediaDeviceInfo[]>([]);
11
  export const mediaStreamStatus = writable(MediaStreamStatusEnum.INIT);
frontend/src/routes/+page.svelte CHANGED
@@ -35,25 +35,18 @@
35
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
36
  }
37
 
38
- // send Webcam stream to LCM if image mode
39
- $: {
40
- if (
41
- isImageMode &&
42
- $lcmLiveStatus === LCMLiveStatus.CONNECTED &&
43
- $mediaStreamStatus === MediaStreamStatusEnum.CONNECTED
44
- ) {
45
- lcmLiveActions.send(getPipelineValues());
46
- lcmLiveActions.send($onFrameChangeStore.blob);
47
- }
48
- }
49
- $: {
50
- if (!isImageMode && $lcmLiveStatus === LCMLiveStatus.CONNECTED) {
51
- lcmLiveActions.send($deboucedPipelineValues);
52
  }
53
  }
54
 
55
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
56
-
 
 
57
  let disabled = false;
58
  async function toggleLcmLive() {
59
  if (!isLCMRunning) {
@@ -62,7 +55,7 @@
62
  await mediaStreamActions.start();
63
  }
64
  disabled = true;
65
- await lcmLiveActions.start();
66
  disabled = false;
67
  } else {
68
  if (isImageMode) {
 
35
  pipelineParams = pipelineParams.filter((e) => e?.disabled !== true);
36
  }
37
 
38
+ function getSreamdata() {
39
+ if (isImageMode) {
40
+ return [getPipelineValues(), $onFrameChangeStore?.blob];
41
+ } else {
42
+ return [$deboucedPipelineValues];
 
 
 
 
 
 
 
 
 
43
  }
44
  }
45
 
46
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
47
+ $: {
48
+ console.log('lcmLiveStatus', $lcmLiveStatus);
49
+ }
50
  let disabled = false;
51
  async function toggleLcmLive() {
52
  if (!isLCMRunning) {
 
55
  await mediaStreamActions.start();
56
  }
57
  disabled = true;
58
+ await lcmLiveActions.start(getSreamdata);
59
  disabled = false;
60
  } else {
61
  if (isImageMode) {
pipelines/img2img.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ AutoPipelineForImage2Image,
3
+ AutoencoderTiny,
4
+ )
5
+ from compel import Compel
6
+ import torch
7
+
8
+ try:
9
+ import intel_extension_for_pytorch as ipex # type: ignore
10
+ except:
11
+ pass
12
+
13
+ import psutil
14
+ from config import Args
15
+ from pydantic import BaseModel, Field
16
+ from PIL import Image
17
+
18
+ base_model = "SimianLuo/LCM_Dreamshaper_v7"
19
+ taesd_model = "madebyollin/taesd"
20
+
21
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
22
+
23
+
24
+ class Pipeline:
25
+ class Info(BaseModel):
26
+ name: str = "img2img"
27
+ title: str = "Image-to-Image LCM"
28
+ description: str = "Generates an image from a text prompt"
29
+ input_mode: str = "image"
30
+
31
+ class InputParams(BaseModel):
32
+ prompt: str = Field(
33
+ default_prompt,
34
+ title="Prompt",
35
+ field="textarea",
36
+ id="prompt",
37
+ )
38
+ seed: int = Field(
39
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
40
+ )
41
+ steps: int = Field(
42
+ 4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
43
+ )
44
+ width: int = Field(
45
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
46
+ )
47
+ height: int = Field(
48
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
49
+ )
50
+ guidance_scale: float = Field(
51
+ 0.2,
52
+ min=0,
53
+ max=20,
54
+ step=0.001,
55
+ title="Guidance Scale",
56
+ field="range",
57
+ hide=True,
58
+ id="guidance_scale",
59
+ )
60
+ strength: float = Field(
61
+ 0.5,
62
+ min=0.25,
63
+ max=1.0,
64
+ step=0.001,
65
+ title="Strength",
66
+ field="range",
67
+ hide=True,
68
+ id="strength",
69
+ )
70
+
71
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
72
+ if args.safety_checker:
73
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(base_model)
74
+ else:
75
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
76
+ base_model,
77
+ safety_checker=None,
78
+ )
79
+ if args.use_taesd:
80
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
81
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
82
+ )
83
+
84
+ self.pipe.set_progress_bar_config(disable=True)
85
+ self.pipe.to(device=device, dtype=torch_dtype)
86
+ self.pipe.unet.to(memory_format=torch.channels_last)
87
+
88
+ # check if computer has less than 64GB of RAM using sys or os
89
+ if psutil.virtual_memory().total < 64 * 1024**3:
90
+ self.pipe.enable_attention_slicing()
91
+
92
+ if args.torch_compile:
93
+ print("Running torch compile")
94
+ self.pipe.unet = torch.compile(
95
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
96
+ )
97
+ self.pipe.vae = torch.compile(
98
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
99
+ )
100
+
101
+ self.pipe(
102
+ prompt="warmup",
103
+ image=[Image.new("RGB", (768, 768))],
104
+ )
105
+
106
+ self.compel_proc = Compel(
107
+ tokenizer=self.pipe.tokenizer,
108
+ text_encoder=self.pipe.text_encoder,
109
+ truncate_long_prompts=False,
110
+ )
111
+
112
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
113
+ generator = torch.manual_seed(params.seed)
114
+ prompt_embeds = self.compel_proc(params.prompt)
115
+ results = self.pipe(
116
+ image=params.image,
117
+ prompt_embeds=prompt_embeds,
118
+ generator=generator,
119
+ strength=params.strength,
120
+ num_inference_steps=params.steps,
121
+ guidance_scale=params.guidance_scale,
122
+ width=params.width,
123
+ height=params.height,
124
+ output_type="pil",
125
+ )
126
+
127
+ nsfw_content_detected = (
128
+ results.nsfw_content_detected[0]
129
+ if "nsfw_content_detected" in results
130
+ else False
131
+ )
132
+ if nsfw_content_detected:
133
+ return None
134
+ result_image = results.images[0]
135
+
136
+ return result_image
user_queue.py CHANGED
@@ -36,6 +36,7 @@ class UserData:
36
  async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
37
  user_session = self.data_content[user_id]
38
  queue = user_session["queue"]
 
39
  try:
40
  return await queue.get()
41
  except asyncio.QueueEmpty:
@@ -55,5 +56,8 @@ class UserData:
55
  def get_user_count(self) -> int:
56
  return len(self.data_content)
57
 
 
 
 
58
 
59
  user_data = UserData()
 
36
  async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
37
  user_session = self.data_content[user_id]
38
  queue = user_session["queue"]
39
+
40
  try:
41
  return await queue.get()
42
  except asyncio.QueueEmpty:
 
56
  def get_user_count(self) -> int:
57
  return len(self.data_content)
58
 
59
+ def get_websocket(self, user_id: UUID) -> WebSocket:
60
+ return self.data_content[user_id]["websocket"]
61
+
62
 
63
  user_data = UserData()