import os import shutil import zipfile from typing import Dict, Tuple import gradio as gr import requests from PIL import Image from replicate.client import Client # Constants IMAGE_DIR = "uploaded_images" VALID_IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg'] # Global variables replicate_credentials: Dict[str, str] = {} global_serving_url_dict: Dict[str, str] = {} num_of_input_images_dict: Dict[str, int] = {} def validate_replicate_credentials(username: str, token: str) -> Tuple[bool, Dict[str, str]]: """Validate Replicate credentials and return account info if valid.""" headers = {"Authorization": f"Token {token}"} response = requests.get("https://api.replicate.com/v1/account", headers=headers) if response.status_code == 200: account_info = response.json() if account_info.get("username") == username: return True, account_info return False, {} def create_save_slots(username: str, replicate_token: str) -> None: """Create save slots for a given username.""" base_url = "https://api.replicate.com/v1" headers = {"Authorization": f"Token {replicate_token}"} account_response = requests.get(f"{base_url}/account", headers=headers) if account_response.status_code != 200: print(f"Failed to get account information. Status Code: {account_response.status_code}") return account_info = account_response.json() account_name = account_info.get("username") if not account_name: print("Account username not found in response.") return for i in range(1, 5): save_slot_name = f"saveslot{i}" model_name = f"{account_name}/{save_slot_name}" model_response = requests.get(f"{base_url}/models/{model_name}", headers=headers) if model_response.status_code == 200: print(f"{save_slot_name} already exists.") elif model_response.status_code == 404: create_model_response = requests.post( f"{base_url}/models", headers=headers, json={ "owner": account_name, "name": save_slot_name, "description": f"Save slot {i}", "visibility": "private", "hardware": "gpu-a40-small" } ) if create_model_response.status_code == 201: print(f"{save_slot_name} created successfully.") else: print(f"Failed to create {save_slot_name}. Status Code: {create_model_response.status_code}") else: print(f"Failed to check {save_slot_name}. Status Code: {model_response.status_code}") def replicate_auth(username: str, token: str) -> bool: """Authenticate user with Replicate and create save slots.""" is_valid, account_info = validate_replicate_credentials(username, token) if is_valid: replicate_credentials[username] = token create_save_slots(username, token) return True return False def save_images(uploaded_files: list, request: gr.Request) -> str: """Save uploaded images to a directory and return the serving URL.""" if os.path.exists(IMAGE_DIR): shutil.rmtree(IMAGE_DIR) os.makedirs(IMAGE_DIR, exist_ok=True) image_count = 0 for i, file_info in enumerate(uploaded_files): extension = os.path.splitext(file_info.name)[1][1:].lower() if extension not in VALID_IMAGE_EXTENSIONS: return f"Error: '{file_info.name}' is not a valid image file. Please upload only PNG, JPG, or JPEG files." try: with Image.open(file_info.name) as image: image_path = os.path.join(IMAGE_DIR, f"image_{i}.{extension}") image.save(image_path) image_count += 1 except OSError as e: return f"Error processing image {file_info.name}: {e}" if image_count == 0: return "No valid images were uploaded. Please upload at least one PNG, JPG, or JPEG file." num_of_input_images_dict[request.username] = image_count zip_path = 'data.zip' with zipfile.ZipFile(zip_path, 'w') as zipf: for root, dirs, files in os.walk(IMAGE_DIR): for file in files: zipf.write(os.path.join(root, file), arcname=file) replicate_token = replicate_credentials[request.username] headers = {"Authorization": f"Token {replicate_token}"} response = requests.post("https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip", headers=headers) if response.status_code != 200: raise Exception(f"Failed to get upload URL. Response: {response.text}") data = response.json() upload_url = data.get("upload_url") if not upload_url: raise Exception("Upload URL not found in response") with open("data.zip", "rb") as file: upload_response = requests.put(upload_url, headers={"Content-Type": "application/zip"}, data=file) if upload_response.status_code != 200: raise Exception(f"Failed to upload file. Response: {upload_response.text}") serving_url = data.get("serving_url") if not serving_url: raise Exception("Serving URL not found in response") global_serving_url_dict[request.username] = serving_url return f"Images have been successfully uploaded to: {serving_url}" def on_files_selected(data: list) -> str: """Callback function when files are selected.""" return "Click 'Upload Training Data' to upload your images." def on_files_cleared(data: list) -> str: """Callback function when files are cleared.""" return "Upload images and press the button." def train_model(category: str, token_string: str, caption_prefix: str, mask_target_prompts: str, destination_choice: str, request: gr.Request) -> Tuple[str, str]: """Train a model with the specified parameters.""" account_name = request.username num_of_input_images = num_of_input_images_dict[request.username] global_serving_url = global_serving_url_dict[request.username] replicate_client = Client(api_token=replicate_credentials[account_name]) destination_decider = { 'Slot 1': f"{account_name}/saveslot1", 'Slot 2': f"{account_name}/saveslot2", 'Slot 3': f"{account_name}/saveslot3", 'Slot 4': f"{account_name}/saveslot4" } destination = destination_decider[destination_choice] settings = { "Person": { "resolution": 1024, "train_batch_size": 4, "max_train_steps": 1400 + (num_of_input_images * 50), "unet_learning_rate": 1e-6, "is_lora": True, "lora_lr": 1e-4, "crop_based_on_salience": False, "use_face_detection_instead": True, "clipseg_temperature": 1.0, }, "Object": { "resolution": 1024, "train_batch_size": 4, "max_train_steps": 1000 + (num_of_input_images * 50), "unet_learning_rate": 1e-6, "is_lora": True, "lora_lr": 1e-4, "crop_based_on_salience": True, "use_face_detection_instead": False, "clipseg_temperature": 0.9, }, "Animal": { "resolution": 1024, "train_batch_size": 4, "max_train_steps": 500 + (num_of_input_images * 75), "unet_learning_rate": 1e-6, "is_lora": True, "lora_lr": 1e-4, "crop_based_on_salience": True, "use_face_detection_instead": False, "clipseg_temperature": 0.9, }, "Style": { "resolution": 1024, "train_batch_size": 4, "max_train_steps": 300 + (num_of_input_images * 100), "unet_learning_rate": 1e-6, "is_lora": True, "lora_lr": 3e-4, "crop_based_on_salience": False, "use_face_detection_instead": False, "clipseg_temperature": 1.0, } }[category] training = replicate_client.trainings.create( version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": global_serving_url, **settings, "token_string": token_string, "caption_prefix": caption_prefix, "mask_target_prompts": mask_target_prompts, }, destination=destination ) train_id = training.id return train_id, "Training started successfully!" def cancel_training(training_id: str, request: gr.Request) -> str: """Cancel the specified training.""" replicate_token = replicate_credentials[request.username] headers = {"Authorization": f"Token {replicate_token}"} url = f"https://api.replicate.com/v1/trainings/{training_id}/cancel" response = requests.post(url, headers=headers) if response.status_code == 200: return "Training cancelled successfully!" else: return f"Failed to cancel training. Response code: {response.status_code}, Message: {response.text}" def check_training_status(training_id: str, request: gr.Request) -> str: """Check the status of the specified training.""" replicate_token = replicate_credentials[request.username] headers = {"Authorization": f"Token {replicate_token}"} url = f"https://api.replicate.com/v1/trainings/{training_id}" response = requests.get(url, headers=headers) if response.status_code == 200: training = response.json() output = f"Training Status: {training['status']}\n" if training['status'] == 'processing': output += "\n".join(training['logs'].split("\n")[-15:]) elif training['status'] == 'succeeded': output += str(training) else: output = f"Request failed with status code: {response.status_code}" return output def generate_images(destination_choice: str, resolution: str, prompt: str, negative_prompt: str, lora_scale: float, num_outputs: int, request: gr.Request) -> list: """Generate images using the specified parameters.""" account_name = request.username replicate_client = Client(api_token=replicate_credentials[request.username]) destination_decider = { 'Slot 1': f"{account_name}/saveslot1", 'Slot 2': f"{account_name}/saveslot2", 'Slot 3': f"{account_name}/saveslot3", 'Slot 4': f"{account_name}/saveslot4" } destination = destination_decider[destination_choice] replicate_token = replicate_credentials[request.username] headers = {"Authorization": f"Token {replicate_token}"} response = requests.get(f"https://api.replicate.com/v1/models/{destination}", headers=headers) if response.status_code == 200: response_data = response.json() latest_version = response_data.get('latest_version', {}).get('id', 'Not available') print("Latest version ID:", latest_version) else: print("Failed to fetch the model details. Status Code:", response.status_code) output = replicate_client.run( f"{destination}:{latest_version}", input={ "width": int(resolution), "height": int(resolution), "prompt": prompt, "negative_prompt": negative_prompt, "lora_scale": lora_scale, "num_outputs": num_outputs, "apply_watermark": False, "disable_safety_checker": True } ) return output # Gradio UI Block with gr.Blocks(theme='prismglider/JSPERTheme2') as demo: with gr.Tab("Upload Data"): with gr.Row(): with gr.Column(): upload_button = gr.File(label="Upload Images", type="filepath", file_count="multiple", interactive=True) with gr.Column(): status_message = gr.Textbox(label="Status", interactive=False, value="Upload images and press the button.") btn = gr.Button("Upload Training Data") btn.click(save_images, inputs=[upload_button], outputs=[status_message]) upload_button.clear(on_files_cleared, inputs=upload_button, outputs=status_message) upload_button.upload(on_files_selected, inputs=upload_button, outputs=status_message) with gr.Row(): with gr.Column(): gr.Markdown( """

