|
from transformers import PretrainedConfig, PreTrainedModel, Pipeline
|
|
import torch
|
|
|
|
from BeamDiffusionModel.beamInference import beam_inference
|
|
|
|
class BeamDiffusionConfig(PretrainedConfig):
|
|
model_type = "beam_diffusion"
|
|
def __init__(self, latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
|
|
self.n_seeds = n_seeds
|
|
self.seeds = seeds if seeds else []
|
|
self.steps_back = steps_back
|
|
self.beam_width = beam_width
|
|
self.window_size = window_size
|
|
self.use_rand = use_rand
|
|
|
|
import torch.nn as nn
|
|
from huggingface_hub import ModelHubMixin
|
|
|
|
class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
|
|
config_class = BeamDiffusionConfig
|
|
model_type = "beam_diffusion"
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.dummy_param = nn.Parameter(torch.zeros(1))
|
|
|
|
def forward(self, input_data):
|
|
images = beam_inference(
|
|
steps=input_data.get('steps', []),
|
|
latents_idx=input_data.get("latents_idx", [0, 1, 2, 3]),
|
|
n_seeds=input_data.get("n_seeds", 4),
|
|
seeds=input_data.get("seeds", []),
|
|
steps_back=input_data.get("steps_back", 2),
|
|
beam_width=input_data.get("beam_width", 4),
|
|
window_size=input_data.get("window_size", 2),
|
|
use_rand=input_data.get("use_rand", True)
|
|
)
|
|
return {"images": images}
|
|
|
|
|
|
|
|
class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
|
|
def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
|
|
super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)
|
|
|
|
def __call__(self, inputs):
|
|
return self._forward(inputs)
|
|
|
|
def preprocess(self, inputs):
|
|
"""Converts raw input data into model-ready format."""
|
|
return inputs
|
|
|
|
def postprocess(self, model_outputs):
|
|
"""Processes model output into a user-friendly format."""
|
|
return model_outputs["images"]
|
|
|
|
def _sanitize_parameters(self, **kwargs):
|
|
"""Handles unused parameters gracefully."""
|
|
return {}, {}, {}
|
|
|
|
def _forward(self, model_inputs):
|
|
return self.model(model_inputs)
|
|
|