File size: 2,921 Bytes
ebaff66
 
 
 
 
 
 
 
c811a04
cef1afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebaff66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from abc import ABC, abstractmethod
from typing import Tuple

import torch
from diffusers.configuration_utils import ConfigMixin
from einops import rearrange
from torch import Tensor

from xora.utils.torch_utils import append_dims


class Patchifier(ConfigMixin, ABC):
    def __init__(self, patch_size: int):
        super().__init__()
        self._patch_size = (1, patch_size, patch_size)

    @abstractmethod
    def patchify(self, latents: Tensor, frame_rates: Tensor, scale_grid: bool) -> Tuple[Tensor, Tensor]:
        pass

    @abstractmethod
    def unpatchify(
        self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int
    ) -> Tuple[Tensor, Tensor]:
        pass

    @property
    def patch_size(self):
        return self._patch_size

    def get_grid(self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device):
        f = orig_num_frames // self._patch_size[0]
        h = orig_height // self._patch_size[1]
        w = orig_width // self._patch_size[2]
        grid_h = torch.arange(h, dtype=torch.float32, device=device)
        grid_w = torch.arange(w, dtype=torch.float32, device=device)
        grid_f = torch.arange(f, dtype=torch.float32, device=device)
        grid = torch.meshgrid(grid_f, grid_h, grid_w)
        grid = torch.stack(grid, dim=0)
        grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)

        if scale_grid is not None:
            for i in range(3):
                if isinstance(scale_grid[i], Tensor):
                    scale = append_dims(scale_grid[i], grid.ndim - 1)
                else:
                    scale = scale_grid[i]
                grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]

        grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
        return grid


def pixart_alpha_patchify(
    latents: Tensor,
    patch_size: int,
) -> Tuple[Tensor, Tensor]:
    latents = rearrange(
        latents,
        "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
        p1=patch_size[0],
        p2=patch_size[1],
        p3=patch_size[2],
    )
    return latents

class SymmetricPatchifier(Patchifier):
    def patchify(
        self,
        latents: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        return pixart_alpha_patchify(latents, self._patch_size)

    def unpatchify(
        self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int
    ) -> Tuple[Tensor, Tensor]:
        output_height = output_height // self._patch_size[1]
        output_width = output_width // self._patch_size[2]
        latents = rearrange(
            latents,
            "b (f h w) (c p q) -> b c f (h p) (w q) ",
            f=output_num_frames,
            h=output_height,
            w=output_width,
            p=self._patch_size[1],
            q=self._patch_size[2],
        )
        return latents