Spaces:
Runtime error
Runtime error
import os, math, numpy as np | |
from PIL import Image | |
import torch | |
from diffusers import AutoencoderKL | |
# ---- TensorFlow / Keras (kept, as requested) ---- | |
import tensorflow as tf | |
from tensorflow import keras | |
from keras import layers | |
# ---- Gradio UI ---- | |
import gradio as gr | |
# ---- CLIP (OpenAI repo) ---- | |
import clip | |
# ---- Hugging Face Hub (to fetch your .h5) ---- | |
from huggingface_hub import hf_hub_download | |
# ---------------------------- | |
# Hub model location (YOUR PUBLIC REPO) | |
# ---------------------------- | |
MODEL_REPO_ID = "Manjeet9812/LDM_trained_model" | |
MODEL_FILENAME = "text_to_image.h5" | |
# latent -> pixel scale for SD VAE | |
STD_LATENT_DEFAULT = 10.5 | |
LATENT_H = 16 # your code uses 16x16x4 latents | |
LATENT_W = 16 | |
LATENT_C = 4 | |
# ---------------------------- | |
# Utils (kept from your code) | |
# ---------------------------- | |
def attention(qkv): | |
q, k, v = qkv | |
vector = tf.matmul(k, q, transpose_b=True) | |
score = tf.nn.softmax(vector) | |
o = tf.matmul(score, v) | |
return o | |
def spatial_attention(img): | |
filters = img.shape[3] | |
orig_shape = (img.shape[1], img.shape[2], img.shape[3]) | |
img = layers.BatchNormalization()(img) | |
q = layers.Conv2D(filters // 8, 1, padding="same")(img) | |
k = layers.Conv2D(filters // 8, 1, padding="same")(img) | |
v = layers.Conv2D(filters, 1, padding="same")(img) | |
k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3]))(k) | |
q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q) | |
v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3]))(v) | |
img = layers.Lambda(attention)([q, k, v]) | |
img = layers.Reshape(orig_shape)(img) | |
img = layers.Conv2D(filters, 1, padding="same")(img) | |
img = layers.BatchNormalization()(img) | |
return img | |
def cross_attention(img, text): | |
filters = img.shape[3] | |
orig_shape = (img.shape[1], img.shape[2], img.shape[3]) | |
img = layers.BatchNormalization()(img) | |
text = layers.BatchNormalization()(text) | |
q = layers.Conv2D(filters // 8, 1, padding="same")(text) | |
k = layers.Conv2D(filters // 8, 1, padding="same")(img) | |
v = layers.Conv2D(filters, 1, padding="same")(text) | |
q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q) | |
k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3]))(k) | |
v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3]))(v) | |
img = layers.Lambda(attention)([q, k, v]) | |
img = layers.Reshape(orig_shape)(img) | |
img = layers.Conv2D(filters, 1, padding="same")(img) | |
img = layers.BatchNormalization()(img) | |
return img | |
def sinusoidal_embedding(x): | |
embedding_min_frequency = 1.0 | |
embedding_max_frequency = 1000.0 | |
embedding_dims = 32 | |
frequencies = tf.exp( | |
tf.linspace( | |
tf.math.log(embedding_min_frequency), | |
tf.math.log(embedding_max_frequency), | |
embedding_dims // 2, | |
) | |
) | |
angular_speeds = 2.0 * math.pi * frequencies | |
embeddings = tf.concat( | |
[tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3 | |
) | |
return embeddings | |
def dynamic_thresholding(img, perc=99.5): | |
s = np.percentile(np.abs(img.ravel()), perc) | |
s = np.max([s, 1]) | |
img = img.clip(-s, s) / s | |
return img | |
class Diffuser: | |
def __init__(self, denoiser, class_guidance, diffusion_steps, perc_thresholding=99.5, batch_size=64): | |
self.denoiser = denoiser | |
self.class_guidance = class_guidance | |
self.diffusion_steps = diffusion_steps | |
self.noise_levels = 1 - np.power(np.arange(0.0001, 0.99, 1 / self.diffusion_steps), 1 / 3) | |
self.noise_levels[-1] = 0.01 | |
self.perc_thresholding = perc_thresholding | |
self.batch_size = batch_size | |
def predict_x_zero(self, x_t, label, noise_level): | |
num_imgs = len(x_t) | |
label_empty_ohe = np.zeros(shape=label.shape) | |
noise_in = np.array([noise_level] * num_imgs)[:, None, None, None] | |
nn_inputs = [np.vstack([x_t, x_t]), | |
np.vstack([noise_in, noise_in]), | |
np.vstack([label, label_empty_ohe])] | |
x0_pred = self.denoiser.predict(nn_inputs, batch_size=self.batch_size, verbose=0) | |
x0_pred_label = x0_pred[:num_imgs] | |
x0_pred_no_label = x0_pred[num_imgs:] | |
x0_pred = self.class_guidance * x0_pred_label + (1 - self.class_guidance) * x0_pred_no_label | |
x0_pred = dynamic_thresholding(x0_pred, perc=self.perc_thresholding) | |
return x0_pred | |
def reverse_diffusion(self, seeds, label, show_img=False): | |
new_img = seeds | |
for _ in range(len(self.noise_levels) - 1): | |
curr_noise, next_noise = self.noise_levels[_], self.noise_levels[_ + 1] | |
x0_pred = self.predict_x_zero(new_img, label, curr_noise) | |
new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise | |
return x0_pred | |
def decode_latents(latent_arr, std_latent=STD_LATENT_DEFAULT, batch_size=32, vae=None, torch_device="cpu"): | |
assert vae is not None | |
decoded_imgs = [] | |
N = latent_arr.shape[0] | |
with torch.no_grad(): | |
for start in range(0, N, batch_size): | |
end = min(start + batch_size, N) | |
encoded = torch.from_numpy(latent_arr[start:end] * float(std_latent)).permute(0, 3, 1, 2).to(torch_device, dtype=torch.float32) | |
decoded = vae.decode(encoded).sample | |
decoded_imgs.append(decoded.permute(0, 2, 3, 1).cpu().numpy()) | |
return np.concatenate(decoded_imgs, axis=0) | |
def to_pil_list(imgs_np): | |
# imgs in [-1, 1] -> [0,255] uint8 | |
imgs = np.clip((imgs_np + 1.0) * 127.5, 0, 255).astype("uint8") | |
out = [] | |
for i in range(imgs.shape[0]): | |
out.append(Image.fromarray(imgs[i])) | |
return out | |
# ---------------------------- | |
# Load models (on import) | |
# ---------------------------- | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
# SD VAE | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(torch_device) | |
vae.eval() | |
# Keras custom objects | |
custom_objects = {"sinusoidal_embedding": sinusoidal_embedding, "attention": attention} | |
# Download your UNet .h5 from the Hub and load it | |
try: | |
LOCAL_MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) | |
except Exception as e: | |
raise RuntimeError( | |
f"Could not download {MODEL_FILENAME} from {MODEL_REPO_ID}. " | |
f"Error: {e}" | |
) | |
unet = keras.models.load_model(LOCAL_MODEL_PATH, custom_objects=custom_objects, compile=False) | |
# CLIP (kept as in your test code) | |
clip_model, _ = clip.load("ViT-B/32", device=torch_device) | |
# ---------------------------- | |
# Gradio inference wrapper | |
# ---------------------------- | |
def generate( | |
prompt: str, | |
num_images: int = 4, | |
diffusion_steps: int = 100, | |
class_guidance: float = 6.0, | |
perc_thresholding: float = 99.75, | |
std_latent: float = STD_LATENT_DEFAULT, | |
seed: int | None = None, | |
batch_size: int = 32, | |
): | |
if not prompt or prompt.strip() == "": | |
return [] | |
# Seed | |
if seed is not None and int(seed) != 0: | |
np.random.seed(int(seed)) | |
# Text -> CLIP embedding (repeat for batch) | |
with torch.no_grad(): | |
text_tokens = clip.tokenize(prompt, truncate=False).to(torch_device) | |
text_encoding = clip_model.encode_text(text_tokens) # (1, D) | |
text_encoding = text_encoding.detach().float() | |
text_np = np.vstack([text_encoding.cpu().numpy()] * int(num_images)) # (N, D) | |
# Random latent seeds (NHWC) | |
rand_latents = np.random.normal(0, 1, (int(num_images), LATENT_H, LATENT_W, LATENT_C)).astype("float32") | |
# Run your diffuser (with your UNet) | |
diffuser = Diffuser( | |
denoiser=unet, | |
class_guidance=float(class_guidance), | |
diffusion_steps=int(diffusion_steps), | |
perc_thresholding=float(perc_thresholding), | |
batch_size=int(batch_size), | |
) | |
# TF GPU is usually not present on Spaces β keep UNet predict on CPU | |
with tf.device("/CPU:0"): | |
imgs_latent = diffuser.reverse_diffusion(rand_latents, text_np) | |
# Decode with SD VAE | |
with torch.no_grad(): | |
imgs = decode_latents( | |
imgs_latent, | |
std_latent=float(std_latent), | |
batch_size=int(batch_size), | |
vae=vae, | |
torch_device=torch_device | |
) | |
# Convert to PIL | |
return to_pil_list(imgs) | |
# ---------------------------- | |
# Gradio UI | |
# ---------------------------- | |
with gr.Blocks(title="LDM Host (Keras UNet + SD VAE + CLIP)") as demo: | |
gr.Markdown("## LDM Host β Text β Image (kept same logic)") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
prompt = gr.Textbox(label="Text prompt", value="beautiful landscape with river and flowers") | |
num_images = gr.Slider(1, 8, value=4, step=1, label="Number of images") | |
diffusion_steps = gr.Slider(10, 200, value=100, step=1, label="Diffusion steps") | |
class_guidance = gr.Slider(0.0, 10.0, value=6.0, step=0.25, label="Classifier-free guidance (blend)") | |
perc_thr = gr.Slider(90.0, 99.95, value=99.75, step=0.05, label="Dynamic thresholding percentile") | |
std_latent = gr.Slider(1.0, 20.0, value=STD_LATENT_DEFAULT, step=0.1, label="Latent std scale (SD VAE)") | |
batch_size = gr.Slider(1, 64, value=32, step=1, label="Batch size (UNet predict)") | |
seed = gr.Number(label="Seed (0 = random)", value=0, precision=0) | |
run = gr.Button("Generate") | |
with gr.Column(scale=1): | |
gallery = gr.Gallery(label="Outputs", columns=2, rows=2, height=560) | |
run.click( | |
fn=generate, | |
inputs=[prompt, num_images, diffusion_steps, class_guidance, perc_thr, std_latent, seed, batch_size], | |
outputs=[gallery], | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |