File size: 3,034 Bytes
173ea2b
 
 
 
90bb09b
 
173ea2b
 
 
90bb09b
173ea2b
90bb09b
 
 
173ea2b
 
 
 
 
 
 
 
90bb09b
 
 
 
 
 
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
90bb09b
173ea2b
90bb09b
 
 
 
 
 
 
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import PretrainedConfig, PreTrainedModel, Pipeline
import torch

from BeamDiffusionModel.beamInference import beam_inference
from BeamDiffusionModel.models.diffusionModel.StableDiffusion import StableDiffusion
from BeamDiffusionModel.models.diffusionModel.Flux import Flux
# Your custom configuration for the BeamDiffusion model
class BeamDiffusionConfig(PretrainedConfig):
    model_type = "beam_diffusion"
    def __init__(self, sd="SD-2.1",latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
        super().__init__(**kwargs)
        self.sd_name = sd
        self.sd = None
        self.get_model(sd)
        self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
        self.n_seeds = n_seeds
        self.seeds = seeds if seeds else []
        self.steps_back = steps_back
        self.beam_width = beam_width
        self.window_size = window_size
        self.use_rand = use_rand

    def get_model(self, sd):
        if self.sd_name == "flux":
            self.sd = Flux()
        elif self.sd_name == "SD-2.1":
            self.sd = StableDiffusion()

import torch.nn as nn
from huggingface_hub import ModelHubMixin
# Custom BeamDiffusionModel that performs inference for each step
class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
    config_class = BeamDiffusionConfig
    model_type = "beam_diffusion"

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.dummy_param = nn.Parameter(torch.zeros(1))  # Ensure at least one parameter

    def forward(self, input_data):
        images = beam_inference(
            self.config.sd,
            steps=input_data.get('steps', []),
            latents_idx=self.config.latents_idx,
            n_seeds=self.config.n_seeds,
            seeds=self.config.seeds,
            steps_back=self.config.steps_back,
            beam_width=self.config.beam_width,
            window_size=self.config.window_size,
            use_rand=self.config.use_rand,
        )
        return {"images": images}


# Custom pipeline to handle inference
class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
    def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
        super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)

    def __call__(self, inputs):
        return self._forward(inputs)

    def preprocess(self, inputs):
        """Converts raw input data into model-ready format."""
        return inputs  # Keep as-is

    def postprocess(self, model_outputs):
        """Processes model output into a user-friendly format."""
        return model_outputs["images"]  # Ensure this matches expected output

    def _sanitize_parameters(self, **kwargs):
        """Handles unused parameters gracefully."""
        return {}, {}, {}

    def _forward(self, model_inputs):
        return self.model(model_inputs)