File size: 4,962 Bytes
4120479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cfc613
 
4120479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
import torch
import json
import random

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"],
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

saved_names = [
    hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
]

css = '''
#title{text-align:center}
#plus_column{align-self: center}
#plus_button{font-size: 250%; text-align: center}
.gradio-container{width: 700px !important; margin: 0 auto !important}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 57px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
    border-top-left-radius: 0px;}
'''

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) 
original_pipe = copy.deepcopy(pipe)

@spaces.GPU
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
  pipe = copy.deepcopy(original_pipe)
  pipe.load_lora_weights(shuffled_items[0]['repo'], weight_name=shuffled_items[0]['weights'])
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(shuffled_items[1]['repo'], weight_name=shuffled_items[1]['weights'])
  pipe.fuse_lora(lora_2_scale)

  pipe.to(torch_dtype=torch.float16)
  pipe.to("cuda")
  if negative_prompt == "":
    negative_prompt = False
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, guidance_scale=7).images[0]
  return image

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"LoRA trigger word: `{trigger_word}`" if trigger_word else "LoRA trigger word: `none`, will be applied automatically", trigger_word
    
def shuffle_images():
    compatible_items = [item for item in data if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1  = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])

    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])
    
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    return title_1,description_1,title_2,description_2,prompt, two_shuffled_items

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎲</h1>
        <h4>Two LoRAs are loaded to SDXL at random, find a way to combine them for your art 🎨</h4>
        ''',
        elem_id="title"
  )
  with gr.Row():
    with gr.Column(min_width=10, scale=6):
      lora_1 = gr.Image(interactive=False, height=350)
      lora_1_prompt = gr.Markdown()
    with gr.Column(min_width=10, scale=1, elem_id="plus_column"):
      plus = gr.HTML("+", elem_id="plus_button")
    with gr.Column(min_width=10, scale=6):
      lora_2 = gr.Image(interactive=False, height=350)
      lora_2_prompt = gr.Markdown()
  with gr.Row():
    prompt = gr.Textbox(label="Your prompt", info="arrange the trigger words of the two LoRAs in a coherent sentence", interactive=True, elem_id="prompt")
    run_btn = gr.Button("Run", elem_id="run_button")
  
  output_image = gr.Image()
  with gr.Accordion("Advanced settings", open=False):
    negative_prompt = gr.Textbox(label="Negative prompt")
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
  shuffle_button = gr.Button("Reshuffle LoRAs!")
  
  demo.load(shuffle_images, inputs=[], outputs=[lora_1,lora_1_prompt,lora_2,lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1,lora_1_prompt,lora_2,lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")

  run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
  prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
demo.queue()
demo.launch()