BeamDiffusion / beam_diffusion.py
Gui28F's picture
uploaded all project files
173ea2b verified
from transformers import PretrainedConfig, PreTrainedModel, Pipeline
import torch
from BeamDiffusionModel.beamInference import beam_inference
# Your custom configuration for the BeamDiffusion model
class BeamDiffusionConfig(PretrainedConfig):
model_type = "beam_diffusion"
def __init__(self, latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
super().__init__(**kwargs)
self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
self.n_seeds = n_seeds
self.seeds = seeds if seeds else []
self.steps_back = steps_back
self.beam_width = beam_width
self.window_size = window_size
self.use_rand = use_rand
import torch.nn as nn
from huggingface_hub import ModelHubMixin
# Custom BeamDiffusionModel that performs inference for each step
class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
config_class = BeamDiffusionConfig
model_type = "beam_diffusion"
def __init__(self, config):
super().__init__(config)
self.config = config
self.dummy_param = nn.Parameter(torch.zeros(1)) # Ensure at least one parameter
def forward(self, input_data):
images = beam_inference(
steps=input_data.get('steps', []),
latents_idx=input_data.get("latents_idx", [0, 1, 2, 3]),
n_seeds=input_data.get("n_seeds", 4),
seeds=input_data.get("seeds", []),
steps_back=input_data.get("steps_back", 2),
beam_width=input_data.get("beam_width", 4),
window_size=input_data.get("window_size", 2),
use_rand=input_data.get("use_rand", True)
)
return {"images": images}
# Custom pipeline to handle inference
class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)
def __call__(self, inputs):
return self._forward(inputs)
def preprocess(self, inputs):
"""Converts raw input data into model-ready format."""
return inputs # Keep as-is
def postprocess(self, model_outputs):
"""Processes model output into a user-friendly format."""
return model_outputs["images"] # Ensure this matches expected output
def _sanitize_parameters(self, **kwargs):
"""Handles unused parameters gracefully."""
return {}, {}, {}
def _forward(self, model_inputs):
return self.model(model_inputs)