radames commited on
Commit
f45636e
·
1 Parent(s): cd80689

controlnetLoraSD15.py

Browse files
frontend/src/lib/components/PipelineOptions.svelte CHANGED
@@ -6,6 +6,7 @@
6
  import SeedInput from './SeedInput.svelte';
7
  import TextArea from './TextArea.svelte';
8
  import Checkbox from './Checkbox.svelte';
 
9
  import { pipelineValues } from '$lib/store';
10
 
11
  export let pipelineParams: FieldProps[];
@@ -25,6 +26,8 @@
25
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
26
  {:else if params.field === FieldType.CHECKBOX}
27
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
 
 
28
  {/if}
29
  {/each}
30
  {/if}
@@ -45,6 +48,8 @@
45
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
46
  {:else if params.field === FieldType.CHECKBOX}
47
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
 
 
48
  {/if}
49
  {/each}
50
  {/if}
 
6
  import SeedInput from './SeedInput.svelte';
7
  import TextArea from './TextArea.svelte';
8
  import Checkbox from './Checkbox.svelte';
9
+ import Selectlist from './Selectlist.svelte';
10
  import { pipelineValues } from '$lib/store';
11
 
12
  export let pipelineParams: FieldProps[];
 
26
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
27
  {:else if params.field === FieldType.CHECKBOX}
28
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
29
+ {:else if params.field === FieldType.SELECT}
30
+ <Selectlist {params} bind:value={$pipelineValues[params.id]}></Selectlist>
31
  {/if}
32
  {/each}
33
  {/if}
 
48
  <TextArea {params} bind:value={$pipelineValues[params.id]}></TextArea>
49
  {:else if params.field === FieldType.CHECKBOX}
50
  <Checkbox {params} bind:value={$pipelineValues[params.id]}></Checkbox>
51
+ {:else if params.field === FieldType.SELECT}
52
+ <Selectlist {params} bind:value={$pipelineValues[params.id]}></Selectlist>
53
  {/if}
54
  {/each}
55
  {/if}
