File size: 2,662 Bytes
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import PretrainedConfig, PreTrainedModel, Pipeline
import torch

from BeamDiffusionModel.beamInference import beam_inference
# Your custom configuration for the BeamDiffusion model
class BeamDiffusionConfig(PretrainedConfig):
    model_type = "beam_diffusion"
    def __init__(self, latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
        super().__init__(**kwargs)
        self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
        self.n_seeds = n_seeds
        self.seeds = seeds if seeds else []
        self.steps_back = steps_back
        self.beam_width = beam_width
        self.window_size = window_size
        self.use_rand = use_rand

import torch.nn as nn
from huggingface_hub import ModelHubMixin
# Custom BeamDiffusionModel that performs inference for each step
class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
    config_class = BeamDiffusionConfig
    model_type = "beam_diffusion"

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.dummy_param = nn.Parameter(torch.zeros(1))  # Ensure at least one parameter

    def forward(self, input_data):
        images = beam_inference(
            steps=input_data.get('steps', []),
            latents_idx=input_data.get("latents_idx", [0, 1, 2, 3]),
            n_seeds=input_data.get("n_seeds", 4),
            seeds=input_data.get("seeds", []),
            steps_back=input_data.get("steps_back", 2),
            beam_width=input_data.get("beam_width", 4),
            window_size=input_data.get("window_size", 2),
            use_rand=input_data.get("use_rand", True)
        )
        return {"images": images}


# Custom pipeline to handle inference
class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
    def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
        super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)

    def __call__(self, inputs):
        return self._forward(inputs)

    def preprocess(self, inputs):
        """Converts raw input data into model-ready format."""
        return inputs  # Keep as-is

    def postprocess(self, model_outputs):
        """Processes model output into a user-friendly format."""
        return model_outputs["images"]  # Ensure this matches expected output

    def _sanitize_parameters(self, **kwargs):
        """Handles unused parameters gracefully."""
        return {}, {}, {}

    def _forward(self, model_inputs):
        return self.model(model_inputs)