JSPER (Just Stablediffusion Plus Easy Retraining)

JSPER makes it EASY to train and generate custom art pieces with Stable Diffusion!

Follow along with our dog and mascot, Jasper, to see how it works!

Step 1: Upload Data

""" ) with gr.Tab("Train AI"): with gr.Row(): with gr.Column(scale=2): category = gr.Radio(["Person", "Object", "Animal", "Style"], label="Category") token_string = gr.Textbox(label="Token", value="JSP") caption_prefix = gr.Textbox(label="Tokenized Prompt", value="photo of a JSP thing") mask_target_prompts = gr.Textbox(label="Reference Prompt", value="photo of a thing") destination_choice = gr.Radio(["Slot 1", "Slot 2", "Slot 3", "Slot 4"], label="Save Slot") submit_button = gr.Button("Start Training") with gr.Column(): training_status = gr.Textbox(label="Training Status") training_id = gr.Textbox(visible=False) with gr.Row(): check_status_button = gr.Button("Check Training Status") cancel_training_button = gr.Button("Cancel Training") submit_button.click( train_model, inputs=[category, token_string, caption_prefix, mask_target_prompts, destination_choice], outputs=[training_id, training_status] ) check_status_button.click(check_training_status, inputs=[training_id], outputs=training_status) cancel_training_button.click(cancel_training, inputs=[training_id], outputs=training_status) with gr.Row(): with gr.Column(): gr.HTML( """