frontend/src/lib/components/Selectlist.svelte ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import type { FieldProps } from '$lib/types';
3
+ import { onMount } from 'svelte';
4
+ export let value = '';
5
+ export let params: FieldProps;
6
+ onMount(() => {
7
+ value = String(params?.default);
8
+ });
9
+ </script>
10
+
11
+ <div class="grid max-w-md grid-cols-4 items-center justify-items-start gap-3">
12
+ <label for="model-list" class="font-medium">{params?.title} </label>
13
+ {#if params?.values}
14
+ <select
15
+ bind:value
16
+ id="model-list"
17
+ class="cursor-pointer rounded-md border-2 border-gray-500 p-2 font-light dark:text-black"
18
+ >
19
+ {#each params.values as model, i}
20
+ <option value={model} selected={i === 0}>{model}</option>
21
+ {/each}
22
+ </select>
23
+ {/if}
24
+ </div>
frontend/src/lib/lcmLive.ts CHANGED
@@ -36,7 +36,6 @@ export const lcmLiveActions = {
36
  };
37
  websocket.onmessage = (event) => {
38
  const data = JSON.parse(event.data);
39
- console.log("WS: ", data);
40
  switch (data.status) {
41
  case "connected":
42
  const userId = data.userId;
 
36
  };
37
  websocket.onmessage = (event) => {
38
  const data = JSON.parse(event.data);
 
39
  switch (data.status) {
40
  case "connected":
41
  const userId = data.userId;
frontend/src/lib/types.ts CHANGED
@@ -3,6 +3,7 @@ export const enum FieldType {
3
  SEED = "seed",
4
  TEXTAREA = "textarea",
5
  CHECKBOX = "checkbox",
 
6
  }
7
  export const enum PipelineMode {
8
  IMAGE = "image",
@@ -20,6 +21,7 @@ export interface FieldProps {
20
  disabled?: boolean;
21
  hide?: boolean;
22
  id: string;
 
23
  }
24
  export interface PipelineInfo {
25
  title: {
 
3
  SEED = "seed",
4
  TEXTAREA = "textarea",
5
  CHECKBOX = "checkbox",
6
+ SELECT = "select",
7
  }
8
  export const enum PipelineMode {
9
  IMAGE = "image",
 
21
  disabled?: boolean;
22
  hide?: boolean;
23
  id: string;
24
+ values?: string[];
25
  }
26
  export interface PipelineInfo {
27
  title: {
frontend/src/routes/+page.svelte CHANGED
@@ -9,12 +9,7 @@
9
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
10
  import Spinner from '$lib/icons/spinner.svelte';
11
  import { lcmLiveStatus, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
12
- import {
13
- mediaStreamActions,
14
- mediaStreamStatus,
15
- onFrameChangeStore,
16
- MediaStreamStatusEnum
17
- } from '$lib/mediaStream';
18
  import { getPipelineValues, deboucedPipelineValues } from '$lib/store';
19
 
20
  let pipelineParams: FieldProps[];
@@ -44,9 +39,7 @@
44
  }
45
 
46
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
47
- $: {
48
- console.log('lcmLiveStatus', $lcmLiveStatus);
49
- }
50
  let disabled = false;
51
  async function toggleLcmLive() {
52
  if (!isLCMRunning) {
 
9
  import PipelineOptions from '$lib/components/PipelineOptions.svelte';
10
  import Spinner from '$lib/icons/spinner.svelte';
11
  import { lcmLiveStatus, lcmLiveActions, LCMLiveStatus } from '$lib/lcmLive';
12
+ import { mediaStreamActions, onFrameChangeStore } from '$lib/mediaStream';
 
 
 
 
 
13
  import { getPipelineValues, deboucedPipelineValues } from '$lib/store';
14
 
15
  let pipelineParams: FieldProps[];
 
39
  }
40
 
41
  $: isLCMRunning = $lcmLiveStatus !== LCMLiveStatus.DISCONNECTED;
42
+
 
 
43
  let disabled = false;
44
  async function toggleLcmLive() {
45
  if (!isLCMRunning) {
pipelines/controlnetLoraSD15.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ LCMScheduler,
5
+ )
6
+ from compel import Compel
7
+ import torch
8
+ from pipelines.utils.canny_gpu import SobelOperator
9
+
10
+ try:
11
+ import intel_extension_for_pytorch as ipex # type: ignore
12
+ except:
13
+ pass
14
+
15
+ import psutil
16
+ from config import Args
17
+ from pydantic import BaseModel, Field
18
+ from PIL import Image
19
+
20
+ taesd_model = "madebyollin/taesd"
21
+ controlnet_model = "lllyasviel/control_v11p_sd15_canny"
22
+ base_models = [
23
+ "plasmo/woolitize",
24
+ "nitrosocke/Ghibli-Diffusion",
25
+ "nitrosocke/mo-di-diffusion",
26
+ ]
27
+ lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
28
+
29
+
30
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
31
+
32
+
33
+ class Pipeline:
34
+ class Info(BaseModel):
35
+ name: str = "controlnet+loras+sd15"
36
+ title: str = "LCM + LoRA + Controlnet "
37
+ description: str = "Generates an image from a text prompt"
38
+ input_mode: str = "image"
39
+
40
+ class InputParams(BaseModel):
41
+ prompt: str = Field(
42
+ default_prompt,
43
+ title="Prompt",
44
+ field="textarea",
45
+ id="prompt",
46
+ )
47
+ model_id: str = Field(
48
+ "plasmo/woolitize",
49
+ title="Base Models List",
50
+ values=base_models,
51
+ field="select",
52
+ id="model_id",
53
+ )
54
+ seed: int = Field(
55
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
56
+ )
57
+ steps: int = Field(
58
+ 4, min=2, max=15, title="Steps", field="range", hide=True, id="steps"
59
+ )
60
+ width: int = Field(
61
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
62
+ )
63
+ height: int = Field(
64
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
65
+ )
66
+ guidance_scale: float = Field(
67
+ 0.2,
68
+ min=0,
69
+ max=2,
70
+ step=0.001,
71
+ title="Guidance Scale",
72
+ field="range",
73
+ hide=True,
74
+ id="guidance_scale",
75
+ )
76
+ strength: float = Field(
77
+ 0.5,
78
+ min=0.25,
79
+ max=1.0,
80
+ step=0.001,
81
+ title="Strength",
82
+ field="range",
83
+ hide=True,
84
+ id="strength",
85
+ )
86
+ controlnet_scale: float = Field(
87
+ 0.8,
88
+ min=0,
89
+ max=1.0,
90
+ step=0.001,
91
+ title="Controlnet Scale",
92
+ field="range",
93
+ hide=True,
94
+ id="controlnet_scale",
95
+ )
96
+ controlnet_start: float = Field(
97
+ 0.0,
98
+ min=0,
99
+ max=1.0,
100
+ step=0.001,
101
+ title="Controlnet Start",
102
+ field="range",
103
+ hide=True,
104
+ id="controlnet_start",
105
+ )
106
+ controlnet_end: float = Field(
107
+ 1.0,
108
+ min=0,
109
+ max=1.0,
110
+ step=0.001,
111
+ title="Controlnet End",
112
+ field="range",
113
+ hide=True,
114
+ id="controlnet_end",
115
+ )
116
+ canny_low_threshold: float = Field(
117
+ 0.31,
118
+ min=0,
119
+ max=1.0,
120
+ step=0.001,
121
+ title="Canny Low Threshold",
122
+ field="range",
123
+ hide=True,
124
+ id="canny_low_threshold",
125
+ )
126
+ canny_high_threshold: float = Field(
127
+ 0.125,
128
+ min=0,
129
+ max=1.0,
130
+ step=0.001,
131
+ title="Canny High Threshold",
132
+ field="range",
133
+ hide=True,
134
+ id="canny_high_threshold",
135
+ )
136
+ debug_canny: bool = Field(
137
+ False,
138
+ title="Debug Canny",
139
+ field="checkbox",
140
+ hide=True,
141
+ id="debug_canny",
142
+ )
143
+
144
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
145
+ controlnet_canny = ControlNetModel.from_pretrained(
146
+ controlnet_model, torch_dtype=torch_dtype
147
+ ).to(device)
148
+
149
+ self.pipes = {}
150
+
151
+ if args.safety_checker:
152
+ for model_id in base_models:
153
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
154
+ model_id,
155
+ controlnet=controlnet_canny,
156
+ )
157
+ self.pipes[model_id] = pipe
158
+ else:
159
+ for model_id in base_models:
160
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
161
+ model_id,
162
+ safety_checker=None,
163
+ controlnet=controlnet_canny,
164
+ )
165
+ self.pipes[model_id] = pipe
166
+
167
+ self.canny_torch = SobelOperator(device=device)
168
+
169
+ for pipe in self.pipes.values():
170
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
171
+ pipe.set_progress_bar_config(disable=True)
172
+ pipe.to(device=device, dtype=torch_dtype).to(device)
173
+
174
+ if psutil.virtual_memory().total < 64 * 1024**3:
175
+ pipe.enable_attention_slicing()
176
+
177
+ # Load LCM LoRA
178
+ pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
179
+ pipe.compel_proc = Compel(
180
+ tokenizer=pipe.tokenizer,
181
+ text_encoder=pipe.text_encoder,
182
+ truncate_long_prompts=False,
183
+ )
184
+ if args.torch_compile:
185
+ pipe.unet = torch.compile(
186
+ pipe.unet, mode="reduce-overhead", fullgraph=True
187
+ )
188
+ pipe.vae = torch.compile(
189
+ pipe.vae, mode="reduce-overhead", fullgraph=True
190
+ )
191
+ pipe(
192
+ prompt="warmup",
193
+ image=[Image.new("RGB", (768, 768))],
194
+ control_image=[Image.new("RGB", (768, 768))],
195
+ )
196
+
197
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
198
+ generator = torch.manual_seed(params.seed)
199
+ print(f"Using model: {params.model_id}")
200
+ pipe = self.pipes[params.model_id]
201
+
202
+ prompt_embeds = pipe.compel_proc(params.prompt)
203
+ control_image = self.canny_torch(
204
+ params.image, params.canny_low_threshold, params.canny_high_threshold
205
+ )
206
+
207
+ results = pipe(
208
+ image=params.image,
209
+ control_image=control_image,
210
+ prompt_embeds=prompt_embeds,
211
+ generator=generator,
212
+ strength=params.strength,
213
+ num_inference_steps=params.steps,
214
+ guidance_scale=params.guidance_scale,
215
+ width=params.width,
216
+ height=params.height,
217
+ output_type="pil",
218
+ controlnet_conditioning_scale=params.controlnet_scale,
219
+ control_guidance_start=params.controlnet_start,
220
+ control_guidance_end=params.controlnet_end,
221
+ )
222
+
223
+ nsfw_content_detected = (
224
+ results.nsfw_content_detected[0]
225
+ if "nsfw_content_detected" in results
226
+ else False
227
+ )
228
+ if nsfw_content_detected:
229
+ return None
230
+ result_image = results.images[0]
231
+ if params.debug_canny:
232
+ # paste control_image on top of result_image
233
+ w0, h0 = (200, 200)
234
+ control_image = control_image.resize((w0, h0))
235
+ w1, h1 = result_image.size
236
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
237
+
238
+ return result_image