|
import torch |
|
from safetensors.torch import save_file, load_file |
|
import gradio as gr |
|
import os |
|
|
|
def convert_embedding(uploaded_files): |
|
if len(uploaded_files) > 5: |
|
raise ValueError("You can upload a maximum of 5 files at a time.") |
|
|
|
output_files = [] |
|
|
|
for uploaded_file in uploaded_files: |
|
file_name, file_extension = os.path.splitext(os.path.basename(uploaded_file)) |
|
output_path = f"{file_name}_XL.safetensors" |
|
|
|
if file_extension == '.pt': |
|
sd15_embedding = torch.load(uploaded_file, map_location=torch.device('cpu')) |
|
sd15_tensor = sd15_embedding.get('string_to_param', {}).get('*') |
|
elif file_extension == '.safetensors': |
|
loaded_tensors = load_file(uploaded_file) |
|
sd15_tensor = loaded_tensors.get('emb_params') |
|
else: |
|
raise ValueError(f"Unsupported file format: {file_extension}") |
|
|
|
if sd15_tensor is None: |
|
raise ValueError(f"Invalid embedding structure in file: {uploaded_file}") |
|
|
|
num_vectors = sd15_tensor.shape[0] |
|
clip_g_shape = (num_vectors, 1280) |
|
clip_l_shape = (num_vectors, 768) |
|
clip_g = torch.zeros(clip_g_shape, dtype=torch.float16) |
|
clip_l = torch.zeros(clip_l_shape, dtype=torch.float16) |
|
clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16) |
|
|
|
save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path) |
|
output_files.append(output_path) |
|
|
|
return output_files |
|
|
|
custom_css = """ |
|
body { |
|
background-color: #121212; |
|
color: #ffffff; |
|
font-family: Arial, sans-serif; |
|
} |
|
.gradio-container { |
|
max-width: 800px; |
|
margin: auto; |
|
padding: 20px; |
|
border-radius: 10px; |
|
background: #1e1e1e; |
|
box-shadow: 0 0 10px rgba(255, 102, 0, 0.5); |
|
} |
|
.gradio-container h1 { |
|
text-align: center; |
|
font-size: 24px; |
|
color: #ff6600; |
|
} |
|
.gradio-container button { |
|
background-color: #ff6600; |
|
color: white; |
|
padding: 10px 15px; |
|
border: none; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
font-size: 16px; |
|
} |
|
.gradio-container button:hover { |
|
background-color: #cc5500; |
|
} |
|
""" |
|
|
|
iface = gr.Interface( |
|
fn=convert_embedding, |
|
inputs=gr.Files(label="Upload SD-1.5 embeddings (Max 5)", type="filepath"), |
|
outputs=gr.Files(label="Download converted SDXL safetensors embeddings"), |
|
title="✨ SD-1.5 to SDXL Embedding Converter | Now supports multiple files ⚡", |
|
description="Upload up to 5 SD-1.5 embedding files to convert them to SDXL. Stylish and efficient!", |
|
theme="default", |
|
css=custom_css, |
|
live=True |
|
) |
|
|
|
iface.launch() |
|
|