Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import zipfile | |
from typing import Dict, Tuple | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from replicate.client import Client | |
# Constants | |
IMAGE_DIR = "uploaded_images" | |
VALID_IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg'] | |
# Global variables | |
replicate_credentials: Dict[str, str] = {} | |
global_serving_url_dict: Dict[str, str] = {} | |
num_of_input_images_dict: Dict[str, int] = {} | |
def validate_replicate_credentials(username: str, token: str) -> Tuple[bool, Dict[str, str]]: | |
"""Validate Replicate credentials and return account info if valid.""" | |
headers = {"Authorization": f"Token {token}"} | |
response = requests.get("https://api.replicate.com/v1/account", headers=headers) | |
if response.status_code == 200: | |
account_info = response.json() | |
if account_info.get("username") == username: | |
return True, account_info | |
return False, {} | |
def create_save_slots(username: str, replicate_token: str) -> None: | |
"""Create save slots for a given username.""" | |
base_url = "https://api.replicate.com/v1" | |
headers = {"Authorization": f"Token {replicate_token}"} | |
account_response = requests.get(f"{base_url}/account", headers=headers) | |
if account_response.status_code != 200: | |
print(f"Failed to get account information. Status Code: {account_response.status_code}") | |
return | |
account_info = account_response.json() | |
account_name = account_info.get("username") | |
if not account_name: | |
print("Account username not found in response.") | |
return | |
for i in range(1, 5): | |
save_slot_name = f"saveslot{i}" | |
model_name = f"{account_name}/{save_slot_name}" | |
model_response = requests.get(f"{base_url}/models/{model_name}", headers=headers) | |
if model_response.status_code == 200: | |
print(f"{save_slot_name} already exists.") | |
elif model_response.status_code == 404: | |
create_model_response = requests.post( | |
f"{base_url}/models", | |
headers=headers, | |
json={ | |
"owner": account_name, | |
"name": save_slot_name, | |
"description": f"Save slot {i}", | |
"visibility": "private", | |
"hardware": "gpu-a40-small" | |
} | |
) | |
if create_model_response.status_code == 201: | |
print(f"{save_slot_name} created successfully.") | |
else: | |
print(f"Failed to create {save_slot_name}. Status Code: {create_model_response.status_code}") | |
else: | |
print(f"Failed to check {save_slot_name}. Status Code: {model_response.status_code}") | |
def replicate_auth(username: str, token: str) -> bool: | |
"""Authenticate user with Replicate and create save slots.""" | |
is_valid, account_info = validate_replicate_credentials(username, token) | |
if is_valid: | |
replicate_credentials[username] = token | |
create_save_slots(username, token) | |
return True | |
return False | |
def save_images(uploaded_files: list, request: gr.Request) -> str: | |
"""Save uploaded images to a directory and return the serving URL.""" | |
if os.path.exists(IMAGE_DIR): | |
shutil.rmtree(IMAGE_DIR) | |
os.makedirs(IMAGE_DIR, exist_ok=True) | |
image_count = 0 | |
for i, file_info in enumerate(uploaded_files): | |
extension = os.path.splitext(file_info.name)[1][1:].lower() | |
if extension not in VALID_IMAGE_EXTENSIONS: | |
return f"Error: '{file_info.name}' is not a valid image file. Please upload only PNG, JPG, or JPEG files." | |
try: | |
with Image.open(file_info.name) as image: | |
image_path = os.path.join(IMAGE_DIR, f"image_{i}.{extension}") | |
image.save(image_path) | |
image_count += 1 | |
except OSError as e: | |
return f"Error processing image {file_info.name}: {e}" | |
if image_count == 0: | |
return "No valid images were uploaded. Please upload at least one PNG, JPG, or JPEG file." | |
num_of_input_images_dict[request.username] = image_count | |
zip_path = 'data.zip' | |
with zipfile.ZipFile(zip_path, 'w') as zipf: | |
for root, dirs, files in os.walk(IMAGE_DIR): | |
for file in files: | |
zipf.write(os.path.join(root, file), arcname=file) | |
replicate_token = replicate_credentials[request.username] | |
headers = {"Authorization": f"Token {replicate_token}"} | |
response = requests.post("https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip", headers=headers) | |
if response.status_code != 200: | |
raise Exception(f"Failed to get upload URL. Response: {response.text}") | |
data = response.json() | |
upload_url = data.get("upload_url") | |
if not upload_url: | |
raise Exception("Upload URL not found in response") | |
with open("data.zip", "rb") as file: | |
upload_response = requests.put(upload_url, headers={"Content-Type": "application/zip"}, data=file) | |
if upload_response.status_code != 200: | |
raise Exception(f"Failed to upload file. Response: {upload_response.text}") | |
serving_url = data.get("serving_url") | |
if not serving_url: | |
raise Exception("Serving URL not found in response") | |
global_serving_url_dict[request.username] = serving_url | |
return f"Images have been successfully uploaded to: {serving_url}" | |
def on_files_selected(data: list) -> str: | |
"""Callback function when files are selected.""" | |
return "Click 'Upload Training Data' to upload your images." | |
def on_files_cleared(data: list) -> str: | |
"""Callback function when files are cleared.""" | |
return "Upload images and press the button." | |
def train_model(category: str, token_string: str, caption_prefix: str, mask_target_prompts: str, | |
destination_choice: str, request: gr.Request) -> Tuple[str, str]: | |
"""Train a model with the specified parameters.""" | |
account_name = request.username | |
num_of_input_images = num_of_input_images_dict[request.username] | |
global_serving_url = global_serving_url_dict[request.username] | |
replicate_client = Client(api_token=replicate_credentials[account_name]) | |
destination_decider = { | |
'Slot 1': f"{account_name}/saveslot1", | |
'Slot 2': f"{account_name}/saveslot2", | |
'Slot 3': f"{account_name}/saveslot3", | |
'Slot 4': f"{account_name}/saveslot4" | |
} | |
destination = destination_decider[destination_choice] | |
settings = { | |
"Person": { | |
"resolution": 1024, | |
"train_batch_size": 4, | |
"max_train_steps": 1400 + (num_of_input_images * 50), | |
"unet_learning_rate": 1e-6, | |
"is_lora": True, | |
"lora_lr": 1e-4, | |
"crop_based_on_salience": False, | |
"use_face_detection_instead": True, | |
"clipseg_temperature": 1.0, | |
}, | |
"Object": { | |
"resolution": 1024, | |
"train_batch_size": 4, | |
"max_train_steps": 1000 + (num_of_input_images * 50), | |
"unet_learning_rate": 1e-6, | |
"is_lora": True, | |
"lora_lr": 1e-4, | |
"crop_based_on_salience": True, | |
"use_face_detection_instead": False, | |
"clipseg_temperature": 0.9, | |
}, | |
"Animal": { | |
"resolution": 1024, | |
"train_batch_size": 4, | |
"max_train_steps": 500 + (num_of_input_images * 75), | |
"unet_learning_rate": 1e-6, | |
"is_lora": True, | |
"lora_lr": 1e-4, | |
"crop_based_on_salience": True, | |
"use_face_detection_instead": False, | |
"clipseg_temperature": 0.9, | |
}, | |
"Style": { | |
"resolution": 1024, | |
"train_batch_size": 4, | |
"max_train_steps": 300 + (num_of_input_images * 100), | |
"unet_learning_rate": 1e-6, | |
"is_lora": True, | |
"lora_lr": 3e-4, | |
"crop_based_on_salience": False, | |
"use_face_detection_instead": False, | |
"clipseg_temperature": 1.0, | |
} | |
}[category] | |
training = replicate_client.trainings.create( | |
version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", | |
input={ | |
"input_images": global_serving_url, | |
**settings, | |
"token_string": token_string, | |
"caption_prefix": caption_prefix, | |
"mask_target_prompts": mask_target_prompts, | |
}, | |
destination=destination | |
) | |
train_id = training.id | |
return train_id, "Training started successfully!" | |
def cancel_training(training_id: str, request: gr.Request) -> str: | |
"""Cancel the specified training.""" | |
replicate_token = replicate_credentials[request.username] | |
headers = {"Authorization": f"Token {replicate_token}"} | |
url = f"https://api.replicate.com/v1/trainings/{training_id}/cancel" | |
response = requests.post(url, headers=headers) | |
if response.status_code == 200: | |
return "Training cancelled successfully!" | |
else: | |
return f"Failed to cancel training. Response code: {response.status_code}, Message: {response.text}" | |
def check_training_status(training_id: str, request: gr.Request) -> str: | |
"""Check the status of the specified training.""" | |
replicate_token = replicate_credentials[request.username] | |
headers = {"Authorization": f"Token {replicate_token}"} | |
url = f"https://api.replicate.com/v1/trainings/{training_id}" | |
response = requests.get(url, headers=headers) | |
if response.status_code == 200: | |
training = response.json() | |
output = f"Training Status: {training['status']}\n" | |
if training['status'] == 'processing': | |
output += "\n".join(training['logs'].split("\n")[-15:]) | |
elif training['status'] == 'succeeded': | |
output += str(training) | |
else: | |
output = f"Request failed with status code: {response.status_code}" | |
return output | |
def generate_images(destination_choice: str, resolution: str, prompt: str, negative_prompt: str, | |
lora_scale: float, num_outputs: int, request: gr.Request) -> list: | |
"""Generate images using the specified parameters.""" | |
account_name = request.username | |
replicate_client = Client(api_token=replicate_credentials[request.username]) | |
destination_decider = { | |
'Slot 1': f"{account_name}/saveslot1", | |
'Slot 2': f"{account_name}/saveslot2", | |
'Slot 3': f"{account_name}/saveslot3", | |
'Slot 4': f"{account_name}/saveslot4" | |
} | |
destination = destination_decider[destination_choice] | |
replicate_token = replicate_credentials[request.username] | |
headers = {"Authorization": f"Token {replicate_token}"} | |
response = requests.get(f"https://api.replicate.com/v1/models/{destination}", headers=headers) | |
if response.status_code == 200: | |
response_data = response.json() | |
latest_version = response_data.get('latest_version', {}).get('id', 'Not available') | |
print("Latest version ID:", latest_version) | |
else: | |
print("Failed to fetch the model details. Status Code:", response.status_code) | |
output = replicate_client.run( | |
f"{destination}:{latest_version}", | |
input={ | |
"width": int(resolution), | |
"height": int(resolution), | |
"prompt": prompt, | |
"negative_prompt": negative_prompt, | |
"lora_scale": lora_scale, | |
"num_outputs": num_outputs, | |
"apply_watermark": False, | |
"disable_safety_checker": True | |
} | |
) | |
return output | |
# Gradio UI Block | |
with gr.Blocks(theme='prismglider/JSPERTheme2') as demo: | |
with gr.Tab("Upload Data"): | |
with gr.Row(): | |
with gr.Column(): | |
upload_button = gr.File(label="Upload Images", type="filepath", file_count="multiple", interactive=True) | |
with gr.Column(): | |
status_message = gr.Textbox(label="Status", interactive=False, value="Upload images and press the button.") | |
btn = gr.Button("Upload Training Data") | |
btn.click(save_images, inputs=[upload_button], outputs=[status_message]) | |
upload_button.clear(on_files_cleared, inputs=upload_button, outputs=status_message) | |
upload_button.upload(on_files_selected, inputs=upload_button, outputs=status_message) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
<style> | |
@media (max-width: 600px) { | |
.responsive-container { | |
flex-direction: column; | |
} | |
.responsive-container .text-content, | |
.responsive-container .image-content { | |
flex: 0 0 100%; | |
padding: 0; | |
} | |
} | |
</style> | |
<div class="responsive-container" style="display: flex; flex-wrap: wrap; align-items: flex-start;"> | |
<div class="text-content" style="flex: 1;"> | |
<h2>JSPER (Just Stablediffusion Plus Easy Retraining)</h2> | |
<p>JSPER makes it EASY to train and generate custom art pieces with Stable Diffusion!</p> | |
<p>Follow along with our dog and mascot, Jasper, to see how it works!</p> | |
<h3>Step 1: Upload Data</h3> | |
<ul> | |
<li>Upload images of your subject that will be used to train the AI art model</li> | |
<li>It's recommended to use 10+ images for training (~15 minutes of training time for 10 images)</li> | |
<li>After you've selected your images, click the 'upload training data' button</li> | |
<li>For best results, use HD images that show the subject in multiple poses, at multiple angles, and with various lighting conditions.</li> | |
</ul> | |
</div> | |
<div class="image-content" style="flex: 1; padding-left: 20px;"> | |
<img src="https://i.postimg.cc/nrVnQRwD/Untitled-Artwork.png" style="width: 100%; max-width: 300px; height: auto;"> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Tab("Train AI"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
category = gr.Radio(["Person", "Object", "Animal", "Style"], label="Category") | |
token_string = gr.Textbox(label="Token", value="JSP") | |
caption_prefix = gr.Textbox(label="Tokenized Prompt", value="photo of a JSP thing") | |
mask_target_prompts = gr.Textbox(label="Reference Prompt", value="photo of a thing") | |
destination_choice = gr.Radio(["Slot 1", "Slot 2", "Slot 3", "Slot 4"], label="Save Slot") | |
submit_button = gr.Button("Start Training") | |
with gr.Column(): | |
training_status = gr.Textbox(label="Training Status") | |
training_id = gr.Textbox(visible=False) | |
with gr.Row(): | |
check_status_button = gr.Button("Check Training Status") | |
cancel_training_button = gr.Button("Cancel Training") | |
submit_button.click( | |
train_model, | |
inputs=[category, token_string, caption_prefix, mask_target_prompts, destination_choice], | |
outputs=[training_id, training_status] | |
) | |
check_status_button.click(check_training_status, inputs=[training_id], outputs=training_status) | |
cancel_training_button.click(cancel_training, inputs=[training_id], outputs=training_status) | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
""" | |
<style> | |
@media (max-width: 600px) { | |
.responsive-container { | |
flex-direction: column; | |
} | |
.responsive-container .text-content, | |
.responsive-container .image-content { | |
flex: 0 0 100%; | |
padding: 0; | |
} | |
} | |
</style> | |
<div class="responsive-container" style="display: flex; flex-wrap: wrap; align-items: flex-start;"> | |
<div class="text-content" style="flex: 1;"> | |
<h2>JSPER (Just Stablediffusion Plus Easy Retraining)</h2> | |
<h3>Step 2: Train AI</h3> | |
<ul> | |
<li>Select the type of subject you are training (Person, Object, Animal, or Style).</li> | |
<li>Choose a save slot for your training.</li> | |
<ul style="padding-left: 20px;"> | |
<li>For <strong>Persons</strong>: Use "JSP" or the person's name as the token. Your prompt should be "Photo of a JSP man/girl" or "Photo of {Firstname Lastname}".</li> | |
<li>For <strong>Objects</strong>: Use "JSP" as the token. Your prompt should be "Photo of a JSP {object}", e.g., "Photo of a JSP shoe".</li> | |
<li>For <strong>Animals</strong>: Use "JSP" as the token. Your prompt should be "Photo of a JSP {animal}", e.g., "Photo of a JSP dog".</li> | |
<li>For <strong>Styles</strong>: Your prompt should be "In the style of JSP".</li> | |
</ul> | |
</ul> | |
</div> | |
<div class="image-content" style="flex: 1; padding-left: 20px;"> | |
<img src="https://i.postimg.cc/6qHtVrD1/Untitled-Artwork-2.png" style="width: 100%; max-width: 300px; height: auto;"> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Tab("Generate Images"): | |
with gr.Row(): | |
with gr.Column(): | |
destination_choice = gr.Radio(["Slot 1", "Slot 2", "Slot 3", "Slot 4"], label="Load Slot") | |
resolution = gr.Radio(["512", "1024", "2048"], label="Resolution", value="1024") | |
prompt = gr.Textbox(label="Prompt", value="a JSP object") | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="") | |
lora_scale = gr.Slider(label="Training Scale", minimum=.5, maximum=1, value=0.6, step=.01) | |
num_outputs = gr.Slider(label="Number of Outputs", minimum=1, maximum=4, value=2, step=1) | |
with gr.Column(): | |
images = gr.Gallery(label="Generated Images") | |
submit_button = gr.Button("Generate") | |
submit_button.click( | |
generate_images, | |
inputs=[destination_choice, resolution, prompt, negative_prompt, lora_scale, num_outputs], | |
outputs=images | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
""" | |
<style> | |
@media (max-width: 700px) { | |
.responsive-container { | |
flex-direction: column; | |
} | |
.responsive-container .text-content, | |
.responsive-container .image-content { | |
flex: 0 0 100%; | |
padding: 0; | |
} | |
} | |
</style> | |
<div class="responsive-container" style="display: flex; flex-wrap: wrap; align-items: flex-start;"> | |
<div class="text-content" style="flex: 1;"> | |
<h2>Step 3: Generate Images</h2> | |
<p>In the <strong>Generate Images</strong> tab:</p> | |
<ul> | |
<li>Select a load slot corresponding to one of your previous trainings.</li> | |
<li>Choose the number of images you want to generate.</li> | |
<li>Enter a prompt that utilizes the token and prompt style used in the training section. For example:</li> | |
<ul style="padding-left: 20px;"> | |
<li>For a person: "Photo of JSP man at a construction site" or "Renaissance style painting of Johnny Fakename".</li> | |
<li>For a style: "Digital landscape shot in the style of JSP".</li> | |
</ul> | |
<li>The prompts are flexible, and you can try multiple prompts to get good results.</li> | |
<li>If the results don't look enough like the subject, adjust the training scale up slightly.</li> | |
<li>If the results don't look creative or unique, lower the training scale slightly. A value of 0.6 is good in most cases.</li> | |
</ul> | |
</div> | |
<div class="image-content" style="flex: 1; padding-left: 20px;"> | |
<img src="https://i.postimg.cc/qBQJdrzX/Untitled-Artwork-1.png" style="width: 100%; max-width: 300px; height: auto;"> | |
</div> | |
</div> | |
""" | |
) | |
auth_message = """ | |
<div style='display: flex; align-items: flex-start;'> | |
<div style='width: 200px; flex-shrink: 0;'> | |
<img src='https://i.postimg.cc/qBQJdrzX/Untitled-Artwork-1.png' style='width: 100%; height: auto;'> | |
</div> | |
<div style='flex: 1;'> | |
<h1 style="font-size: 1.5em; text-decoration: underline; margin-bottom: 20px;">Welcome to JSPER! (Just Stablediffusion Plus Easy Retraining)</h1> | |
<p>Before you begin training StableDiffusion, please follow these steps to log in:</p> | |
<div style='padding-left: 20px; padding-right: 20px;'> | |
<ul> | |
<li>You need to have a <u><strong><a href='https://replicate.com/signin?next=/'>Replicate account</a></strong></u> to use this application.</li> | |
<li>Once you have an account, log in to the application using your <strong>Replicate username</strong> and <strong>Replicate token</strong>.</li> | |
<li>Your Replicate token can be found in your account settings on the Replicate website.</li> | |
</ul> | |
</div> | |
</div> | |
</div> | |
""" | |
demo.launch(auth=replicate_auth, auth_message=auth_message, show_api=False) | |