new_LDM_hosting / server.py
Manjeet9812's picture
Update server.py
5403e77 verified
import os, math, numpy as np
from PIL import Image
import torch
from diffusers import AutoencoderKL
# ---- TensorFlow / Keras (kept, as requested) ----
import tensorflow as tf
from tensorflow import keras
from keras import layers
# ---- Gradio UI ----
import gradio as gr
# ---- CLIP (OpenAI repo) ----
import clip
# ---- Hugging Face Hub (to fetch your .h5) ----
from huggingface_hub import hf_hub_download
# ----------------------------
# Hub model location (YOUR PUBLIC REPO)
# ----------------------------
MODEL_REPO_ID = "Manjeet9812/LDM_trained_model"
MODEL_FILENAME = "text_to_image.h5"
# latent -> pixel scale for SD VAE
STD_LATENT_DEFAULT = 10.5
LATENT_H = 16 # your code uses 16x16x4 latents
LATENT_W = 16
LATENT_C = 4
# ----------------------------
# Utils (kept from your code)
# ----------------------------
def attention(qkv):
q, k, v = qkv
vector = tf.matmul(k, q, transpose_b=True)
score = tf.nn.softmax(vector)
o = tf.matmul(score, v)
return o
def spatial_attention(img):
filters = img.shape[3]
orig_shape = (img.shape[1], img.shape[2], img.shape[3])
img = layers.BatchNormalization()(img)
q = layers.Conv2D(filters // 8, 1, padding="same")(img)
k = layers.Conv2D(filters // 8, 1, padding="same")(img)
v = layers.Conv2D(filters, 1, padding="same")(img)
k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3]))(k)
q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3]))(v)
img = layers.Lambda(attention)([q, k, v])
img = layers.Reshape(orig_shape)(img)
img = layers.Conv2D(filters, 1, padding="same")(img)
img = layers.BatchNormalization()(img)
return img
def cross_attention(img, text):
filters = img.shape[3]
orig_shape = (img.shape[1], img.shape[2], img.shape[3])
img = layers.BatchNormalization()(img)
text = layers.BatchNormalization()(text)
q = layers.Conv2D(filters // 8, 1, padding="same")(text)
k = layers.Conv2D(filters // 8, 1, padding="same")(img)
v = layers.Conv2D(filters, 1, padding="same")(text)
q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3]))(k)
v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3]))(v)
img = layers.Lambda(attention)([q, k, v])
img = layers.Reshape(orig_shape)(img)
img = layers.Conv2D(filters, 1, padding="same")(img)
img = layers.BatchNormalization()(img)
return img
def sinusoidal_embedding(x):
embedding_min_frequency = 1.0
embedding_max_frequency = 1000.0
embedding_dims = 32
frequencies = tf.exp(
tf.linspace(
tf.math.log(embedding_min_frequency),
tf.math.log(embedding_max_frequency),
embedding_dims // 2,
)
)
angular_speeds = 2.0 * math.pi * frequencies
embeddings = tf.concat(
[tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
)
return embeddings
def dynamic_thresholding(img, perc=99.5):
s = np.percentile(np.abs(img.ravel()), perc)
s = np.max([s, 1])
img = img.clip(-s, s) / s
return img
class Diffuser:
def __init__(self, denoiser, class_guidance, diffusion_steps, perc_thresholding=99.5, batch_size=64):
self.denoiser = denoiser
self.class_guidance = class_guidance
self.diffusion_steps = diffusion_steps
self.noise_levels = 1 - np.power(np.arange(0.0001, 0.99, 1 / self.diffusion_steps), 1 / 3)
self.noise_levels[-1] = 0.01
self.perc_thresholding = perc_thresholding
self.batch_size = batch_size
def predict_x_zero(self, x_t, label, noise_level):
num_imgs = len(x_t)
label_empty_ohe = np.zeros(shape=label.shape)
noise_in = np.array([noise_level] * num_imgs)[:, None, None, None]
nn_inputs = [np.vstack([x_t, x_t]),
np.vstack([noise_in, noise_in]),
np.vstack([label, label_empty_ohe])]
x0_pred = self.denoiser.predict(nn_inputs, batch_size=self.batch_size, verbose=0)
x0_pred_label = x0_pred[:num_imgs]
x0_pred_no_label = x0_pred[num_imgs:]
x0_pred = self.class_guidance * x0_pred_label + (1 - self.class_guidance) * x0_pred_no_label
x0_pred = dynamic_thresholding(x0_pred, perc=self.perc_thresholding)
return x0_pred
def reverse_diffusion(self, seeds, label, show_img=False):
new_img = seeds
for _ in range(len(self.noise_levels) - 1):
curr_noise, next_noise = self.noise_levels[_], self.noise_levels[_ + 1]
x0_pred = self.predict_x_zero(new_img, label, curr_noise)
new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
return x0_pred
def decode_latents(latent_arr, std_latent=STD_LATENT_DEFAULT, batch_size=32, vae=None, torch_device="cpu"):
assert vae is not None
decoded_imgs = []
N = latent_arr.shape[0]
with torch.no_grad():
for start in range(0, N, batch_size):
end = min(start + batch_size, N)
encoded = torch.from_numpy(latent_arr[start:end] * float(std_latent)).permute(0, 3, 1, 2).to(torch_device, dtype=torch.float32)
decoded = vae.decode(encoded).sample
decoded_imgs.append(decoded.permute(0, 2, 3, 1).cpu().numpy())
return np.concatenate(decoded_imgs, axis=0)
def to_pil_list(imgs_np):
# imgs in [-1, 1] -> [0,255] uint8
imgs = np.clip((imgs_np + 1.0) * 127.5, 0, 255).astype("uint8")
out = []
for i in range(imgs.shape[0]):
out.append(Image.fromarray(imgs[i]))
return out
# ----------------------------
# Load models (on import)
# ----------------------------
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
# SD VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(torch_device)
vae.eval()
# Keras custom objects
custom_objects = {"sinusoidal_embedding": sinusoidal_embedding, "attention": attention}
# Download your UNet .h5 from the Hub and load it
try:
LOCAL_MODEL_PATH = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
except Exception as e:
raise RuntimeError(
f"Could not download {MODEL_FILENAME} from {MODEL_REPO_ID}. "
f"Error: {e}"
)
unet = keras.models.load_model(LOCAL_MODEL_PATH, custom_objects=custom_objects, compile=False)
# CLIP (kept as in your test code)
clip_model, _ = clip.load("ViT-B/32", device=torch_device)
# ----------------------------
# Gradio inference wrapper
# ----------------------------
def generate(
prompt: str,
num_images: int = 4,
diffusion_steps: int = 100,
class_guidance: float = 6.0,
perc_thresholding: float = 99.75,
std_latent: float = STD_LATENT_DEFAULT,
seed: int | None = None,
batch_size: int = 32,
):
if not prompt or prompt.strip() == "":
return []
# Seed
if seed is not None and int(seed) != 0:
np.random.seed(int(seed))
# Text -> CLIP embedding (repeat for batch)
with torch.no_grad():
text_tokens = clip.tokenize(prompt, truncate=False).to(torch_device)
text_encoding = clip_model.encode_text(text_tokens) # (1, D)
text_encoding = text_encoding.detach().float()
text_np = np.vstack([text_encoding.cpu().numpy()] * int(num_images)) # (N, D)
# Random latent seeds (NHWC)
rand_latents = np.random.normal(0, 1, (int(num_images), LATENT_H, LATENT_W, LATENT_C)).astype("float32")
# Run your diffuser (with your UNet)
diffuser = Diffuser(
denoiser=unet,
class_guidance=float(class_guidance),
diffusion_steps=int(diffusion_steps),
perc_thresholding=float(perc_thresholding),
batch_size=int(batch_size),
)
# TF GPU is usually not present on Spaces β†’ keep UNet predict on CPU
with tf.device("/CPU:0"):
imgs_latent = diffuser.reverse_diffusion(rand_latents, text_np)
# Decode with SD VAE
with torch.no_grad():
imgs = decode_latents(
imgs_latent,
std_latent=float(std_latent),
batch_size=int(batch_size),
vae=vae,
torch_device=torch_device
)
# Convert to PIL
return to_pil_list(imgs)
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks(title="LDM Host (Keras UNet + SD VAE + CLIP)") as demo:
gr.Markdown("## LDM Host β€” Text β†’ Image (kept same logic)")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Text prompt", value="beautiful landscape with river and flowers")
num_images = gr.Slider(1, 8, value=4, step=1, label="Number of images")
diffusion_steps = gr.Slider(10, 200, value=100, step=1, label="Diffusion steps")
class_guidance = gr.Slider(0.0, 10.0, value=6.0, step=0.25, label="Classifier-free guidance (blend)")
perc_thr = gr.Slider(90.0, 99.95, value=99.75, step=0.05, label="Dynamic thresholding percentile")
std_latent = gr.Slider(1.0, 20.0, value=STD_LATENT_DEFAULT, step=0.1, label="Latent std scale (SD VAE)")
batch_size = gr.Slider(1, 64, value=32, step=1, label="Batch size (UNet predict)")
seed = gr.Number(label="Seed (0 = random)", value=0, precision=0)
run = gr.Button("Generate")
with gr.Column(scale=1):
gallery = gr.Gallery(label="Outputs", columns=2, rows=2, height=560)
run.click(
fn=generate,
inputs=[prompt, num_images, diffusion_steps, class_guidance, perc_thr, std_latent, seed, batch_size],
outputs=[gallery],
api_name="generate"
)
if __name__ == "__main__":
demo.launch()