JSPER / app.py
arusterholz's picture
Removed 'WEBP' support
daa802f verified
raw
history blame contribute delete
No virus
22.6 kB
import os
import shutil
import zipfile
from typing import Dict, Tuple
import gradio as gr
import requests
from PIL import Image
from replicate.client import Client
# Constants
IMAGE_DIR = "uploaded_images"
VALID_IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg']
# Global variables
replicate_credentials: Dict[str, str] = {}
global_serving_url_dict: Dict[str, str] = {}
num_of_input_images_dict: Dict[str, int] = {}
def validate_replicate_credentials(username: str, token: str) -> Tuple[bool, Dict[str, str]]:
"""Validate Replicate credentials and return account info if valid."""
headers = {"Authorization": f"Token {token}"}
response = requests.get("https://api.replicate.com/v1/account", headers=headers)
if response.status_code == 200:
account_info = response.json()
if account_info.get("username") == username:
return True, account_info
return False, {}
def create_save_slots(username: str, replicate_token: str) -> None:
"""Create save slots for a given username."""
base_url = "https://api.replicate.com/v1"
headers = {"Authorization": f"Token {replicate_token}"}
account_response = requests.get(f"{base_url}/account", headers=headers)
if account_response.status_code != 200:
print(f"Failed to get account information. Status Code: {account_response.status_code}")
return
account_info = account_response.json()
account_name = account_info.get("username")
if not account_name:
print("Account username not found in response.")
return
for i in range(1, 5):
save_slot_name = f"saveslot{i}"
model_name = f"{account_name}/{save_slot_name}"
model_response = requests.get(f"{base_url}/models/{model_name}", headers=headers)
if model_response.status_code == 200:
print(f"{save_slot_name} already exists.")
elif model_response.status_code == 404:
create_model_response = requests.post(
f"{base_url}/models",
headers=headers,
json={
"owner": account_name,
"name": save_slot_name,
"description": f"Save slot {i}",
"visibility": "private",
"hardware": "gpu-a40-small"
}
)
if create_model_response.status_code == 201:
print(f"{save_slot_name} created successfully.")
else:
print(f"Failed to create {save_slot_name}. Status Code: {create_model_response.status_code}")
else:
print(f"Failed to check {save_slot_name}. Status Code: {model_response.status_code}")
def replicate_auth(username: str, token: str) -> bool:
"""Authenticate user with Replicate and create save slots."""
is_valid, account_info = validate_replicate_credentials(username, token)
if is_valid:
replicate_credentials[username] = token
create_save_slots(username, token)
return True
return False
def save_images(uploaded_files: list, request: gr.Request) -> str:
"""Save uploaded images to a directory and return the serving URL."""
if os.path.exists(IMAGE_DIR):
shutil.rmtree(IMAGE_DIR)
os.makedirs(IMAGE_DIR, exist_ok=True)
image_count = 0
for i, file_info in enumerate(uploaded_files):
extension = os.path.splitext(file_info.name)[1][1:].lower()
if extension not in VALID_IMAGE_EXTENSIONS:
return f"Error: '{file_info.name}' is not a valid image file. Please upload only PNG, JPG, or JPEG files."
try:
with Image.open(file_info.name) as image:
image_path = os.path.join(IMAGE_DIR, f"image_{i}.{extension}")
image.save(image_path)
image_count += 1
except OSError as e:
return f"Error processing image {file_info.name}: {e}"
if image_count == 0:
return "No valid images were uploaded. Please upload at least one PNG, JPG, or JPEG file."
num_of_input_images_dict[request.username] = image_count
zip_path = 'data.zip'
with zipfile.ZipFile(zip_path, 'w') as zipf:
for root, dirs, files in os.walk(IMAGE_DIR):
for file in files:
zipf.write(os.path.join(root, file), arcname=file)
replicate_token = replicate_credentials[request.username]
headers = {"Authorization": f"Token {replicate_token}"}
response = requests.post("https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip", headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to get upload URL. Response: {response.text}")
data = response.json()
upload_url = data.get("upload_url")
if not upload_url:
raise Exception("Upload URL not found in response")
with open("data.zip", "rb") as file:
upload_response = requests.put(upload_url, headers={"Content-Type": "application/zip"}, data=file)
if upload_response.status_code != 200:
raise Exception(f"Failed to upload file. Response: {upload_response.text}")
serving_url = data.get("serving_url")
if not serving_url:
raise Exception("Serving URL not found in response")
global_serving_url_dict[request.username] = serving_url
return f"Images have been successfully uploaded to: {serving_url}"
def on_files_selected(data: list) -> str:
"""Callback function when files are selected."""
return "Click 'Upload Training Data' to upload your images."
def on_files_cleared(data: list) -> str:
"""Callback function when files are cleared."""
return "Upload images and press the button."
def train_model(category: str, token_string: str, caption_prefix: str, mask_target_prompts: str,
destination_choice: str, request: gr.Request) -> Tuple[str, str]:
"""Train a model with the specified parameters."""
account_name = request.username
num_of_input_images = num_of_input_images_dict[request.username]
global_serving_url = global_serving_url_dict[request.username]
replicate_client = Client(api_token=replicate_credentials[account_name])
destination_decider = {
'Slot 1': f"{account_name}/saveslot1",
'Slot 2': f"{account_name}/saveslot2",
'Slot 3': f"{account_name}/saveslot3",
'Slot 4': f"{account_name}/saveslot4"
}
destination = destination_decider[destination_choice]
settings = {
"Person": {
"resolution": 1024,
"train_batch_size": 4,
"max_train_steps": 1400 + (num_of_input_images * 50),
"unet_learning_rate": 1e-6,
"is_lora": True,
"lora_lr": 1e-4,
"crop_based_on_salience": False,
"use_face_detection_instead": True,
"clipseg_temperature": 1.0,
},
"Object": {
"resolution": 1024,
"train_batch_size": 4,
"max_train_steps": 1000 + (num_of_input_images * 50),
"unet_learning_rate": 1e-6,
"is_lora": True,
"lora_lr": 1e-4,
"crop_based_on_salience": True,
"use_face_detection_instead": False,
"clipseg_temperature": 0.9,
},
"Animal": {
"resolution": 1024,
"train_batch_size": 4,
"max_train_steps": 500 + (num_of_input_images * 75),
"unet_learning_rate": 1e-6,
"is_lora": True,
"lora_lr": 1e-4,
"crop_based_on_salience": True,
"use_face_detection_instead": False,
"clipseg_temperature": 0.9,
},
"Style": {
"resolution": 1024,
"train_batch_size": 4,
"max_train_steps": 300 + (num_of_input_images * 100),
"unet_learning_rate": 1e-6,
"is_lora": True,
"lora_lr": 3e-4,
"crop_based_on_salience": False,
"use_face_detection_instead": False,
"clipseg_temperature": 1.0,
}
}[category]
training = replicate_client.trainings.create(
version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
input={
"input_images": global_serving_url,
**settings,
"token_string": token_string,
"caption_prefix": caption_prefix,
"mask_target_prompts": mask_target_prompts,
},
destination=destination
)
train_id = training.id
return train_id, "Training started successfully!"
def cancel_training(training_id: str, request: gr.Request) -> str:
"""Cancel the specified training."""
replicate_token = replicate_credentials[request.username]
headers = {"Authorization": f"Token {replicate_token}"}
url = f"https://api.replicate.com/v1/trainings/{training_id}/cancel"
response = requests.post(url, headers=headers)
if response.status_code == 200:
return "Training cancelled successfully!"
else:
return f"Failed to cancel training. Response code: {response.status_code}, Message: {response.text}"
def check_training_status(training_id: str, request: gr.Request) -> str:
"""Check the status of the specified training."""
replicate_token = replicate_credentials[request.username]
headers = {"Authorization": f"Token {replicate_token}"}
url = f"https://api.replicate.com/v1/trainings/{training_id}"
response = requests.get(url, headers=headers)
if response.status_code == 200:
training = response.json()
output = f"Training Status: {training['status']}\n"
if training['status'] == 'processing':
output += "\n".join(training['logs'].split("\n")[-15:])
elif training['status'] == 'succeeded':
output += str(training)
else:
output = f"Request failed with status code: {response.status_code}"
return output
def generate_images(destination_choice: str, resolution: str, prompt: str, negative_prompt: str,
lora_scale: float, num_outputs: int, request: gr.Request) -> list:
"""Generate images using the specified parameters."""
account_name = request.username
replicate_client = Client(api_token=replicate_credentials[request.username])
destination_decider = {
'Slot 1': f"{account_name}/saveslot1",
'Slot 2': f"{account_name}/saveslot2",
'Slot 3': f"{account_name}/saveslot3",
'Slot 4': f"{account_name}/saveslot4"
}
destination = destination_decider[destination_choice]
replicate_token = replicate_credentials[request.username]
headers = {"Authorization": f"Token {replicate_token}"}
response = requests.get(f"https://api.replicate.com/v1/models/{destination}", headers=headers)
if response.status_code == 200:
response_data = response.json()
latest_version = response_data.get('latest_version', {}).get('id', 'Not available')
print("Latest version ID:", latest_version)
else:
print("Failed to fetch the model details. Status Code:", response.status_code)
output = replicate_client.run(
f"{destination}:{latest_version}",
input={
"width": int(resolution),
"height": int(resolution),
"prompt": prompt,
"negative_prompt": negative_prompt,
"lora_scale": lora_scale,
"num_outputs": num_outputs,
"apply_watermark": False,
"disable_safety_checker": True
}
)
return output
# Gradio UI Block
with gr.Blocks(theme='prismglider/JSPERTheme2') as demo:
with gr.Tab("Upload Data"):
with gr.Row():
with gr.Column():
upload_button = gr.File(label="Upload Images", type="filepath", file_count="multiple", interactive=True)
with gr.Column():
status_message = gr.Textbox(label="Status", interactive=False, value="Upload images and press the button.")
btn = gr.Button("Upload Training Data")
btn.click(save_images, inputs=[upload_button], outputs=[status_message])
upload_button.clear(on_files_cleared, inputs=upload_button, outputs=status_message)
upload_button.upload(on_files_selected, inputs=upload_button, outputs=status_message)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
<style>
@media (max-width: 600px) {
.responsive-container {
flex-direction: column;
}
.responsive-container .text-content,
.responsive-container .image-content {
flex: 0 0 100%;
padding: 0;
}
}
</style>
<div class="responsive-container" style="display: flex; flex-wrap: wrap; align-items: flex-start;">
<div class="text-content" style="flex: 1;">
<h2>JSPER (Just Stablediffusion Plus Easy Retraining)</h2>
<p>JSPER makes it EASY to train and generate custom art pieces with Stable Diffusion!</p>
<p>Follow along with our dog and mascot, Jasper, to see how it works!</p>
<h3>Step 1: Upload Data</h3>
<ul>
<li>Upload images of your subject that will be used to train the AI art model</li>
<li>It's recommended to use 10+ images for training (~15 minutes of training time for 10 images)</li>
<li>After you've selected your images, click the 'upload training data' button</li>
<li>For best results, use HD images that show the subject in multiple poses, at multiple angles, and with various lighting conditions.</li>
</ul>
</div>
<div class="image-content" style="flex: 1; padding-left: 20px;">
<img src="https://i.postimg.cc/nrVnQRwD/Untitled-Artwork.png" style="width: 100%; max-width: 300px; height: auto;">
</div>
</div>
"""
)
with gr.Tab("Train AI"):
with gr.Row():
with gr.Column(scale=2):
category = gr.Radio(["Person", "Object", "Animal", "Style"], label="Category")
token_string = gr.Textbox(label="Token", value="JSP")
caption_prefix = gr.Textbox(label="Tokenized Prompt", value="photo of a JSP thing")
mask_target_prompts = gr.Textbox(label="Reference Prompt", value="photo of a thing")
destination_choice = gr.Radio(["Slot 1", "Slot 2", "Slot 3", "Slot 4"], label="Save Slot")
submit_button = gr.Button("Start Training")
with gr.Column():
training_status = gr.Textbox(label="Training Status")
training_id = gr.Textbox(visible=False)
with gr.Row():
check_status_button = gr.Button("Check Training Status")
cancel_training_button = gr.Button("Cancel Training")
submit_button.click(
train_model,
inputs=[category, token_string, caption_prefix, mask_target_prompts, destination_choice],
outputs=[training_id, training_status]
)
check_status_button.click(check_training_status, inputs=[training_id], outputs=training_status)
cancel_training_button.click(cancel_training, inputs=[training_id], outputs=training_status)
with gr.Row():
with gr.Column():
gr.HTML(
"""
<style>
@media (max-width: 600px) {
.responsive-container {
flex-direction: column;
}
.responsive-container .text-content,
.responsive-container .image-content {
flex: 0 0 100%;
padding: 0;
}
}
</style>
<div class="responsive-container" style="display: flex; flex-wrap: wrap; align-items: flex-start;">
<div class="text-content" style="flex: 1;">
<h2>JSPER (Just Stablediffusion Plus Easy Retraining)</h2>
<h3>Step 2: Train AI</h3>
<ul>
<li>Select the type of subject you are training (Person, Object, Animal, or Style).</li>
<li>Choose a save slot for your training.</li>
<ul style="padding-left: 20px;">
<li>For <strong>Persons</strong>: Use "JSP" or the person's name as the token. Your prompt should be "Photo of a JSP man/girl" or "Photo of {Firstname Lastname}".</li>
<li>For <strong>Objects</strong>: Use "JSP" as the token. Your prompt should be "Photo of a JSP {object}", e.g., "Photo of a JSP shoe".</li>
<li>For <strong>Animals</strong>: Use "JSP" as the token. Your prompt should be "Photo of a JSP {animal}", e.g., "Photo of a JSP dog".</li>
<li>For <strong>Styles</strong>: Your prompt should be "In the style of JSP".</li>
</ul>
</ul>
</div>
<div class="image-content" style="flex: 1; padding-left: 20px;">
<img src="https://i.postimg.cc/6qHtVrD1/Untitled-Artwork-2.png" style="width: 100%; max-width: 300px; height: auto;">
</div>
</div>
"""
)
with gr.Tab("Generate Images"):
with gr.Row():
with gr.Column():
destination_choice = gr.Radio(["Slot 1", "Slot 2", "Slot 3", "Slot 4"], label="Load Slot")
resolution = gr.Radio(["512", "1024", "2048"], label="Resolution", value="1024")
prompt = gr.Textbox(label="Prompt", value="a JSP object")
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
lora_scale = gr.Slider(label="Training Scale", minimum=.5, maximum=1, value=0.6, step=.01)
num_outputs = gr.Slider(label="Number of Outputs", minimum=1, maximum=4, value=2, step=1)
with gr.Column():
images = gr.Gallery(label="Generated Images")
submit_button = gr.Button("Generate")
submit_button.click(
generate_images,
inputs=[destination_choice, resolution, prompt, negative_prompt, lora_scale, num_outputs],
outputs=images
)
with gr.Row():
with gr.Column():
gr.HTML(
"""
<style>
@media (max-width: 700px) {
.responsive-container {
flex-direction: column;
}
.responsive-container .text-content,
.responsive-container .image-content {
flex: 0 0 100%;
padding: 0;
}
}
</style>
<div class="responsive-container" style="display: flex; flex-wrap: wrap; align-items: flex-start;">
<div class="text-content" style="flex: 1;">
<h2>Step 3: Generate Images</h2>
<p>In the <strong>Generate Images</strong> tab:</p>
<ul>
<li>Select a load slot corresponding to one of your previous trainings.</li>
<li>Choose the number of images you want to generate.</li>
<li>Enter a prompt that utilizes the token and prompt style used in the training section. For example:</li>
<ul style="padding-left: 20px;">
<li>For a person: "Photo of JSP man at a construction site" or "Renaissance style painting of Johnny Fakename".</li>
<li>For a style: "Digital landscape shot in the style of JSP".</li>
</ul>
<li>The prompts are flexible, and you can try multiple prompts to get good results.</li>
<li>If the results don't look enough like the subject, adjust the training scale up slightly.</li>
<li>If the results don't look creative or unique, lower the training scale slightly. A value of 0.6 is good in most cases.</li>
</ul>
</div>
<div class="image-content" style="flex: 1; padding-left: 20px;">
<img src="https://i.postimg.cc/qBQJdrzX/Untitled-Artwork-1.png" style="width: 100%; max-width: 300px; height: auto;">
</div>
</div>
"""
)
auth_message = """
<div style='display: flex; align-items: flex-start;'>
<div style='width: 200px; flex-shrink: 0;'>
<img src='https://i.postimg.cc/qBQJdrzX/Untitled-Artwork-1.png' style='width: 100%; height: auto;'>
</div>
<div style='flex: 1;'>
<h1 style="font-size: 1.5em; text-decoration: underline; margin-bottom: 20px;">Welcome to JSPER! (Just Stablediffusion Plus Easy Retraining)</h1>
<p>Before you begin training StableDiffusion, please follow these steps to log in:</p>
<div style='padding-left: 20px; padding-right: 20px;'>
<ul>
<li>You need to have a <u><strong><a href='https://replicate.com/signin?next=/'>Replicate account</a></strong></u> to use this application.</li>
<li>Once you have an account, log in to the application using your <strong>Replicate username</strong> and <strong>Replicate token</strong>.</li>
<li>Your Replicate token can be found in your account settings on the Replicate website.</li>
</ul>
</div>
</div>
</div>
"""
demo.launch(auth=replicate_auth, auth_message=auth_message, show_api=False)