jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import functools
import os
from typing import Any, Dict, List, Optional, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, GlmModel
import finetrainers.functional as FF
from finetrainers.data import ImageArtifact
from finetrainers.models.modeling_utils import ControlModelSpecification
from finetrainers.models.utils import DiagonalGaussianDistribution, _expand_linear_with_zeroed_weights
from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
from finetrainers.processors import CogView4GLMProcessor, ProcessorMixin
from finetrainers.typing import ArtifactType, SchedulerType
from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function
from .base_specification import CogView4LatentEncodeProcessor
class CogView4ControlModelSpecification(ControlModelSpecification):
def __init__(
self,
pretrained_model_name_or_path: str = "THUDM/CogView4-6B",
tokenizer_id: Optional[str] = None,
text_encoder_id: Optional[str] = None,
transformer_id: Optional[str] = None,
vae_id: Optional[str] = None,
text_encoder_dtype: torch.dtype = torch.bfloat16,
transformer_dtype: torch.dtype = torch.bfloat16,
vae_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
condition_model_processors: List[ProcessorMixin] = None,
latent_model_processors: List[ProcessorMixin] = None,
control_model_processors: List[ProcessorMixin] = None,
**kwargs,
) -> None:
super().__init__(
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer_id=tokenizer_id,
text_encoder_id=text_encoder_id,
transformer_id=transformer_id,
vae_id=vae_id,
text_encoder_dtype=text_encoder_dtype,
transformer_dtype=transformer_dtype,
vae_dtype=vae_dtype,
revision=revision,
cache_dir=cache_dir,
)
if condition_model_processors is None:
condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])]
if latent_model_processors is None:
latent_model_processors = [
CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"])
]
if control_model_processors is None:
control_model_processors = [
CogView4LatentEncodeProcessor(["control_latents", "original_size", "target_size", "crop_coords"])
]
self.condition_model_processors = condition_model_processors
self.latent_model_processors = latent_model_processors
self.control_model_processors = control_model_processors
@property
def control_injection_layer_name(self):
return "patch_embed.proj"
@property
def _resolution_dim_keys(self):
return {"latents": (2, 3)}
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(
self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs
)
if self.text_encoder_id is not None:
text_encoder = GlmModel.from_pretrained(
self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs
)
else:
text_encoder = GlmModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
**common_kwargs,
)
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
if self.vae_id is not None:
vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs)
else:
vae = AutoencoderKL.from_pretrained(
self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs
)
return {"vae": vae}
def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]:
common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir}
if self.transformer_id is not None:
transformer = CogView4Transformer2DModel.from_pretrained(
self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs
)
else:
transformer = CogView4Transformer2DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
**common_kwargs,
)
actual_new_in_features = new_in_features * transformer.config.patch_size**2
transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(
transformer.patch_embed.proj, new_in_features=actual_new_in_features
)
transformer.register_to_config(in_channels=new_in_features)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
def load_pipeline(
self,
tokenizer: Optional[AutoTokenizer] = None,
text_encoder: Optional[GlmModel] = None,
transformer: Optional[CogView4Transformer2DModel] = None,
vae: Optional[AutoencoderKL] = None,
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
training: bool = False,
**kwargs,
) -> CogView4Pipeline:
components = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"transformer": transformer,
"vae": vae,
# Load the scheduler based on CogView4's config instead of using the default initialization being used for training
# "scheduler": scheduler,
}
components = get_non_null_items(components)
pipe = CogView4Pipeline.from_pretrained(
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
)
pipe.text_encoder.to(self.text_encoder_dtype)
pipe.vae.to(self.vae_dtype)
_enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling)
if not training:
pipe.transformer.to(self.transformer_dtype)
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
return pipe
@torch.no_grad()
def prepare_conditions(
self,
tokenizer: AutoTokenizer,
text_encoder: GlmModel,
caption: str,
max_sequence_length: int = 1024,
**kwargs,
) -> Dict[str, Any]:
conditions = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"caption": caption,
"max_sequence_length": max_sequence_length,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_conditions(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
return conditions
@torch.no_grad()
def prepare_latents(
self,
vae: AutoencoderKL,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
control_image: Optional[torch.Tensor] = None,
control_video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
_original_height: Optional[int] = None,
_original_width: Optional[int] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
common_kwargs = {
"vae": vae,
"generator": generator,
"compute_posterior": compute_posterior,
"_original_height": _original_height,
"_original_width": _original_width,
**kwargs,
}
conditions = {"image": image, "video": video, **common_kwargs}
input_keys = set(conditions.keys())
conditions = ControlModelSpecification.prepare_latents(self, self.latent_model_processors, **conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
control_conditions = {"image": control_image, "video": control_video, **common_kwargs}
input_keys = set(control_conditions.keys())
control_conditions = ControlModelSpecification.prepare_latents(
self, self.control_model_processors, **control_conditions
)
control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys}
return {**control_conditions, **conditions}
def forward(
self,
transformer: CogView4Transformer2DModel,
condition_model_conditions: Dict[str, torch.Tensor],
latent_model_conditions: Dict[str, torch.Tensor],
sigmas: torch.Tensor,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
base_image_sequence_length = 256
base_shift = 0.25
max_shift = 0.75
if compute_posterior:
latents = latent_model_conditions.pop("latents")
control_latents = latent_model_conditions.pop("control_latents")
else:
posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
latents = posterior.sample(generator=generator)
del posterior
control_posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("control_latents"))
control_latents = control_posterior.sample(generator=generator)
del control_posterior
if getattr(self.vae_config, "shift_factor") is not None:
latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
noise = torch.zeros_like(latents).normal_(generator=generator)
timesteps = (sigmas.flatten() * 1000.0).long()
image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2
mu = (image_sequence_length / base_image_sequence_length) ** 0.5
mu = mu * max_shift + base_shift
shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0)
noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas)
noisy_latents = torch.cat([noisy_latents, control_latents], dim=1)
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
pred = transformer(
**latent_model_conditions,
**condition_model_conditions,
timestep=timesteps,
return_dict=False,
)[0]
target = FF.flow_match_target(noise, latents)
# NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation
# but let's keep it this way for now. Longer training runs should reveal more insights.
# return pred, target, sigmas
return pred, target, shifted_sigmas
def validation(
self,
pipeline: CogView4Pipeline,
prompt: str,
control_image: torch.Tensor,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
**kwargs,
) -> List[ArtifactType]:
with torch.no_grad():
dtype = pipeline.vae.dtype
device = pipeline._execution_device
in_channels = self.transformer_config.in_channels # We need to use the original in_channels
latents = pipeline.prepare_latents(1, in_channels, height, width, dtype, device, generator)
control_image = pipeline.image_processor.preprocess(control_image, height=height, width=width)
control_image = control_image.to(device=device, dtype=dtype)
control_latents = pipeline.vae.encode(control_image).latent_dist.sample(generator=generator)
if getattr(self.vae_config, "shift_factor") is not None:
control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
generation_kwargs = {
"latents": latents,
"prompt": prompt,
"height": height,
"width": width,
"num_inference_steps": num_inference_steps,
"generator": generator,
"return_dict": True,
"output_type": "pil",
}
generation_kwargs = get_non_null_items(generation_kwargs)
with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]):
image = pipeline(**generation_kwargs).images[0]
return [ImageArtifact(value=image)]
def _save_lora_weights(
self,
directory: str,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
norm_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
metadata: Optional[Dict[str, str]] = None,
*args,
**kwargs,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
CogView4Pipeline.save_lora_weights(
directory,
transformer_state_dict,
save_function=functools.partial(safetensors_torch_save_function, metadata=metadata),
safe_serialization=True,
)
if norm_state_dict is not None:
safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors"))
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
def _save_model(
self,
directory: str,
transformer: CogView4Transformer2DModel,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
with init_empty_weights():
transformer_copy = CogView4Transformer2DModel.from_config(transformer.config)
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
@property
def _original_control_layer_in_features(self):
return self.transformer_config.in_channels
@property
def _original_control_layer_out_features(self):
return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim
@property
def _qk_norm_identifiers(self):
return ["attn1.norm_q", "attn1.norm_k"]