Your Name commited on
Commit
92627a4
1 Parent(s): 8cd501f

PD model and funcitonnal endpoint inference + check progress'

Browse files
feature_extractor/preprocessor_config.json CHANGED
@@ -14,7 +14,7 @@
14
  0.4578275,
15
  0.40821073
16
  ],
17
- "image_processor_type": "CLIPImageProcessor",
18
  "image_std": [
19
  0.26862954,
20
  0.26130258,
 
14
  0.4578275,
15
  0.40821073
16
  ],
17
+ "image_processor_type": "CLIPFeatureExtractor",
18
  "image_std": [
19
  0.26862954,
20
  0.26130258,
handler.py CHANGED
@@ -7,6 +7,9 @@ from pprint import pprint
7
  from typing import Any, Dict, List
8
  import os
9
  from pathlib import Path
 
 
 
10
 
11
  import torch
12
  from diffusers import (
@@ -14,12 +17,12 @@ from diffusers import (
14
  DPMSolverMultistepScheduler,
15
  DPMSolverSinglestepScheduler,
16
  EulerAncestralDiscreteScheduler,
 
17
  )
18
  from safetensors.torch import load_file
19
- from torch import autocast
20
-
21
- # https://huggingface.co/philschmid/stable-diffusion-v1-4-endpoints
22
- # https://huggingface.co/docs/inference-endpoints/guides/custom_handler
23
 
24
  REPO_DIR = Path(__file__).resolve().parent
25
 
@@ -48,6 +51,7 @@ class EndpointHandler:
48
  "detailed_eye-10": str(REPO_DIR / "lora/detailed_eye-10.safetensors"),
49
  "add_detail": str(REPO_DIR / "lora/add_detail.safetensors"),
50
  "MuscleGirl_v1": str(REPO_DIR / "lora/MuscleGirl_v1.safetensors"),
 
51
  }
52
 
53
  TEXTUAL_INVERSION = [
@@ -55,10 +59,6 @@ class EndpointHandler:
55
  "weight_name": str(REPO_DIR / "embeddings/EasyNegative.safetensors"),
56
  "token": "easynegative",
57
  },
58
- {
59
- "weight_name": str(REPO_DIR / "embeddings/EasyNegative.safetensors"),
60
- "token": "EasyNegative",
61
- },
62
  {
63
  "weight_name": str(REPO_DIR / "embeddings/badhandv4.pt"),
64
  "token": "badhandv4",
@@ -69,16 +69,12 @@ class EndpointHandler:
69
  },
70
  {
71
  "weight_name": str(REPO_DIR / "embeddings/NegfeetV2.pt"),
72
- "token": "NegfeetV2",
73
  },
74
  {
75
  "weight_name": str(REPO_DIR / "embeddings/ng_deepnegative_v1_75t.pt"),
76
  "token": "ng_deepnegative_v1_75t",
77
  },
78
- {
79
- "weight_name": str(REPO_DIR / "embeddings/ng_deepnegative_v1_75t.pt"),
80
- "token": "NG_DeepNegative_V1_75T",
81
- },
82
  {
83
  "weight_name": str(REPO_DIR / "embeddings/bad-hands-5.pt"),
84
  "token": "bad-hands-5",
@@ -86,6 +82,15 @@ class EndpointHandler:
86
  ]
87
 
88
  def __init__(self, path="."):
 
 
 
 
 
 
 
 
 
89
  # load the optimized model
90
  self.pipe = DiffusionPipeline.from_pretrained(
91
  path,
@@ -94,30 +99,31 @@ class EndpointHandler:
94
  )
95
  self.pipe = self.pipe.to(device)
96
 
 
 
97
  # DPM++ 2M SDE Karras
98
  # increase step to avoid high contrast num_inference_steps=30
 
 
 
 
 
 
 
99
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
100
  self.pipe.scheduler.config,
101
  use_karras_sigmas=True,
102
- algorithm_type="sde-dpmsolver++",
103
  )
104
 
105
  # Mode boulardus
106
  self.pipe.safety_checker = None
107
 
 
 
 
108
  # Load negative embeddings to avoid bad hands, etc
109
  self.load_embeddings()
110
 
111
- # Load default Lora models
112
- self.pipe = self.load_selected_loras(
113
- [
114
- ("polyhedron_new_skin_v1.1", 0.35), # nice Skin
115
- ("detailed_eye-10", 0.3), # nice eyes
116
- ("add_detail", 0.4), # detailed pictures
117
- ("MuscleGirl_v1", 0.3), # shape persons
118
- ],
119
- )
120
-
121
  # boosts performance by another 20%
