| | from diffusers import AutoencoderKL |
| | from PIL import Image |
| | import io |
| | from transformers import CLIPProcessor, CLIPModel |
| | from model import Model |
| | from pathlib import Path |
| | from noise_scheduler import NoiseSchedule |
| | import torch |
| | import base64 |
| | from typing import Any, Dict |
| |
|
| | LDM = True |
| | image_size = 512 |
| | latent_size = 64 |
| | filters = [64, 128, 256, 512] |
| | latent_dim = 4 |
| | t_dim = 512 |
| | T = 1000 |
| | depth = 2 |
| |
|
| | class CLIP: |
| | def __init__(self): |
| | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| | self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
| | self.model.eval() |
| | for name, param in self.model.named_parameters(): |
| | param.requires_grad = False |
| |
|
| | @torch.inference_mode() |
| | def embed_images(self, images): |
| | image = self.processor(images=images, return_tensors="pt").to(self.model.device) |
| | return self.model.get_image_features(**image) |
| | |
| | @torch.inference_mode() |
| | def embed_text(self, text): |
| | text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device) |
| | return self.model.get_text_features(**text) |
| |
|
| | class Inference: |
| | def __init__(self): |
| | here = Path(__file__).resolve().parent |
| | ckpt_path = here / "unet.pt" |
| |
|
| | self.clip = CLIP() |
| | self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu") |
| | self.ae.eval() |
| | for name, param in self.ae.named_parameters(): |
| | param.requires_grad = False |
| | self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM) |
| | self.unet.load_state_dict(torch.load(ckpt_path, weights_only=False, map_location=torch.device('cpu'))) |
| | self.unet.eval() |
| | for name, param in self.unet.named_parameters(): |
| | param.requires_grad = False |
| | self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True) |
| | self.target_vector = self.clip.embed_text("A photo of a cat")[0] |
| | self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True) |
| | @torch.inference_mode() |
| | def __call__(self, num_images=8): |
| | imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu') |
| | max_img = None |
| | max_score = -1 |
| | images = [] |
| | for img in imgs: |
| | image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5 |
| | image = torch.clamp(image, 0.0, 1.0) |
| | images.append(image) |
| | embeddings = self.clip.embed_images(images) |
| | scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T |
| | i = torch.argmax(scores).item() |
| | return images[i], scores[i], scores |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | self.engine = Inference() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | img_tensor, score, _ = self.engine(num_images=1) |
| | img_uint8 = (img_tensor.clamp(0,1).numpy() * 255).astype("uint8") |
| | pil_img = Image.fromarray(img_uint8) |
| | |
| | buf = io.BytesIO() |
| | pil_img.save(buf, format="PNG") |
| | png_bytes = buf.getvalue() |
| |
|
| | b64 = base64.b64encode(png_bytes).decode("utf-8") |
| | return {"image": b64, "score": float(score)} |
| |
|
| |
|