DoozyWo's picture
Update app.py
9ab6417 verified
import os
import random
import sys
import asyncio
import tempfile
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from PIL import Image
import numpy as np
import spaces
from huggingface_hub import hf_hub_download
# Download required models at startup
def download_models():
"""Download all required models from HuggingFace Hub"""
# Download models as specified
try:
print("๐Ÿ“ฅ Downloading FLUX Kontext checkpoint...")
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Kontext-dev",
filename="flux1-kontext-dev.safetensors",
local_dir="models/checkpoints"
)
print("โœ… FLUX Kontext checkpoint downloaded")
except Exception as e:
print(f"โŒ Error downloading FLUX checkpoint: {e}")
try:
print("๐Ÿ“ฅ Downloading VAE model...")
hf_hub_download(
repo_id="black-forest-labs/FLUX.1-Kontext-dev",
filename="ae.safetensors",
local_dir="models/vae"
)
print("โœ… VAE model downloaded")
except Exception as e:
print(f"โŒ Error downloading VAE: {e}")
try:
print("๐Ÿ“ฅ Downloading CLIP text encoder...")
hf_hub_download(
repo_id="DoozyWo/Kontext_Clip_model",
filename="model.safetensors",
local_dir="models/text_encoders"
)
print("โœ… CLIP text encoder downloaded")
except Exception as e:
print(f"โŒ Error downloading CLIP text encoder: {e}")
try:
print("๐Ÿ“ฅ Downloading T5 text encoder...")
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp8_e4m3fn.safetensors",
local_dir="models/text_encoders"
)
print("โœ… T5 text encoder downloaded")
except Exception as e:
print(f"โŒ Error downloading T5 text encoder: {e}")
try:
print("๐Ÿ“ฅ Downloading Avatar LoRA...")
hf_hub_download(
repo_id="DoozyWo/Kontext_avatar_LoRA",
filename="Avataar_LoRA_000003000.safetensors",
local_dir="models/loras"
)
print("โœ… Avatar LoRA downloaded")
except Exception as e:
print(f"โŒ Error downloading Avatar LoRA: {e}")
print("โœ… Model downloads completed!")
# Download models on import
download_models()
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
"""Returns the value at the given index of a sequence or mapping."""
try:
return obj[index]
except KeyError:
return obj["result"][index]
def find_path(name: str, path: str = None) -> str:
"""Recursively looks for a path starting from the given directory."""
if path is None:
path = os.getcwd()
if name in os.listdir(path):
path_name = os.path.join(path, name)
print(f"{name} found: {path_name}")
return path_name
parent_directory = os.path.dirname(path)
if parent_directory == path:
return None
return find_path(name, parent_directory)
def add_comfyui_directory_to_sys_path() -> None:
"""Add ComfyUI to the sys.path"""
comfyui_path = find_path("ComfyUI")
if comfyui_path and os.path.isdir(comfyui_path):
sys.path.append(comfyui_path)
print(f"'{comfyui_path}' added to sys.path")
def add_extra_model_paths() -> None:
"""Parse the optional extra_model_paths.yaml file and add paths to sys.path."""
try:
from main import load_extra_path_config
except ImportError:
try:
from utils.extra_config import load_extra_path_config
except ImportError:
print("Could not import load_extra_path_config")
return
extra_model_paths = find_path("extra_model_paths.yaml")
if extra_model_paths:
load_extra_path_config(extra_model_paths)
else:
print("Could not find the extra_model_paths config file.")
# Initialize ComfyUI
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
async def import_custom_nodes() -> None:
"""Import and initialize ComfyUI custom nodes."""
import execution
from nodes import init_extra_nodes
import server
# Create event loop if none exists
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Initialize server and nodes
server_instance = server.PromptServer(loop)
execution.PromptQueue(server_instance)
# Await the async function
await init_extra_nodes()
# Import NODE_CLASS_MAPPINGS after ComfyUI is set up
from nodes import NODE_CLASS_MAPPINGS
# Global initialization
_initialized = False
_model_loaders = None
async def initialize_models():
"""Initialize and preload models for faster inference."""
global _initialized, _model_loaders
if _initialized:
return _model_loaders
await import_custom_nodes()
# Use no_grad instead of inference_mode for better compatibility
with torch.no_grad():
# Initialize all node classes
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
# Check what CLIP models are available
text_encoders_path = "models/text_encoders"
available_files = []
if os.path.exists(text_encoders_path):
available_files = os.listdir(text_encoders_path)
print(f"Available text encoder files: {available_files}")
# Try different CLIP loading approaches
try:
# First try: Use the expected model names
dualcliploader_184 = dualcliploader.load_clip(
clip_name1="model.safetensors",
clip_name2="t5xxl_fp8_e4m3fn.safetensors",
type="flux",
device="default",
)
except Exception as e:
print(f"First CLIP load attempt failed: {e}")
try:
# Second try: Use alternative names
dualcliploader_184 = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5xxl_fp8_e4m3fn.safetensors",
type="flux",
device="default",
)
except Exception as e2:
print(f"Second CLIP load attempt failed: {e2}")
# Third try: Download and use standard FLUX text encoders
print("๐Ÿ“ฅ Downloading standard FLUX text encoders as fallback...")
try:
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="clip_l.safetensors",
local_dir="models/text_encoders"
)
hf_hub_download(
repo_id="comfyanonymous/flux_text_encoders",
filename="t5xxl_fp16.safetensors",
local_dir="models/text_encoders"
)
dualcliploader_184 = dualcliploader.load_clip(
clip_name1="clip_l.safetensors",
clip_name2="t5xxl_fp16.safetensors",
type="flux",
device="default",
)
except Exception as e3:
print(f"Fallback CLIP download failed: {e3}")
raise e3
vaeloader_39 = vaeloader.load_vae(vae_name="ae.safetensors")
checkpointloadersimple_188 = checkpointloadersimple.load_checkpoint(
ckpt_name="flux1-kontext-dev.safetensors"
)
loraloadermodelonly_186 = loraloadermodelonly.load_lora_model_only(
lora_name="Avataar_LoRA_000003000.safetensors",
strength_model=1,
model=get_value_at_index(checkpointloadersimple_188, 0),
)
# Store all loaded models
_model_loaders = {
'clip': dualcliploader_184,
'vae': vaeloader_39,
'checkpoint': checkpointloadersimple_188,
'lora_model': loraloadermodelonly_186
}
# Load models to GPU for faster inference with better error handling
try:
from comfy import model_management
# Collect valid models more safely
valid_models = []
for loader_name, loader in _model_loaders.items():
try:
if loader and len(loader) > 0:
model_obj = loader[0]
if hasattr(model_obj, 'patcher') and model_obj.patcher is not None:
if not isinstance(model_obj.patcher, dict):
valid_models.append(model_obj.patcher)
elif not isinstance(model_obj, dict):
valid_models.append(model_obj)
except Exception as e:
print(f"Warning: Could not process model {loader_name}: {e}")
continue
if valid_models:
print(f"Loading {len(valid_models)} models to GPU...")
model_management.load_models_gpu(valid_models)
print("โœ… Models loaded to GPU successfully")
else:
print("โš ๏ธ No valid models found for GPU loading")
except Exception as e:
print(f"โš ๏ธ Warning: Could not load models to GPU: {e}")
print("Models will run on CPU/default device")
_initialized = True
return _model_loaders
@spaces.GPU(duration=60)
def generate_image(input_image, custom_prompt=""):
"""Synchronous wrapper for the async generate_image function - main entry point."""
async def _generate():
return await generate_image_async(input_image, custom_prompt)
# Run the async function in the event loop
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# If loop is already running, create a new task
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, _generate())
return future.result()
else:
return loop.run_until_complete(_generate())
except RuntimeError:
# No event loop exists, create one
return asyncio.run(_generate())
def generate_image_sync(input_image, custom_prompt=""):
"""Alternative synchronous wrapper - calls the main generate_image function."""
return generate_image(input_image, custom_prompt)
async def generate_image_async(input_image, custom_prompt=""):
"""Transform an input image using Avatar LoRA with FLUX Kontext model."""
if input_image is None:
return None, "Please provide an input image."
try:
# Initialize models
model_loaders = await initialize_models()
# Use no_grad instead of inference_mode for better compatibility
with torch.no_grad():
# Force garbage collection before starting
import gc
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Initialize node classes
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
imagescaletototalpixels = NODE_CLASS_MAPPINGS["ImageScaleToTotalPixels"]()
vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]()
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
referencelatent = NODE_CLASS_MAPPINGS["ReferenceLatent"]()
fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
ksampler = NODE_CLASS_MAPPINGS["KSampler"]()
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
# Save input image temporarily
temp_dir = tempfile.mkdtemp()
temp_image_path = os.path.join(temp_dir, "input_image.jpg")
input_image.save(temp_image_path)
try:
# Load and process input image
loadimage_133 = loadimage.load_image(image=temp_image_path)
# Scale image - detach tensors to avoid version tracking issues
image_tensor = get_value_at_index(loadimage_133, 0)
if isinstance(image_tensor, torch.Tensor):
image_tensor = image_tensor.detach().clone()
imagescaletototalpixels_187 = imagescaletototalpixels.upscale(
upscale_method="bicubic",
megapixels=1,
image=image_tensor,
)
# Encode image - ensure tensor is detached
scaled_image = get_value_at_index(imagescaletototalpixels_187, 0)
if isinstance(scaled_image, torch.Tensor):
scaled_image = scaled_image.detach().clone()
vaeencode_124 = vaeencode.encode(
pixels=scaled_image,
vae=get_value_at_index(model_loaders['vae'], 0),
)
# Base prompt for Avatar transformation
base_prompt = "Turn this into a photorealistic Na'vi character from Avatar, with blue bioluminescent skin, large eyes, and set in the glowing jungle of Pandora."
# Combine with custom prompt if provided
if custom_prompt.strip():
full_prompt = f"{base_prompt} {custom_prompt.strip()}"
else:
full_prompt = base_prompt
# Encode text prompts
cliptextencode_181 = cliptextencode.encode(
text=full_prompt,
clip=get_value_at_index(model_loaders['clip'], 0),
)
cliptextencode_182 = cliptextencode.encode(
text="", clip=get_value_at_index(model_loaders['clip'], 0)
)
# Generate Avatar transformation
referencelatent_176 = referencelatent.append(
conditioning=get_value_at_index(cliptextencode_181, 0),
latent=get_value_at_index(vaeencode_124, 0),
)
fluxguidance_179 = fluxguidance.append(
guidance=4.5, conditioning=get_value_at_index(referencelatent_176, 0)
)
# Use random seed for variety
import random
random_seed = random.randint(0, 2**32 - 1)
ksampler_178 = ksampler.sample(
seed=42,
steps=25,
cfg=1,
sampler_name="euler",
scheduler="simple",
denoise=1,
model=get_value_at_index(model_loaders['lora_model'], 0),
positive=get_value_at_index(fluxguidance_179, 0),
negative=get_value_at_index(cliptextencode_182, 0),
latent_image=get_value_at_index(vaeencode_124, 0),
)
vaedecode_177 = vaedecode.decode(
samples=get_value_at_index(ksampler_178, 0),
vae=get_value_at_index(model_loaders['vae'], 0),
)
# Get the result image and properly handle tensor conversion
result_images = get_value_at_index(vaedecode_177, 0)
# Convert tensor to PIL Image with proper detachment
if isinstance(result_images, torch.Tensor):
# Detach and clone to avoid version tracking issues
image_tensor = result_images.detach().clone().squeeze(0)
# Move to CPU if on GPU
if image_tensor.is_cuda:
image_tensor = image_tensor.cpu()
# Convert to numpy
image_np = image_tensor.numpy()
image_np = np.clip(image_np, 0.0, 1.0)
image_np = (image_np * 255).astype(np.uint8)
if len(image_np.shape) == 3 and image_np.shape[-1] == 3:
result_image = Image.fromarray(image_np, 'RGB')
elif len(image_np.shape) == 3:
result_image = Image.fromarray(image_np[:,:,0], 'L').convert('RGB')
else:
result_image = Image.fromarray(image_np, 'L').convert('RGB')
else:
result_image = result_images
# Force cleanup of tensors
del loadimage_133, imagescaletototalpixels_187, vaeencode_124
del cliptextencode_181, cliptextencode_182, referencelatent_176
del fluxguidance_179, ksampler_178, vaedecode_177
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return result_image, "โœจ Avatar transformation complete!"
finally:
# Cleanup temporary files
try:
os.remove(temp_image_path)
os.rmdir(temp_dir)
except:
pass
except Exception as e:
# Force cleanup on error
import gc
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return None, f"โŒ Error: {str(e)}"
def create_gradio_interface():
"""Create the Gradio interface."""
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap');
.gradio-container {
background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
min-height: 100vh;
font-family: 'Inter', sans-serif;
padding: 2rem;
}
.main-header {
text-align: center;
font-size: 3.5rem;
font-weight: 700;
background: linear-gradient(135deg, #06b6d4, #8b5cf6);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 1rem;
}
.main-description {
text-align: center;
font-size: 1.3rem;
color: #cbd5e1;
max-width: 800px;
margin: 0 auto 3rem auto;
line-height: 1.6;
}
.image-container {
background: rgba(15, 23, 42, 0.6);
border-radius: 20px;
padding: 2rem;
border: 2px solid rgba(6, 182, 212, 0.3);
backdrop-filter: blur(10px);
}
.big-image-upload {
border: 2px dashed rgba(6, 182, 212, 0.5) !important;
border-radius: 16px !important;
background: rgba(15, 23, 42, 0.3) !important;
min-height: 500px !important;
transition: all 0.3s ease !important;
}
.big-image-upload:hover {
border-color: rgba(6, 182, 212, 0.8) !important;
background: rgba(6, 182, 212, 0.05) !important;
}
.section-title {
color: #06b6d4;
font-weight: 600;
font-size: 1.5rem;
text-align: center;
margin-bottom: 1.5rem;
}
.custom-textbox textarea {
background: rgba(15, 23, 42, 0.6) !important;
border: 2px solid rgba(6, 182, 212, 0.3) !important;
border-radius: 12px !important;
color: #e2e8f0 !important;
font-size: 1.1rem !important;
padding: 1rem !important;
min-height: 120px !important;
}
.custom-textbox textarea:focus {
border-color: rgba(6, 182, 212, 0.8) !important;
background: rgba(15, 23, 42, 0.8) !important;
}
.transform-button {
background: linear-gradient(135deg, #06b6d4, #3b82f6) !important;
color: white !important;
font-weight: 600 !important;
font-size: 1.3rem !important;
padding: 1.2rem 3rem !important;
border-radius: 12px !important;
border: none !important;
box-shadow: 0 8px 25px rgba(6, 182, 212, 0.4) !important;
transition: all 0.3s ease !important;
width: 100% !important;
margin-top: 1.5rem !important;
}
.transform-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 12px 35px rgba(6, 182, 212, 0.6) !important;
}
.status-display {
background: rgba(15, 23, 42, 0.6);
border-radius: 12px;
padding: 1.5rem;
margin-top: 1rem;
border: 1px solid rgba(6, 182, 212, 0.2);
min-height: 60px;
display: flex;
align-items: center;
justify-content: center;
}
.status-success {
color: #10b981 !important;
font-weight: 600 !important;
font-size: 1.2rem !important;
}
.status-error {
color: #ef4444 !important;
font-weight: 600 !important;
font-size: 1.2rem !important;
}
.status-processing {
color: #06b6d4 !important;
font-weight: 600 !important;
font-size: 1.2rem !important;
}
@media (max-width: 768px) {
.main-header {
font-size: 2.5rem;
}
.big-image-upload {
min-height: 400px !important;
}
.gradio-container {
padding: 1rem;
}
}
"""
with gr.Blocks(css=custom_css, title="Avatar Transformation Studio") as interface:
# Header
gr.HTML("""
<div class="main-header">
๐Ÿงžโ€โ™‚๏ธ Avatar Transformation Studio
</div>
<div class="main-description">
Transform any portrait into a stunning Na'vi character with bioluminescent blue skin and mystical Avatar features using FLUX Kontext + Avatar LoRA.
</div>
""")
with gr.Row(equal_height=False):
# Input Column
with gr.Column(scale=1):
with gr.Column(elem_classes="image-container"):
gr.HTML('<div class="section-title">๐Ÿ“ธ Upload Your Portrait</div>')
input_image = gr.Image(
label="",
type="pil",
height=500,
elem_classes="big-image-upload",
show_label=False
)
gr.HTML('<div class="section-title" style="margin-top: 2rem;">โœจ Additional Details (Optional)</div>')
custom_prompt = gr.TextArea(
label="",
placeholder="Add creative details like: 'warrior markings', 'glowing tattoos', 'forest setting', 'ceremonial jewelry', etc.",
lines=4,
elem_classes="custom-textbox",
show_label=False
)
transform_btn = gr.Button(
"๐ŸŒŸ Transform to Na'vi Avatar",
elem_classes="transform-button"
)
# Output Column
with gr.Column(scale=1):
with gr.Column(elem_classes="image-container"):
gr.HTML('<div class="section-title">๐Ÿงžโ€โ™‚๏ธ Your Avatar Transformation</div>')
output_image = gr.Image(
label="",
height=500,
show_label=False,
elem_classes="big-image-upload"
)
status_text = gr.HTML(
value="<div class='status-display'><span style='color: #64748b; font-style: italic;'>Ready to transform! Upload an image to begin โœจ</span></div>",
show_label=False
)
# Tips Section
gr.HTML("""
<div style="background: rgba(6, 182, 212, 0.1); border: 1px solid rgba(6, 182, 212, 0.3); border-radius: 16px; padding: 2rem; margin-top: 3rem; text-align: center;">
<h3 style="color: #06b6d4; font-weight: 600; margin-bottom: 1rem;">๐Ÿ’ก Tips for Best Results</h3>
<p style="color: #cbd5e1; font-size: 1.1rem; line-height: 1.6;">
๐Ÿ“ท Use clear, well-lit photos โ€ข ๐Ÿ‘ค Front-facing works best โ€ข ๐ŸŽจ High resolution recommended โ€ข โœจ Be creative with prompts
</p>
</div>
""")
# Event Handler
def gradio_transform(image, prompt):
if image is None:
error_html = "<div class='status-display'><span class='status-error'>โš ๏ธ Please upload an image first!</span></div>"
return None, error_html
# Show processing status
processing_html = "<div class='status-display'><span class='status-processing'>๐Ÿ”ฎ Creating your Avatar transformation... This may take a moment! โœจ</span></div>"
yield None, processing_html
try:
result_image, status = generate_image(image, prompt)
if result_image:
success_html = "<div class='status-display'><span class='status-success'>๐ŸŽ‰ Transformation complete! Welcome to Pandora! ๐Ÿงžโ€โ™‚๏ธ</span></div>"
yield result_image, success_html
else:
error_html = f"<div class='status-display'><span class='status-error'>โŒ {status}</span></div>"
yield None, error_html
except Exception as e:
error_html = f"<div class='status-display'><span class='status-error'>๐Ÿ’ฅ Error: {str(e)}</span></div>"
yield None, error_html
transform_btn.click(
fn=gradio_transform,
inputs=[input_image, custom_prompt],
outputs=[output_image, status_text],
show_progress=True
)
return interface
if __name__ == "__main__":
# Initialize models on startup
print("๐Ÿš€ Initializing Avatar Transformation Studio...")
print("๐Ÿ“ฆ Loading models...")
# Run async initialization
async def init_app():
await initialize_models()
print("โœ… Models loaded successfully!")
# Initialize models asynchronously
asyncio.run(init_app())
# Launch interface (synchronously)
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
debug=False
)