File size: 4,440 Bytes
d16b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import sys


sys.path.append(
    os.path.join(
        os.path.dirname(__file__),
        "..",
    )
)

import torch
from config import Args
from PIL import Image
from pydantic import BaseModel, Field

from live2diff.utils.config import load_config
from live2diff.utils.wrapper import StreamAnimateDiffusionDepthWrapper


default_prompt = "masterpiece, best quality, felted, 1man with glasses, glasses, play with his pen"

page_content = """<h1 class="text-3xl font-bold">Live2Diff: </h1>
<h2 class="text-xl font-bold">Live Stream Translation via Uni-directional Attention in Video Diffusion Models</h2>
<p class="text-sm">
    This demo showcases
    <a
    href="https://github.com/open-mmlab/Live2Diff"
    target="_blank"
    class="text-blue-500 underline hover:no-underline">Live2Diff
</a>
pipeline using
    <a
    href="https://huggingface.co/latent-consistency/lcm-lora-sdv1-5"
    target="_blank"
    class="text-blue-500 underline hover:no-underline">LCM-LoRA</a
    > with a MJPEG stream server.
</p>
"""


WARMUP_FRAMES = 8
WINDOW_SIZE = 16


class Pipeline:
    class Info(BaseModel):
        name: str = "Live2Diff"
        input_mode: str = "image"
        page_content: str = page_content

    def build_input_params(self, default_prompt: str = default_prompt, width=512, height=512):
        class InputParams(BaseModel):
            prompt: str = Field(
                default_prompt,
                title="Prompt",
                field="textarea",
                id="prompt",
            )
            width: int = Field(
                512,
                min=2,
                max=15,
                title="Width",
                disabled=True,
                hide=True,
                id="width",
            )
            height: int = Field(
                512,
                min=2,
                max=15,
                title="Height",
                disabled=True,
                hide=True,
                id="height",
            )

        return InputParams

    def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
        config_path = args.config

        cfg = load_config(config_path)
        prompt = args.prompt or cfg.prompt or default_prompt

        self.InputParams = self.build_input_params(default_prompt=prompt)
        params = self.InputParams()

        num_inference_steps = args.num_inference_steps or cfg.get("num_inference_steps", None)
        strength = args.strength or cfg.get("strength", None)
        t_index_list = args.t_index_list or cfg.get("t_index_list", None)

        self.stream = StreamAnimateDiffusionDepthWrapper(
            few_step_model_type="lcm",
            config_path=config_path,
            cfg_type="none",
            strength=strength,
            num_inference_steps=num_inference_steps,
            t_index_list=t_index_list,
            frame_buffer_size=1,
            width=params.width,
            height=params.height,
            acceleration=args.acceleration,
            do_add_noise=True,
            output_type="pil",
            enable_similar_image_filter=True,
            similar_image_filter_threshold=0.98,
            use_denoising_batch=True,
            use_tiny_vae=True,
            seed=args.seed,
            engine_dir=args.engine_dir,
        )

        self.last_prompt = prompt

        self.warmup_frame_list = []
        self.has_prepared = False

    def predict(self, params: "Pipeline.InputParams") -> Image.Image:
        prompt = params.prompt
        if prompt != self.last_prompt:
            self.last_prompt = prompt
            self.warmup_frame_list.clear()

        if len(self.warmup_frame_list) < WARMUP_FRAMES:
            # from PIL import Image
            self.warmup_frame_list.append(self.stream.preprocess_image(params.image))

        elif len(self.warmup_frame_list) == WARMUP_FRAMES and not self.has_prepared:
            warmup_frames = torch.stack(self.warmup_frame_list)
            self.stream.prepare(
                warmup_frames=warmup_frames,
                prompt=prompt,
                guidance_scale=1,
            )
            self.has_prepared = True

        if self.has_prepared:
            image_tensor = self.stream.preprocess_image(params.image)
            output_image = self.stream(image=image_tensor)
            return output_image
        else:
            return Image.new("RGB", (params.width, params.height))