alvdansen commited on
Commit
f1ed22c
β€’
1 Parent(s): 41651eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -43
app.py CHANGED
@@ -1,26 +1,54 @@
1
  import json
2
  import random
3
  import requests
 
 
 
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
- from PIL import Image
10
- import os
 
 
 
 
 
 
 
 
11
 
12
- # Load the JSON data
13
  with open("sdxl_lora.json", "r") as file:
14
  data = json.load(file)
15
- sdxl_loras_raw = sorted(data, key=lambda x: x["likes"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
19
 
20
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
21
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
22
  pipe.to(device=DEVICE, dtype=torch.float16)
23
 
 
 
 
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
  MAX_IMAGE_SIZE = 1024
26
 
@@ -29,46 +57,50 @@ def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
29
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
30
  return lora_id, trigger_word
31
 
32
- def load_lora_for_style(style_repo):
33
- pipe.unload_lora_weights()
34
- pipe.load_lora_weights("jasperai/flash-sdxl", adapter_name="lora")
35
-
36
- def get_image(image_data):
37
- if isinstance(image_data, str):
38
- return image_data
39
-
40
- if isinstance(image_data, dict):
41
- local_path = image_data.get('local_path')
42
- hf_url = image_data.get('hf_url')
43
- else:
44
- print(f"Unexpected image_data format: {type(image_data)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return None
46
 
47
- # Try loading from local path first
48
- if local_path and os.path.exists(local_path):
49
- try:
50
- Image.open(local_path).verify() # Verify that it's a valid image
51
- return local_path
52
- except Exception as e:
53
- print(f"Error loading local image {local_path}: {e}")
54
-
55
- # If local path fails or doesn't exist, try URL
56
- if hf_url:
57
- try:
58
- response = requests.get(hf_url)
59
- if response.status_code == 200:
60
- img = Image.open(requests.get(hf_url, stream=True).raw)
61
- img.verify() # Verify that it's a valid image
62
- img.save(local_path) # Save for future use
63
- return local_path
64
- else:
65
- print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
66
- except Exception as e:
67
- print(f"Error loading image from URL {hf_url}: {e}")
68
-
69
- print(f"Failed to load image for {image_data}")
70
- return None
71
-
72
  @spaces.GPU
73
  def infer(
74
  pre_prompt,
 
1
  import json
2
  import random
3
  import requests
4
+ import os
5
+ from PIL import Image
6
+
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
10
  import torch
11
  from diffusers import DiffusionPipeline, LCMScheduler
12
+ from peft import PeftModel
13
+
14
+ # Custom LCMScheduler to ignore unexpected attributes
15
+ class CustomLCMScheduler(LCMScheduler):
16
+ @property
17
+ def config(self):
18
+ return {k: v for k, v in super().config.items() if k != "skip_prk_steps"}
19
+
20
+ def get_image(image_data):
21
+ # ... (keep the get_image function as is)
22
 
 
23
  with open("sdxl_lora.json", "r") as file:
24
  data = json.load(file)
25
+ sdxl_loras_raw = [
26
+ {
27
+ "image": get_image(item["image"]),
28
+ "title": item["title"],
29
+ "repo": item["repo"],
30
+ "trigger_word": item["trigger_word"],
31
+ "weights": item["weights"],
32
+ "is_pivotal": item.get("is_pivotal", False),
33
+ "text_embedding_weights": item.get("text_embedding_weights", None),
34
+ "likes": item.get("likes", 0),
35
+ }
36
+ for item in data
37
+ ]
38
+
39
+ sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
40
 
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
43
 
44
  pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16")
45
+ pipe.scheduler = CustomLCMScheduler.from_config(pipe.scheduler.config)
46
  pipe.to(device=DEVICE, dtype=torch.float16)
47
 
48
+ # Load Flash SDXL LoRA
49
+ flash_sdxl_id = "jasperai/flash-sdxl"
50
+ pipe.load_lora_weights(flash_sdxl_id, adapter_name="flash_lora")
51
+
52
  MAX_SEED = np.iinfo(np.int32).max
53
  MAX_IMAGE_SIZE = 1024
54
 
 
57
  trigger_word = gr_sdxl_loras[selected_state.index]["trigger_word"]
58
  return lora_id, trigger_word
59
 
60
+ @spaces.GPU
61
+ def infer(
62
+ pre_prompt,
63
+ prompt,
64
+ seed,
65
+ randomize_seed,
66
+ num_inference_steps,
67
+ negative_prompt,
68
+ guidance_scale,
69
+ user_lora_selector,
70
+ user_lora_weight,
71
+ progress=gr.Progress(track_tqdm=True),
72
+ ):
73
+ try:
74
+ # Load the user-selected LoRA
75
+ new_adapter_id = user_lora_selector.replace("/", "_")
76
+ pipe.load_lora_weights(user_lora_selector, adapter_name=new_adapter_id)
77
+
78
+ # Set adapter weights
79
+ pipe.set_adapters(["flash_lora", new_adapter_id], adapter_weights=[1.0, user_lora_weight])
80
+ gr.Info("LoRA setup complete")
81
+
82
+ if randomize_seed:
83
+ seed = random.randint(0, MAX_SEED)
84
+
85
+ generator = torch.Generator().manual_seed(seed)
86
+
87
+ if pre_prompt != "":
88
+ prompt = f"{pre_prompt} {prompt}"
89
+
90
+ # Use Flash Diffusion settings
91
+ image = pipe(
92
+ prompt=prompt,
93
+ negative_prompt=negative_prompt,
94
+ guidance_scale=1.0, # Flash Diffusion typically uses guidance_scale=1
95
+ num_inference_steps=4, # Flash Diffusion uses fewer steps
96
+ generator=generator,
97
+ ).images[0]
98
+
99
+ return image
100
+ except Exception as e:
101
+ gr.Error(f"An error occurred: {str(e)}")
102
  return None
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  @spaces.GPU
105
  def infer(
106
  pre_prompt,