alvdansen commited on
Commit
9b35819
β€’
1 Parent(s): eaaefc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -94
app.py CHANGED
@@ -1,63 +1,17 @@
1
  import json
2
  import random
3
- import requests
4
- import os
5
- from PIL import Image
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
10
  import torch
11
  from diffusers import DiffusionPipeline, LCMScheduler
12
- from peft import PeftModel
13
-
14
- # Custom LCMScheduler to ignore unexpected attributes
15
- class CustomLCMScheduler(LCMScheduler):
16
- @property
17
- def config(self):
18
- return {k: v for k, v in super().config.items() if k != "skip_prk_steps"}
19
-
20
- def get_image(image_data):
21
- if isinstance(image_data, str):
22
- return image_data
23
-
24
- if isinstance(image_data, dict):
25
- local_path = image_data.get('local_path')
26
- hf_url = image_data.get('hf_url')
27
- else:
28
- print(f"Unexpected image_data format: {type(image_data)}")
29
- return None
30
-
31
- # Try loading from local path first
32
- if local_path and os.path.exists(local_path):
33
- try:
34
- Image.open(local_path).verify() # Verify that it's a valid image
35
- return local_path
36
- except Exception as e:
37
- print(f"Error loading local image {local_path}: {e}")
38
-
39
- # If local path fails or doesn't exist, try URL
40
- if hf_url:
41
- try:
42
- response = requests.get(hf_url)
43
- if response.status_code == 200:
44
- img = Image.open(requests.get(hf_url, stream=True).raw)
45
- img.verify() # Verify that it's a valid image
46
- img.save(local_path) # Save for future use
47
- return local_path
48
- else:
49
- print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
50
- except Exception as e:
51
- print(f"Error loading image from URL {hf_url}: {e}")
52
-
53
- print(f"Failed to load image for {image_data}")
54
- return None
55
 
56
  with open("sdxl_lora.json", "r") as file:
57
  data = json.load(file)
58
  sdxl_loras_raw = [
59
  {
60
- "image": get_image(item["image"]),
61
  "title": item["title"],
62
  "repo": item["repo"],
63
  "trigger_word": item["trigger_word"],
@@ -69,27 +23,33 @@ with open("sdxl_lora.json", "r") as file:
69
  for item in data
70
  ]
71
 
 
72
  sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
73
 
74
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
75
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
76
 
77
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
78
- pipe.scheduler = CustomLCMScheduler.from_config(pipe.scheduler.config)
 
79
  pipe.to(device=DEVICE, dtype=torch.float16)
80
 
81
- # Load Flash SDXL LoRA
82
- flash_sdxl_id = "jasperai/flash-sdxl"
83
- pipe.load_lora_weights(flash_sdxl_id, adapter_name="flash_lora")
84
 
85
  MAX_SEED = np.iinfo(np.int32).max
86
  MAX_IMAGE_SIZE = 1024
87
 
88
- def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
 
 
 
 
 
89
  lora_id = gr_sdxl_loras[selected_state.index]["repo"]
90
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
 
91
  return lora_id, trigger_word
92
 
 
93
  @spaces.GPU
94
  def infer(
95
  pre_prompt,
@@ -103,51 +63,19 @@ def infer(
103
  user_lora_weight,
104
  progress=gr.Progress(track_tqdm=True),
105
  ):
106
- try:
107
- # Load the user-selected LoRA
108
- new_adapter_id = user_lora_selector.replace("/", "_")
109
- pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
110
 
111
- # Set adapter weights
112
- pipe.set_adapters(["flash_lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
113
- gr.Info("LoRA setup complete")
114
 
115
- if randomize_seed:
116
- seed = random.randint(0, MAX_SEED)
117
-
118
- generator = torch.Generator().manual_seed(seed)
119
-
120
- if pre_prompt != "":
121
- prompt = f"{pre_prompt} {prompt}"
122
-
123
- # Use Flash Diffusion settings
124
- image = pipe(
125
- prompt=prompt,
126
- negative_prompt=negative_prompt,
127
- guidance_scale=1.0, # Flash Diffusion typically uses guidance_scale=1
128
- num_inference_steps=4, # Flash Diffusion uses fewer steps
129
- generator=generator,
130
- ).images[0]
131
-
132
- return image
133
- except Exception as e:
134
- gr.Error(f"An error occurred: {str(e)}")
135
- return None
136
 
137
- @spaces.GPU
138
- def infer(
139
- pre_prompt,
140
- prompt,
141
- seed,
142
- randomize_seed,
143
- num_inference_steps,
144
- negative_prompt,
145
- guidance_scale,
146
- user_lora_selector,
147
- user_lora_weight,
148
- progress=gr.Progress(track_tqdm=True),
149
- ):
150
- load_lora_for_style(user_lora_selector)
151
 
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
 
1
  import json
2
  import random
 
 
 
3
 
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  with open("sdxl_lora.json", "r") as file:
11
  data = json.load(file)
12
  sdxl_loras_raw = [
13
  {
14
+ "image": item["image"],
15
  "title": item["title"],
16
  "repo": item["repo"],
17
  "trigger_word": item["trigger_word"],
 
23
  for item in data
24
  ]
25
 
26
+ # Sort the loras by likes
27
  sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
28
 
29
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
31
 
32
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
33
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
34
+ pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
35
  pipe.to(device=DEVICE, dtype=torch.float16)
36
 
 
 
 
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1024
40
 
41
+
42
+ def update_selection(
43
+ selected_state: gr.SelectData,
44
+ gr_sdxl_loras,
45
+ ):
46
+
47
  lora_id = gr_sdxl_loras[selected_state.index]["repo"]
48
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
49
+
50
  return lora_id, trigger_word
51
 
52
+
53
  @spaces.GPU
54
  def infer(
55
  pre_prompt,
 
63
  user_lora_weight,
64
  progress=gr.Progress(track_tqdm=True),
65
  ):
66
+ flash_sdxl_id = "jasperai/flash-sdxl"
 
 
 
67
 
68
+ new_adapter_id = user_lora_selector.replace("/", "_")
69
+ loaded_adapters = pipe.get_list_adapters()
 
70
 
71
+ if new_adapter_id not in loaded_adapters["unet"]:
72
+ gr.Info("Swapping LoRA")
73
+ pipe.unload_lora_weights()
74
+ pipe.load_lora_weights(flash_sdxl_id, adapter_name="lora")
75
+ pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ pipe.set_adapters(["lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
78
+ gr.Info("LoRA setup done")
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if randomize_seed:
81
  seed = random.randint(0, MAX_SEED)