Plonk / pipe.py
nicolas-dufour's picture
initial commit
cf60dfb
import torch
import random
import string
from transformers import AutoTokenizer, T5EncoderModel
from models.pretrained_models import Plonk
from models.samplers.riemannian_flow_sampler import riemannian_flow_sampler
from models.postprocessing import CartesiantoGPS
from models.schedulers import (
SigmoidScheduler,
LinearScheduler,
CosineScheduler,
)
from models.preconditioning import DDPMPrecond
from torchvision import transforms
from transformers import CLIPProcessor, CLIPVisionModel
from utils.image_processing import CenterCrop
import numpy as np
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
MODELS = {
"nicolas-dufour/PLONK_YFCC": {"emb_name": "dinov2"},
"nicolas-dufour/PLONK_OSV_5M": {
"emb_name": "street_clip",
},
"nicolas-dufour/PLONK_iNaturalist": {
"emb_name": "dinov2",
},
}
def scheduler_fn(
scheduler_type: str, start: float, end: float, tau: float, clip_min: float = 1e-9
):
if scheduler_type == "sigmoid":
return SigmoidScheduler(start, end, tau, clip_min)
elif scheduler_type == "cosine":
return CosineScheduler(start, end, tau, clip_min)
elif scheduler_type == "linear":
return LinearScheduler(clip_min=clip_min)
else:
raise ValueError(f"Scheduler type {scheduler_type} not supported")
class DinoV2FeatureExtractor:
def __init__(self, device=device):
super().__init__()
self.device = device
self.emb_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg")
self.emb_model.eval()
self.emb_model.to(self.device)
self.augmentation = transforms.Compose(
[
CenterCrop(ratio="1:1"),
transforms.Resize(
336, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)
def __call__(self, batch):
embs = []
with torch.no_grad():
for img in batch["img"]:
emb = self.emb_model(
self.augmentation(img).unsqueeze(0).to(self.device)
).squeeze(0)
embs.append(emb)
batch["emb"] = torch.stack(embs)
return batch
class StreetClipFeatureExtractor:
def __init__(self, device=device):
self.device = device
self.emb_model = CLIPVisionModel.from_pretrained("geolocal/StreetCLIP").to(
device
)
self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
def __call__(self, batch):
inputs = self.processor(images=batch["img"], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.emb_model(**inputs)
embeddings = outputs.last_hidden_state[:, 0]
batch["emb"] = embeddings
return batch
def load_prepocessing(model_name, dtype=torch.float32):
if MODELS[model_name]["emb_name"] == "dinov2":
return DinoV2FeatureExtractor()
elif MODELS[model_name]["emb_name"] == "street_clip":
return StreetClipFeatureExtractor()
else:
raise ValueError(f"Embedding model {MODELS[model_name]['emb_name']} not found")
class PlonkPipeline:
"""
The CADT2IPipeline class is designed to facilitate the generation of images from text prompts using a pre-trained CAD model.
It integrates various components such as samplers, schedulers, and post-processing techniques to produce high-quality images.
Initialization:
CADT2IPipeline(
model_path,
sampler="ddim",
scheduler="sigmoid",
postprocessing="sd_1_5_vae",
scheduler_start=-3,
scheduler_end=3,
scheduler_tau=1.1,
device="cuda",
)
Parameters:
model_path (str): Path to the pre-trained CAD model.
sampler (str): The sampling method to use. Options are "ddim", "ddpm", "dpm", "dpm_2S", "dpm_2M". Default is "ddim".
scheduler (str): The scheduler type to use. Options are "sigmoid", "cosine", "linear". Default is "sigmoid".
postprocessing (str): The post-processing method to use. Options are "consistency-decoder", "sd_1_5_vae". Default is "sd_1_5_vae".
scheduler_start (float): Start value for the scheduler. Default is -3.
scheduler_end (float): End value for the scheduler. Default is 3.
scheduler_tau (float): Tau value for the scheduler. Default is 1.1.
device (str): Device to run the model on. Default is "cuda".
Methods:
model(*args, **kwargs):
Runs the preconditioning on the network with the provided arguments.
__call__(...):
Generates images based on the provided conditions and parameters.
Parameters:
cond (str or list of str): The conditioning text or list of texts.
num_samples (int, optional): Number of samples to generate. If not provided, it is inferred from cond.
x_N (torch.Tensor, optional): Initial noise tensor. If not provided, it is generated.
latents (torch.Tensor, optional): Previous latents.
num_steps (int, optional): Number of steps for the sampler. If not provided, the default is used.
sampler (callable, optional): Custom sampler function. If not provided, the default sampler is used.
scheduler (callable, optional): Custom scheduler function. If not provided, the default scheduler is used.
cfg (float): Classifier-free guidance scale. Default is 15.
guidance_type (str): Type of guidance. Default is "constant".
guidance_start_step (int): Step to start guidance. Default is 0.
generator (torch.Generator, optional): Random number generator.
coherence_value (float): Doherence value for sampling. Default is 1.0.
uncoherence_value (float): Uncoherence value for sampling. Default is 0.0.
unconfident_prompt (str, optional): Unconfident prompt text.
thresholding_type (str): Type of thresholding. Default is "clamp".
clamp_value (float): Clamp value for thresholding. Default is 1.0.
thresholding_percentile (float): Percentile for thresholding. Default is 0.995.
Returns:
torch.Tensor: The generated image tensor after post-processing.
to(device):
Moves the model and its components to the specified device.
Parameters:
device (str): The device to move the model to (e.g., "cuda", "cpu").
Returns:
CADT2IPipeline: The pipeline instance with updated device.
Example Usage:
pipe = CADT2IPipeline(
"nicolas-dufour/",
)
pipe.to("cuda")
image = pipe(
"a beautiful landscape with a river and mountains",
num_samples=4,
)
"""
def __init__(
self,
model_path,
scheduler="sigmoid",
scheduler_start=-7,
scheduler_end=3,
scheduler_tau=1.0,
device=device,
):
self.network = Plonk.from_pretrained(model_path).to(device)
self.network.requires_grad_(False).eval()
assert scheduler in [
"sigmoid",
"cosine",
"linear",
], f"Scheduler {scheduler} not supported"
self.scheduler = scheduler_fn(
scheduler, scheduler_start, scheduler_end, scheduler_tau
)
self.cond_preprocessing = load_prepocessing(model_name=model_path)
self.postprocessing = CartesiantoGPS()
self.sampler = riemannian_flow_sampler
self.model_path = model_path
self.preconditioning = DDPMPrecond()
self.device = device
def model(self, *args, **kwargs):
return self.preconditioning(self.network, *args, **kwargs)
def __call__(
self,
images,
batch_size=None,
x_N=None,
num_steps=None,
scheduler=None,
cfg=0,
generator=None,
callback=None,
):
"""Sample from the model given conditioning.
Args:
cond: Conditioning input (image or list of images)
batch_size: Number of samples to generate (inferred from cond if not provided)
x_N: Initial noise tensor (generated if not provided)
num_steps: Number of sampling steps (uses default if not provided)
sampler: Custom sampler function (uses default if not provided)
scheduler: Custom scheduler function (uses default if not provided)
cfg: Classifier-free guidance scale (default 15)
generator: Random number generator
callback: Optional callback function to report progress (step, total_steps)
Returns:
Sampled GPS coordinates after postprocessing
"""
# Set up batch size and initial noise
shape = [3]
if not isinstance(images, list):
images = [images]
if x_N is None:
if batch_size is None:
if isinstance(images, list):
batch_size = len(images)
else:
batch_size = 1
x_N = torch.randn(
batch_size, *shape, device=self.device, generator=generator
)
else:
x_N = x_N.to(self.device)
if x_N.ndim == 3:
x_N = x_N.unsqueeze(0)
batch_size = x_N.shape[0]
# Set up batch with conditioning
batch = {"y": x_N}
batch["img"] = images
batch = self.cond_preprocessing(batch)
if len(images) > 1:
assert len(images) == batch_size
else:
batch["emb"] = batch["emb"].repeat(batch_size, 1)
# Use default sampler/scheduler if not provided
sampler = self.sampler
if scheduler is None:
scheduler = self.scheduler
# Sample from model
if num_steps is None:
num_steps = 16 # Default number of steps
# Create a wrapper for the model that updates progress
def model_with_progress(*args, **kwargs):
step = kwargs.pop('current_step', 0)
if callback:
callback(step, num_steps)
return self.model(*args, **kwargs)
output = sampler(
model_with_progress,
batch,
conditioning_keys="emb",
scheduler=scheduler,
num_steps=num_steps,
cfg_rate=cfg,
generator=generator,
callback=callback,
)
# Apply postprocessing and return
output = self.postprocessing(output)
# To degrees
output = np.degrees(output.detach().cpu().numpy())
return output
def to(self, device):
self.network.to(device)
self.postprocessing.to(device)
self.device = torch.device(device)
return self