import gradio as gr import torch from safetensors.torch import save_file import os def convert_ckpt_to_safetensors(ckpt_file): if not ckpt_file.name.endswith('.ckpt'): return "Please upload a .ckpt file." try: # Load the checkpoint ckpt = torch.load(ckpt_file.name, map_location="cpu") # Extract the state dict if "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt # Create the output filename output_file = os.path.splitext(ckpt_file.name)[0] + ".safetensors" # Save as safetensors save_file(state_dict, output_file) return f"Conversion successful. Saved as {output_file}" except Exception as e: return f"Error during conversion: {str(e)}" iface = gr.Interface( fn=convert_ckpt_to_safetensors, inputs=gr.File(label="Upload .ckpt file"), outputs="text", title="CKPT to Safetensors Converter", description="Upload a .ckpt file to convert it to the safetensors format." ) iface.launch()