admaker / core /image_generator.py
karthikeya1212's picture
Update core/image_generator.py
d88d679 verified
raw
history blame
20.4 kB
# import os
# from pathlib import Path
# import gc
# import torch
# from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
# from huggingface_hub import hf_hub_download
# from typing import Dict, Any
# from PIL import Image
# from io import BytesIO
# import base64
# import tempfile
# # --------------------------------------------------------------
# # 🚨 ABSOLUTE FIX FOR PermissionError('/.cache') & '/root/.cache'
# # --------------------------------------------------------------
# HF_CACHE_DIR = Path("/tmp/hf_cache")
# HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# # Set ALL key environment variables FIRST
# os.environ.update({
# "HF_HOME": str(HF_CACHE_DIR),
# "HF_HUB_CACHE": str(HF_CACHE_DIR),
# "DIFFUSERS_CACHE": str(HF_CACHE_DIR),
# "TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
# "XDG_CACHE_HOME": str(HF_CACHE_DIR),
# "HF_DATASETS_CACHE": str(HF_CACHE_DIR),
# "HF_MODULES_CACHE": str(HF_CACHE_DIR),
# "TMPDIR": str(HF_CACHE_DIR),
# "CACHE_DIR": str(HF_CACHE_DIR),
# "TORCH_HOME": str(HF_CACHE_DIR),
# "HOME": str(HF_CACHE_DIR)
# })
# # Patch expanduser BEFORE any library imports that might touch ~/.cache
# import os.path
# if not hasattr(os.path, "expanduser_original"):
# os.path.expanduser_original = os.path.expanduser
# def safe_expanduser(path):
# if (
# path.startswith("~") or
# path.startswith("/.cache") or
# path.startswith("/root/.cache")
# ):
# print(f"[DEBUG] πŸ”„ Patched path expanduser call for: {path}")
# return str(HF_CACHE_DIR)
# return os.path.expanduser_original(path)
# os.path.expanduser = safe_expanduser
# tempfile.tempdir = str(HF_CACHE_DIR)
# print("[DEBUG] βœ… Hugging Face, Diffusers, Datasets and Torch cache fully redirected to:", HF_CACHE_DIR)
# # --------------------------------------------------------------
# # βœ… PERSISTENT STORAGE SETUP (for Hugging Face Spaces)
# # --------------------------------------------------------------
# MODEL_DIR = Path("/tmp/models/realvisxl_v4")
# SEED_DIR = Path("/tmp/seed_images")
# TMP_DIR = Path("/tmp/generated_images")
# for d in [MODEL_DIR, SEED_DIR, TMP_DIR]:
# d.mkdir(parents=True, exist_ok=True)
# print("[DEBUG] βœ… Using persistent Hugging Face cache at:", HF_CACHE_DIR)
# print("[DEBUG] βœ… Model directory:", MODEL_DIR)
# print("[DEBUG] βœ… Seed directory:", SEED_DIR)
# # --------------------------------------------------------------
# # MODEL CONFIG
# # --------------------------------------------------------------
# MODEL_REPO = "SG161222/RealVisXL_V4.0"
# MODEL_FILENAME = "RealVisXL_V4.0.safetensors"
# # ---------------- GLOBAL PIPELINE CACHE ----------------
# pipe: StableDiffusionXLPipeline | None = None
# img2img_pipe: StableDiffusionXLImg2ImgPipeline | None = None
# # --------------------------------------------------------------
# # MODEL DOWNLOAD
# # --------------------------------------------------------------
# def download_model() -> Path:
# model_path = MODEL_DIR / MODEL_FILENAME
# if not model_path.exists():
# print("[ImageGen] Downloading RealVisXL V4.0 model...")
# model_path = Path(
# hf_hub_download(
# repo_id=MODEL_REPO,
# filename=MODEL_FILENAME,
# cache_dir=str(HF_CACHE_DIR),
# force_download=False,
# resume_download=True,
# )
# )
# print(f"[ImageGen] βœ… Model downloaded to: {model_path}")
# else:
# print("[ImageGen] βœ… Model already exists at:", model_path)
# return model_path
# # --------------------------------------------------------------
# # MEMORY-SAFE PIPELINE MANAGER
# # --------------------------------------------------------------
# def unload_pipelines(target="all"):
# """Unload specific or all pipelines."""
# global pipe, img2img_pipe
# print("[ImageGen] 🧹 Clearing pipelines from memory...")
# if target in ("pipe", "all"):
# try:
# del pipe
# except:
# pass
# pipe = None
# if target in ("img2img_pipe", "all"):
# try:
# del img2img_pipe
# except:
# pass
# img2img_pipe = None
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# print("[ImageGen] βœ… Memory cleared.")
# def safe_load_pipeline(pipeline_class, model_path):
# """Safely load a pipeline with retry logic and memory handling."""
# try:
# print(f"[ImageGen] πŸ”„ Loading {pipeline_class.__name__} from {model_path} ...")
# pipe = pipeline_class.from_single_file(
# model_path,
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
# )
# print(f"[ImageGen] βœ… Successfully loaded {pipeline_class.__name__}.")
# return pipe
# except Exception as e:
# print(f"[ImageGen] ❌ Failed to load {pipeline_class.__name__}: {e}")
# unload_pipelines()
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# raise e
# def load_pipeline():
# global pipe
# unload_pipelines(target="pipe")
# model_path = download_model()
# print("[ImageGen] Loading main (txt2img) pipeline...")
# pipe = safe_load_pipeline(StableDiffusionXLPipeline, model_path)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# pipe.to(device)
# pipe.safety_checker = None
# pipe.enable_attention_slicing()
# print("[ImageGen] βœ… Text-to-image pipeline ready.")
# return pipe
# def load_img2img_pipeline():
# global img2img_pipe
# unload_pipelines(target="img2img_pipe")
# model_path = download_model()
# print("[ImageGen] Loading img2img pipeline...")
# img2img_pipe = safe_load_pipeline(StableDiffusionXLImg2ImgPipeline, model_path)
# device = "cuda" if torch.cuda.is_available() else "cpu"
# img2img_pipe.to(device)
# img2img_pipe.safety_checker = None
# img2img_pipe.enable_attention_slicing()
# print("[ImageGen] βœ… Img2Img pipeline ready.")
# return img2img_pipe
# # --------------------------------------------------------------
# # UTILITY: PIL β†’ BASE64
# # --------------------------------------------------------------
# def pil_to_base64(img: Image.Image) -> str:
# buffered = BytesIO()
# img.save(buffered, format="PNG")
# return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
# # --------------------------------------------------------------
# # UNIFIED IMAGE GENERATION FUNCTION
# # --------------------------------------------------------------
# async def generate_images(prompt_or_json, seed: int | None = None, num_images: int = 3):
# global pipe, img2img_pipe
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # ----------------------------------------------------------
# # CASE 1: STRUCTURED JSON (story mode)
# # ----------------------------------------------------------
# if isinstance(prompt_or_json, dict):
# story_json = prompt_or_json
# print("[ImageGen] Detected structured JSON input. Generating cinematic visuals...")
# # Step 1: Load only txt2img for character generation
# pipe = load_pipeline()
# seed_to_char_image = {}
# for char in story_json.get("characters", []):
# char_name = char["name"]
# char_seed = int(char.get("seed", 0))
# char_desc = char.get("description", "")
# seed_image_path = SEED_DIR / f"seed_{char_seed}.png"
# if seed_image_path.exists():
# print(f"[ImageGen] πŸ” Reusing existing seed image for '{char_name}' (seed={char_seed})")
# image = Image.open(seed_image_path)
# else:
# print(f"[ImageGen] 🎨 Generating new character '{char_name}' (seed={char_seed})")
# generator = torch.Generator(device).manual_seed(char_seed)
# image = pipe(f"{char_name}, {char_desc}", num_inference_steps=30, generator=generator).images[0]
# image.save(seed_image_path)
# seed_to_char_image[char_seed] = image
# # Free txt2img pipeline
# unload_pipelines(target="pipe")
# # Step 2: Load only img2img for keyframes
# img2img_pipe = load_img2img_pipeline()
# for key, scene_data in story_json.items():
# if not key.startswith("scene"):
# continue
# for frame in scene_data.get("keyframes", []):
# frame_seed = int(frame.get("seed", 0))
# if frame_seed not in seed_to_char_image:
# print(f"[WARN] Seed {frame_seed} not found in characters. Skipping keyframes...")
# continue
# init_image = seed_to_char_image[frame_seed]
# for kf_key, kf_prompt in frame.items():
# if kf_key.startswith("keyframe"):
# print(f"[ImageGen] 🎬 Generating {key} β†’ {kf_key} using seed {frame_seed}")
# generator = torch.Generator(device).manual_seed(frame_seed)
# img = img2img_pipe(
# prompt=kf_prompt,
# image=init_image,
# strength=0.55,
# num_inference_steps=30,
# generator=generator
# ).images[0]
# out_path = TMP_DIR / f"{key}_{kf_key}_seed{frame_seed}.png"
# img.save(out_path)
# frame[kf_key] = pil_to_base64(img)
# unload_pipelines(target="all") # unload both just in case
# print("[ImageGen] βœ… Story JSON image generation complete.")
# return story_json
# # ----------------------------------------------------------
# # CASE 2: NORMAL PROMPT
# # ----------------------------------------------------------
# print(f"[ImageGen] Generating {num_images} image(s) for prompt='{prompt_or_json}' seed={seed}")
# pipe = load_pipeline()
# images = []
# for i in range(num_images):
# gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
# try:
# img = pipe(prompt_or_json, num_inference_steps=30, generator=gen).images[0]
# img_path = TMP_DIR / f"prompt_{i}.png"
# img.save(img_path)
# images.append(pil_to_base64(img))
# except Exception as e:
# print(f"[ImageGen] ⚠️ Failed on image {i}: {e}")
# unload_pipelines(target="pipe")
# print(f"[ImageGen] βœ… Generated {len(images)} image(s) successfully.")
# return images
import os
from pathlib import Path
import gc
import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, AutoPipelineForText2Image
from huggingface_hub import hf_hub_download
from typing import Dict, Any
from PIL import Image
from io import BytesIO
import base64
import tempfile
# --------------------------------------------------------------
# 🚨 ABSOLUTE FIX FOR PermissionError('/.cache') & '/root/.cache'
# --------------------------------------------------------------
HF_CACHE_DIR = Path("/tmp/hf_cache")
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Set ALL key environment variables FIRST
os.environ.update({
"HF_HOME": str(HF_CACHE_DIR),
"HF_HUB_CACHE": str(HF_CACHE_DIR),
"DIFFUSERS_CACHE": str(HF_CACHE_DIR),
"TRANSFORMERS_CACHE": str(HF_CACHE_DIR),
"XDG_CACHE_HOME": str(HF_CACHE_DIR),
"HF_DATASETS_CACHE": str(HF_CACHE_DIR),
"HF_MODULES_CACHE": str(HF_CACHE_DIR),
"TMPDIR": str(HF_CACHE_DIR),
"CACHE_DIR": str(HF_CACHE_DIR),
"TORCH_HOME": str(HF_CACHE_DIR),
"HOME": str(HF_CACHE_DIR)
})
# Patch expanduser BEFORE any library imports that might touch ~/.cache
import os.path
if not hasattr(os.path, "expanduser_original"):
os.path.expanduser_original = os.path.expanduser
def safe_expanduser(path):
if (
path.startswith("~") or
path.startswith("/.cache") or
path.startswith("/root/.cache")
):
print(f"[DEBUG] πŸ”„ Patched path expanduser call for: {path}")
return str(HF_CACHE_DIR)
return os.path.expanduser_original(path)
os.path.expanduser = safe_expanduser
tempfile.tempdir = str(HF_CACHE_DIR)
print("[DEBUG] βœ… Hugging Face, Diffusers, Datasets and Torch cache fully redirected to:", HF_CACHE_DIR)
# --------------------------------------------------------------
# βœ… PERSISTENT STORAGE SETUP (for Hugging Face Spaces)
# --------------------------------------------------------------
MODEL_DIR = Path("/tmp/models/dreamshaper_sd15")
SEED_DIR = Path("/tmp/seed_images")
TMP_DIR = Path("/tmp/generated_images")
for d in [MODEL_DIR, SEED_DIR, TMP_DIR]:
d.mkdir(parents=True, exist_ok=True)
print("[DEBUG] βœ… Using persistent Hugging Face cache at:", HF_CACHE_DIR)
print("[DEBUG] βœ… Model directory:", MODEL_DIR)
print("[DEBUG] βœ… Seed directory:", SEED_DIR)
# --------------------------------------------------------------
# MODEL CONFIG
# --------------------------------------------------------------
MODEL_REPO = "lykon/dreamshaper-8" # Use Hugging Face repo
# ---------------- GLOBAL PIPELINE CACHE ----------------
pipe: StableDiffusionXLPipeline | AutoPipelineForText2Image | None = None
img2img_pipe: StableDiffusionXLImg2ImgPipeline | None = None
# --------------------------------------------------------------
# MEMORY-SAFE PIPELINE MANAGER
# --------------------------------------------------------------
def unload_pipelines(target="all"):
"""Unload specific or all pipelines."""
global pipe, img2img_pipe
print("[ImageGen] 🧹 Clearing pipelines from memory...")
if target in ("pipe", "all"):
try:
del pipe
except:
pass
pipe = None
if target in ("img2img_pipe", "all"):
try:
del img2img_pipe
except:
pass
img2img_pipe = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("[ImageGen] βœ… Memory cleared.")
def safe_load_pipeline(pretrained_model_name):
"""Load DreamShaper SD1.5 safely via from_pretrained."""
try:
print(f"[ImageGen] πŸ”„ Loading model {pretrained_model_name} ...")
pipe = AutoPipelineForText2Image.from_pretrained(
pretrained_model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
variant="fp16" # use fp16 if possible
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
pipe.enable_attention_slicing()
print(f"[ImageGen] βœ… Successfully loaded {pretrained_model_name}.")
return pipe
except Exception as e:
print(f"[ImageGen] ❌ Failed to load {pretrained_model_name}: {e}")
unload_pipelines()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise e
def load_pipeline():
global pipe
unload_pipelines(target="pipe")
print("[ImageGen] Loading main (txt2img) pipeline...")
pipe = safe_load_pipeline(MODEL_REPO)
print("[ImageGen] βœ… Text-to-image pipeline ready.")
return pipe
def load_img2img_pipeline():
global img2img_pipe
unload_pipelines(target="img2img_pipe")
print("[ImageGen] Loading img2img pipeline...")
# For DreamShaper, img2img uses the same pipeline
img2img_pipe = safe_load_pipeline(MODEL_REPO)
print("[ImageGen] βœ… Img2Img pipeline ready.")
return img2img_pipe
# --------------------------------------------------------------
# UTILITY: PIL β†’ BASE64
# --------------------------------------------------------------
def pil_to_base64(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
# --------------------------------------------------------------
# UNIFIED IMAGE GENERATION FUNCTION
# --------------------------------------------------------------
async def generate_images(prompt_or_json, seed: int | None = None, num_images: int = 3):
global pipe, img2img_pipe
device = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------------------------------------------------
# CASE 1: STRUCTURED JSON (story mode)
# ----------------------------------------------------------
if isinstance(prompt_or_json, dict):
story_json = prompt_or_json
print("[ImageGen] Detected structured JSON input. Generating cinematic visuals...")
# Step 1: Load only txt2img for character generation
pipe = load_pipeline()
seed_to_char_image = {}
for char in story_json.get("characters", []):
char_name = char["name"]
char_seed = int(char.get("seed", 0))
char_desc = char.get("description", "")
seed_image_path = SEED_DIR / f"seed_{char_seed}.png"
if seed_image_path.exists():
print(f"[ImageGen] πŸ” Reusing existing seed image for '{char_name}' (seed={char_seed})")
image = Image.open(seed_image_path)
else:
print(f"[ImageGen] 🎨 Generating new character '{char_name}' (seed={char_seed})")
generator = torch.Generator(device).manual_seed(char_seed)
image = pipe(f"{char_name}, {char_desc}", num_inference_steps=30, generator=generator).images[0]
image.save(seed_image_path)
seed_to_char_image[char_seed] = image
# Free txt2img pipeline
unload_pipelines(target="pipe")
# Step 2: Load only img2img for keyframes
img2img_pipe = load_img2img_pipeline()
for key, scene_data in story_json.items():
if not key.startswith("scene"):
continue
for frame in scene_data.get("keyframes", []):
frame_seed = int(frame.get("seed", 0))
if frame_seed not in seed_to_char_image:
print(f"[WARN] Seed {frame_seed} not found in characters. Skipping keyframes...")
continue
init_image = seed_to_char_image[frame_seed]
for kf_key, kf_prompt in frame.items():
if kf_key.startswith("keyframe"):
print(f"[ImageGen] 🎬 Generating {key} β†’ {kf_key} using seed {frame_seed}")
generator = torch.Generator(device).manual_seed(frame_seed)
img = img2img_pipe(
prompt=kf_prompt,
image=init_image,
strength=0.55,
num_inference_steps=30,
generator=generator
).images[0]
out_path = TMP_DIR / f"{key}_{kf_key}_seed{frame_seed}.png"
img.save(out_path)
frame[kf_key] = pil_to_base64(img)
unload_pipelines(target="all") # unload both just in case
print("[ImageGen] βœ… Story JSON image generation complete.")
return story_json
# ----------------------------------------------------------
# CASE 2: NORMAL PROMPT
# ----------------------------------------------------------
print(f"[ImageGen] Generating {num_images} image(s) for prompt='{prompt_or_json}' seed={seed}")
pipe = load_pipeline()
images = []
for i in range(num_images):
gen = torch.Generator(device).manual_seed(seed + i) if seed is not None else None
try:
img = pipe(prompt_or_json, num_inference_steps=30, generator=gen).images[0]
img_path = TMP_DIR / f"prompt_{i}.png"
img.save(img_path)
images.append(pil_to_base64(img))
except Exception as e:
print(f"[ImageGen] ⚠️ Failed on image {i}: {e}")
unload_pipelines(target="pipe")
print(f"[ImageGen] βœ… Generated {len(images)} image(s) successfully.")
return images