#@title Prepare the Concepts Library to be used
import requests
import os
import gradio as gr
import wget
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from huggingface_hub import HfApi
from transformers import CLIPTextModel, CLIPTokenizer
import html
api = HfApi()
models_list = api.list_models(author="sd-concepts-library", sort="likes", direction=-1)
models = []
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16).to("cuda")
torch.backends.cudnn.benchmark = True
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype)
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
i = 1
while(num_added_tokens == 0):
print(f"The tokenizer already contains the token {token}.")
token = f"{token[:-1]}-{i}>"
print(f"Attempting to add the token {token}.")
num_added_tokens = tokenizer.add_tokens(token)
i+=1
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token
print("Setting up the public library")
for model in models_list:
model_content = {}
model_id = model.modelId
model_content["id"] = model_id
embeds_url = f"https://huggingface.co/{model_id}/resolve/main/learned_embeds.bin"
os.makedirs(model_id,exist_ok = True)
if not os.path.exists(f"{model_id}/learned_embeds.bin"):
try:
wget.download(embeds_url, out=model_id)
except:
continue
token_identifier = f"https://huggingface.co/{model_id}/raw/main/token_identifier.txt"
response = requests.get(token_identifier)
token_name = response.text
concept_type = f"https://huggingface.co/{model_id}/raw/main/type_of_concept.txt"
response = requests.get(concept_type)
concept_name = response.text
model_content["concept_type"] = concept_name
images = []
for i in range(4):
url = f"https://huggingface.co/{model_id}/resolve/main/concept_images/{i}.jpeg"
image_download = requests.get(url)
url_code = image_download.status_code
if(url_code == 200):
file = open(f"{model_id}/{i}.jpeg", "wb") ## Creates the file for image
file.write(image_download.content) ## Saves file content
file.close()
images.append(f"{model_id}/{i}.jpeg")
model_content["images"] = images
learned_token = load_learned_embed_in_clip(f"{model_id}/learned_embeds.bin", pipe.text_encoder, pipe.tokenizer, token_name)
model_content["token"] = learned_token
models.append(model_content)
#@title Run the app to navigate around [the Library](https://huggingface.co/sd-concepts-library)
#@markdown Click the `Running on public URL:` result to run the Gradio app
SELECT_LABEL = "Select concept"
def assembleHTML(model):
html_gallery = ''
html_gallery = html_gallery+'''
'''
for model in models:
html_gallery = html_gallery+f'''
'''
for image in model["images"]:
html_gallery = html_gallery + f'''
'''
html_gallery = html_gallery+'''
'''
html_gallery = html_gallery+'''
'''
return html_gallery
def title_block(title, id):
return gr.Markdown(f"### [`{title}`](https://huggingface.co/{id})")
def image_block(image_list, concept_type):
return gr.Gallery(
label=concept_type, value=image_list, elem_id="gallery"
).style(grid=[2], height="auto")
def checkbox_block():
checkbox = gr.Checkbox(label=SELECT_LABEL).style(container=False)
return checkbox
def infer(text):
images_list = pipe(
[text]*2,
num_inference_steps=50,
guidance_scale=7.5
)
output_images = []
for i, image in enumerate(images_list["sample"]):
output_images.append(image)
return output_images
css = '''
.gradio-container {font-family: 'IBM Plex Sans', sans-serif}
#top_title{margin-bottom: .5em}
#top_title h2{margin-bottom: 0; text-align: center}
#main_row{flex-wrap: wrap; gap: 1em; max-height: 550px; overflow-y: scroll; flex-direction: row}
@media (min-width: 768px){#main_row > div{flex: 1 1 32%; margin-left: 0 !important}}
.gr-prose code::before, .gr-prose code::after {content: "" !important}
::-webkit-scrollbar {width: 10px}
::-webkit-scrollbar-track {background: #f1f1f1}
::-webkit-scrollbar-thumb {background: #888}
::-webkit-scrollbar-thumb:hover {background: #555}
.gr-button {white-space: nowrap}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
#prompt_input{flex: 1 3 auto}
#prompt_area{margin-bottom: .75em}
#prompt_area > div:first-child{flex: 1 3 auto}
'''
examples = ["a in style", "a style mecha robot", "a piano being played by ", "Candid photo of , high resolution photo, trending on artstation, interior design"]
with gr.Blocks(css=css) as demo:
state = gr.Variable({
'selected': -1
})
state = {}
def update_state(i):
global checkbox_states
if(checkbox_states[i]):
checkbox_states[i] = False
state[i] = False
else:
state[i] = True
checkbox_states[i] = True
gr.HTML('''
Stable Diffusion Conceptualizer
Navigate through community created concepts and styles via Stable Diffusion Textual Inversion and pick yours for inference.
To train your own concepts and contribute to the library check out this notebook .
''')
with gr.Row():
with gr.Column():
gr.Markdown(f"### Navigate {len(models)}+ Textual-Inversion community trained concepts")
with gr.Row():
image_blocks = []
#for i, model in enumerate(models):
with gr.Box().style(border=None):
gr.HTML(assembleHTML(models))
#title_block(model["token"], model["id"])
#image_blocks.append(image_block(model["images"], model["concept_type"]))
with gr.Box():
with gr.Row(elem_id="prompt_area").style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt", placeholder="Enter your prompt", show_label=False, max_lines=1, elem_id="prompt_input"
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False
)
btn = gr.Button("Run",elem_id="run_btn").style(
margin=False,
rounded=(False, True, True, False)
)
with gr.Row().style():
infer_outputs = gr.Gallery(show_label=False).style(grid=[2], height="512px")
with gr.Row():
gr.HTML("Prompting may not work as you are used to. objects
may need the concept added at the end, styles
may work better at the beginning. You can navigate on lexica.art to get inspired on prompts
")
with gr.Row():
gr.Examples(examples=examples, fn=infer, inputs=[text], outputs=infer_outputs, cache_examples=True)
checkbox_states = {}
inputs = [text]
btn.click(
infer,
inputs=inputs,
outputs=infer_outputs
)
demo.queue(max_size=25).launch()