122
  self.pipe.enable_xformers_memory_efficient_attention()
123
  self.pipe.enable_attention_slicing()
@@ -215,14 +221,121 @@ class EndpointHandler:
215
  )
216
  return self.pipe
217
 
218
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
219
- """
220
- Args:
221
- data (:obj:):
222
- includes the input data and the parameters for the inference.
223
- Return:
224
- A :obj:`dict`:. base64 encoded image
225
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  global device
227
 
228
  # Which Lora do we load ?
@@ -241,8 +354,8 @@ class EndpointHandler:
241
  "width",
242
  "num_inference_steps",
243
  "height",
244
- "seed",
245
  "guidance_scale",
 
246
  ]
247
 
248
  missing_fields = [field for field in required_fields if field not in data]
@@ -256,17 +369,21 @@ class EndpointHandler:
256
  # Now extract the fields
257
  prompt = data["prompt"]
258
  negative_prompt = data["negative_prompt"]
259
- loras_model = data.pop("loras_model", None)
260
- seed = data["seed"]
261
  width = data["width"]
262
  num_inference_steps = data["num_inference_steps"]
263
  height = data["height"]
264
  guidance_scale = data["guidance_scale"]
 
 
 
 
265
 
266
  # USe this to add automatically some negative prompts
267
  forced_negative = (
268
  negative_prompt
269
- + """easynegative, badhandv4, bad-artist-anime, NegfeetV2, ng_deepnegative_v1_75t, bad-hands-5 """
270
  )
271
 
272
  # Set the generator seed if provided
@@ -288,15 +405,20 @@ class EndpointHandler:
288
  negative_prompt=forced_negative,
289
  generator=generator,
290
  max_embeddings_multiples=5,
 
 
 
 
 
291
  ).images[0]
292
 
293
- # encode image as base 64
294
- buffered = BytesIO()
295
- image.save(buffered, format="JPEG")
296
- img_str = base64.b64encode(buffered.getvalue())
297
 
298
- # Return the success response
299
- return {"flag": "success", "image": img_str.decode()}
300
 
301
  except Exception as e:
302
  # Handle any other exceptions and return an error response
 
7
  from typing import Any, Dict, List
8
  import os
9
  from pathlib import Path
10
+ from typing import Union
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ import numpy as np
13
 
14
  import torch
15
  from diffusers import (
 
17
  DPMSolverMultistepScheduler,
18
  DPMSolverSinglestepScheduler,
19
  EulerAncestralDiscreteScheduler,
20
+ utils,
21
  )
22
  from safetensors.torch import load_file
23
+ from torch import autocast, tensor
24
+ import torchvision.transforms
25
+ from PIL import Image
 
26
 
27
  REPO_DIR = Path(__file__).resolve().parent
28
 
 
51
  "detailed_eye-10": str(REPO_DIR / "lora/detailed_eye-10.safetensors"),
52
  "add_detail": str(REPO_DIR / "lora/add_detail.safetensors"),
53
  "MuscleGirl_v1": str(REPO_DIR / "lora/MuscleGirl_v1.safetensors"),
54
+ "flat2": str(REPO_DIR / "lora/flat2.safetensors"),
55
  }
56
 
57
  TEXTUAL_INVERSION = [
 
59
  "weight_name": str(REPO_DIR / "embeddings/EasyNegative.safetensors"),
60
  "token": "easynegative",
61
  },
 
 
 
 
62
  {
63
  "weight_name": str(REPO_DIR / "embeddings/badhandv4.pt"),
64
  "token": "badhandv4",
 
69
  },
70
  {
71
  "weight_name": str(REPO_DIR / "embeddings/NegfeetV2.pt"),
72
+ "token": "negfeetv2",
73
  },
74
  {
75
  "weight_name": str(REPO_DIR / "embeddings/ng_deepnegative_v1_75t.pt"),
76
  "token": "ng_deepnegative_v1_75t",
77
  },
 
 
 
 
78
  {
79
  "weight_name": str(REPO_DIR / "embeddings/bad-hands-5.pt"),
80
  "token": "bad-hands-5",
 
82
  ]
83
 
84
  def __init__(self, path="."):
85
+ self.inference_progress = {} # Dictionary to store progress of each request
86
+ self.inference_images = {} # Dictionary to store latest image of each request
87
+ self.total_steps = {}
88
+ self.inference_in_progress = False
89
+
90
+ self.executor = ThreadPoolExecutor(
91
+ max_workers=1
92
+ ) # Vous pouvez ajuster max_workers en fonction de vos besoins
93
+
94
  # load the optimized model
95
  self.pipe = DiffusionPipeline.from_pretrained(
96
  path,
 
99
  )
100
  self.pipe = self.pipe.to(device)
101
 
102
+ # https://stablediffusionapi.com/docs/a1111schedulers/
103
+
104
  # DPM++ 2M SDE Karras
105
  # increase step to avoid high contrast num_inference_steps=30
106
+ # self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
107
+ # self.pipe.scheduler.config,
108
+ # use_karras_sigmas=True,
109
+ # algorithm_type="sde-dpmsolver++",
110
+ # )
111
+ # DPM++ 2M Karras
112
+ # increase step to avoid high contrast num_inference_steps=30
113
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
114
  self.pipe.scheduler.config,
115
  use_karras_sigmas=True,
 
116
  )
117
 
118
  # Mode boulardus
119
  self.pipe.safety_checker = None
120
 
121
+ # Disable progress bar
122
+ self.pipe.set_progress_bar_config(disable=True)
123
+
124
  # Load negative embeddings to avoid bad hands, etc
125
  self.load_embeddings()
126
 
 
 
 
 
 
 
 
 
 
 
127
  # boosts performance by another 20%
128
  self.pipe.enable_xformers_memory_efficient_attention()
129
  self.pipe.enable_attention_slicing()
 
221
  )
222
  return self.pipe
223
 
224
+ def __call__(self, data: Any) -> Dict:
225
+ """Handle incoming requests."""
226
+
227
+ action = data.get("action", None)
228
+ request_id = data.get("request_id")
229
+
230
+ # Check if the request_id is valid for all actions
231
+ if not request_id:
232
+ return {"flag": "error", "message": "Missing request_id."}
233
+
234
+ if action == "check_progress":
235
+ return self.check_progress(request_id)
236
+
237
+ elif action == "inference":
238
+ # Check if an inference is already in progress
239
+ if self.inference_in_progress:
240
+ return {
241
+ "flag": "error",
242
+ "message": "Another inference is already in progress. Please wait.",
243
+ }
244
+
245
+ # Set the inference state to in progress
246
+ self.clean_request_data(request_id)
247
+ self.inference_in_progress = True
248
+ self.inference_progress[request_id] = 0
249
+ self.inference_images[request_id] = None
250
+
251
+ self.executor.submit(self.start_inference, data)
252
+
253
+ return {
254
+ "flag": "success",
255
+ "message": "Inference started",
256
+ "request_id": request_id,
257
+ }
258
+
259
+ else:
260
+ return {"flag": "error", "message": f"Unsupported action: {action}"}
261
+
262
+ def clean_request_data(self, request_id: str):
263
+ """Clean up the data related to a specific request ID."""
264
+
265
+ # Remove the request ID from the progress dictionary
266
+ self.inference_progress.pop(request_id, None)
267
+
268
+ # Remove the request ID from the images dictionary
269
+ self.inference_images.pop(request_id, None)
270
+
271
+ # Remove the request ID from the total_steps dictionary
272
+ self.total_steps.pop(request_id, None)
273
+
274
+ # Set inference to False
275
+ self.inference_in_progress = False
276
+
277
+ def progress_callback(
278
+ self,
279
+ step: int,
280
+ timestep: int,
281
+ latents: Any,
282
+ request_id: str,
283
+ status: str,
284
+ ):
285
+ try:
286
+ if status == "progress":
287
+ # Latents to numpy
288
+ img_data = self.pipe.decode_latents(latents)
289
+ img_data = (img_data.squeeze() * 255).astype(np.uint8)
290
+ img = Image.fromarray(img_data, "RGB")
291
+ # print(img_data)
292
+ else:
293
+ # pil object
294
+ # print(latents)
295
+ img = latents
296
+
297
+ buffered = BytesIO()
298
+ img.save(buffered, format="PNG")
299
+
300
+ # print(status)
301
+ # Save the image to a file
302
+ # img.save("squirel.png", format="PNG")
303
+
304
+ # Encode the image into a base64 string representation
305
+ img_str = base64.b64encode(buffered.getvalue()).decode()
306
+
307
+ except Exception as e:
308
+ print(f"Error: {e}")
309
+
310
+ # Store progress and image
311
+ progress_percentage = (
312
+ step / self.total_steps[request_id]
313
+ ) * 100 # Assuming self.total_steps is the total number of steps for inference
314
+
315
+ self.inference_progress[request_id] = progress_percentage
316
+ self.inference_images[request_id] = img_str
317
+
318
+ def check_progress(self, request_id: str) -> Dict[str, Union[str, float]]:
319
+ progress = self.inference_progress.get(request_id, 0)
320
+ latest_image = self.inference_images.get(request_id, None)
321
+
322
+ # print(self.inference_progress)
323
+
324
+ if progress >= 100:
325
+ status = "complete"
326
+ else:
327
+ status = "in-progress"
328
+
329
+ return {
330
+ "flag": "success",
331
+ "status": status,
332
+ "progress": int(progress),
333
+ "image": latest_image,
334
+ }
335
+
336
+ def start_inference(self, data: Dict) -> Dict:
337
+ """Start a new inference."""
338
+
339
  global device
340
 
341
  # Which Lora do we load ?
 
354
  "width",
355
  "num_inference_steps",
356
  "height",
 
357
  "guidance_scale",
358
+ "request_id",
359
  ]
360
 
361
  missing_fields = [field for field in required_fields if field not in data]
 
369
  # Now extract the fields
370
  prompt = data["prompt"]
371
  negative_prompt = data["negative_prompt"]
372
+ loras_model = data.get("loras_model", None)
373
+ seed = data.get("seed", None)
374
  width = data["width"]
375
  num_inference_steps = data["num_inference_steps"]
376
  height = data["height"]
377
  guidance_scale = data["guidance_scale"]
378
+ request_id = data["request_id"]
379
+
380
+ # Used for progress checker
381
+ self.total_steps[request_id] = num_inference_steps
382
 
383
  # USe this to add automatically some negative prompts
384
  forced_negative = (
385
  negative_prompt
386
+ + """, easynegative, badhandv4, bad-artist-anime, negfeetv2, ng_deepnegative_v1_75t, bad-hands-5, """
387
  )
388
 
389
  # Set the generator seed if provided
 
405
  negative_prompt=forced_negative,
406
  generator=generator,
407
  max_embeddings_multiples=5,
408
+ callback=lambda step, timestep, latents: self.progress_callback(
409
+ step, timestep, latents, request_id, "progress"
410
+ ),
411
+ callback_steps=8, # The frequency at which the callback function is called.
412
+ # output_type="pt",
413
  ).images[0]
414
 
415
+ # print(image)
416
+ self.progress_callback(
417
+ num_inference_steps, 0, image, request_id, "complete"
418
+ )
419
 
420
+ # for debug
421
+ # image.save("squirelb.png", format="PNG")
422
 
423
  except Exception as e:
424
  # Handle any other exceptions and return an error response
lora/flat2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:054e950e72181bb45ddbc7106d3625de406477725b5b313a91fe4522f03dbe0a
3
+ size 6865699
model_index.json CHANGED
@@ -3,7 +3,7 @@
3
  "_diffusers_version": "0.20.0",
4
  "feature_extractor": [
5
  "transformers",
6
- "CLIPImageProcessor"
7
  ],
8
  "requires_safety_checker": true,
9
  "safety_checker": [
 
3
  "_diffusers_version": "0.20.0",
4
  "feature_extractor": [
5
  "transformers",
6
+ "CLIPFeatureExtractor"
7
  ],
8
  "requires_safety_checker": true,
9
  "safety_checker": [
safety_checker/config.json CHANGED
@@ -15,7 +15,7 @@
15
  "attention_dropout": 0.0,
16
  "bad_words_ids": null,
17
  "begin_suppress_tokens": null,
18
- "bos_token_id": 0,
19
  "chunk_size_feed_forward": 0,
20
  "cross_attention_hidden_size": null,
21
  "decoder_start_token_id": null,
@@ -24,7 +24,7 @@
24
  "dropout": 0.0,
25
  "early_stopping": false,
26
  "encoder_no_repeat_ngram_size": 0,
27
- "eos_token_id": 2,
28
  "exponential_decay_length_penalty": null,
29
  "finetuning_task": null,
30
  "forced_bos_token_id": null,
@@ -80,17 +80,11 @@
80
  "top_p": 1.0,
81
  "torch_dtype": null,
82
  "torchscript": false,
83
- "transformers_version": "4.25.1",
84
  "typical_p": 1.0,
85
  "use_bfloat16": false,
86
  "vocab_size": 49408
87
  },
88
- "text_config_dict": {
89
- "hidden_size": 768,
90
- "intermediate_size": 3072,
91
- "num_attention_heads": 12,
92
- "num_hidden_layers": 12
93
- },
94
  "torch_dtype": "float32",
95
  "transformers_version": null,
96
  "vision_config": {
@@ -167,15 +161,8 @@
167
  "top_p": 1.0,
168
  "torch_dtype": null,
169
  "torchscript": false,
170
- "transformers_version": "4.25.1",
171
  "typical_p": 1.0,
172
  "use_bfloat16": false
173
- },
174
- "vision_config_dict": {
175
- "hidden_size": 1024,
176
- "intermediate_size": 4096,
177
- "num_attention_heads": 16,
178
- "num_hidden_layers": 24,
179
- "patch_size": 14
180
  }
181
  }
 
15
  "attention_dropout": 0.0,
16
  "bad_words_ids": null,
17
  "begin_suppress_tokens": null,
18
+ "bos_token_id": 49406,
19
  "chunk_size_feed_forward": 0,
20
  "cross_attention_hidden_size": null,
21
  "decoder_start_token_id": null,
 
24
  "dropout": 0.0,
25
  "early_stopping": false,
26
  "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 49407,
28
  "exponential_decay_length_penalty": null,
29
  "finetuning_task": null,
30
  "forced_bos_token_id": null,
 
80
  "top_p": 1.0,
81
  "torch_dtype": null,
82
  "torchscript": false,
83
+ "transformers_version": "4.31.0",
84
  "typical_p": 1.0,
85
  "use_bfloat16": false,
86
  "vocab_size": 49408
87
  },
 
 
 
 
 
 
88
  "torch_dtype": "float32",
89
  "transformers_version": null,
90
  "vision_config": {
 
161
  "top_p": 1.0,
162
  "torch_dtype": null,
163
  "torchscript": false,
164
+ "transformers_version": "4.31.0",
165
  "typical_p": 1.0,
166
  "use_bfloat16": false
 
 
 
 
 
 
 
167
  }
168
  }
safety_checker/pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:16d28f2b37109f222cdc33620fdd262102ac32112be0352a7f77e9614b35a394
3
- size 1216064769
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:753acd54aa6d288d6c0ce9d51468eb28f495fcbaacf0edf755fa5fc7ce678cd9
3
+ size 1216062333
text_encoder/config.json CHANGED
@@ -19,6 +19,6 @@
19
  "pad_token_id": 1,
20
  "projection_dim": 768,
21
  "torch_dtype": "float32",
22
- "transformers_version": "4.25.1",
23
  "vocab_size": 49408
24
  }
 
19
  "pad_token_id": 1,
20
  "projection_dim": 768,
21
  "torch_dtype": "float32",
22
+ "transformers_version": "4.31.0",
23
  "vocab_size": 49408
24
  }
text_encoder/pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:57f6e3badaffb5713c93e1f34ac3abf2ee3cd48e60d01714a0a6ed33f3406a5a
3
- size 492307041
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38a67003cd791d4fc008ae1fd24615b8b168f83cc8e853b746a7ec7bb3d64f42
3
+ size 492306077
tokenizer/tokenizer_config.json CHANGED
@@ -8,6 +8,7 @@
8
  "rstrip": false,
9
  "single_word": false
10
  },
 
11
  "do_lower_case": true,
12
  "eos_token": {
13
  "__type": "AddedToken",
@@ -19,9 +20,7 @@
19
  },
20
  "errors": "replace",
21
  "model_max_length": 77,
22
- "name_or_path": "openai/clip-vit-large-patch14",
23
  "pad_token": "<|endoftext|>",
24
- "special_tokens_map_file": "./special_tokens_map.json",
25
  "tokenizer_class": "CLIPTokenizer",
26
  "unk_token": {
27
  "__type": "AddedToken",
 
8
  "rstrip": false,
9
  "single_word": false
10
  },
11
+ "clean_up_tokenization_spaces": true,
12
  "do_lower_case": true,
13
  "eos_token": {
14
  "__type": "AddedToken",
 
20
  },
21
  "errors": "replace",
22
  "model_max_length": 77,
 
23
  "pad_token": "<|endoftext|>",
 
24
  "tokenizer_class": "CLIPTokenizer",
25
  "unk_token": {
26
  "__type": "AddedToken",