FoodDesert's picture
Create app.py
a71b31f verified
raw
history blame
1.43 kB
import torch
from safetensors.torch import save_file
import gradio as gr
import os
def convert_embedding(sd15_embedding):
# Temporary file paths
input_path = "temp_input.pt"
output_path = "temp_output.safetensors"
# Save uploaded file to disk to be processed
with open(input_path, "wb") as f:
f.write(sd15_embedding.read())
# Your existing conversion logic
sd15_embedding = torch.load(input_path, map_location=torch.device('cpu'))
sd15_tensor = sd15_embedding['string_to_param']['*']
num_vectors = sd15_tensor.shape[0]
clip_g_shape = (num_vectors, 1280)
clip_l_shape = (num_vectors, 768)
clip_g = torch.zeros(clip_g_shape, dtype=torch.float16)
clip_l = torch.zeros(clip_l_shape, dtype=torch.float16)
clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16)
save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path)
# Remove the temporary input file
os.remove(input_path)
# Return the path to the converted file for download
return output_path
iface = gr.Interface(
fn=convert_embedding,
inputs=gr.inputs.File(label="Upload SD1.5 Embedding"),
outputs=gr.outputs.File(label="Download Converted SDXL Embedding"),
title="SD1.5 to SDXL Embedding Converter",
description="Upload an SD1.5 embedding file to convert it to SDXL format."
)
if __name__ == "__main__":
iface.launch()