JSPER (Just Stablediffusion Plus Easy Retraining)

Step 2: Train AI

""" ) with gr.Tab("Generate Images"): with gr.Row(): with gr.Column(): destination_choice = gr.Radio(["Slot 1", "Slot 2", "Slot 3", "Slot 4"], label="Load Slot") resolution = gr.Radio(["512", "1024", "2048"], label="Resolution", value="1024") prompt = gr.Textbox(label="Prompt", value="a JSP object") negative_prompt = gr.Textbox(label="Negative Prompt", value="") lora_scale = gr.Slider(label="Training Scale", minimum=.5, maximum=1, value=0.6, step=.01) num_outputs = gr.Slider(label="Number of Outputs", minimum=1, maximum=4, value=2, step=1) with gr.Column(): images = gr.Gallery(label="Generated Images") submit_button = gr.Button("Generate") submit_button.click( generate_images, inputs=[destination_choice, resolution, prompt, negative_prompt, lora_scale, num_outputs], outputs=images ) with gr.Row(): with gr.Column(): gr.HTML( """

Step 3: Generate Images

In the Generate Images tab:

""" ) auth_message = """

Welcome to JSPER! (Just Stablediffusion Plus Easy Retraining)

Before you begin training StableDiffusion, please follow these steps to log in:

  • You need to have a Replicate account to use this application.
  • Once you have an account, log in to the application using your Replicate username and Replicate token.
  • Your Replicate token can be found in your account settings on the Replicate website.
""" demo.launch(auth=replicate_auth, auth_message=auth_message, show_api=False)