File size: 4,729 Bytes
05d640e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Union, Tuple
from einops import rearrange
from PIL import Image
from .layers import attn, layer_norm, linear, mlp
from .image_crops import overlap_crop_image
from .config import VisionConfig
if torch.backends.mps.is_available():
# Non-divisible input sizes are not implemented on MPS device yet.
# https://github.com/pytorch/pytorch/issues/96056
def adaptive_avg_pool2d(input, output_size):
return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
else:
adaptive_avg_pool2d = F.adaptive_avg_pool2d
DeviceLike = Union[str, torch.device, int]
def prepare_crops(
image: Image.Image, config: VisionConfig, device: DeviceLike
) -> Tuple[torch.Tensor, Tuple[int, int]]:
np_image = np.array(image.convert("RGB"))
overlap_crops = overlap_crop_image(
np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
)
all_crops = overlap_crops["crops"]
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
all_crops = (
torch.from_numpy(all_crops)
.to(device=device, dtype=torch.float16)
.div_(255.0)
.sub_(0.5)
.div_(0.5)
)
return all_crops, overlap_crops["tiling"]
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
x = rearrange(
input_BCHW,
"b c (h p1) (w p2) -> b (h w) (c p1 p2)",
p1=config.enc_patch_size,
p2=config.enc_patch_size,
) # B3HW -> B(HxW)(3xP1xP2), aka BTC
x = linear(x, w.patch_emb)
x = x + w.pos_emb
for block in w.blocks:
x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
x = x + mlp(layer_norm(x, block.ln2), block.mlp)
x = layer_norm(x, w.post_ln)
return x
def vision_projection(
global_features: torch.Tensor,
reconstructed: torch.Tensor,
w: nn.Module,
config: VisionConfig,
):
reconstructed = reconstructed.permute(2, 0, 1)
reconstructed = adaptive_avg_pool2d(
reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
)
reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
final_features = torch.cat([global_features, reconstructed], dim=-1)
return mlp(final_features, w.proj_mlp)
def build_vision_model(config: VisionConfig, dtype: torch.dtype):
patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
grid_size = config.crop_size // config.enc_patch_size
num_patches = grid_size * grid_size
vision = nn.ModuleDict(
{
"patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(
config.enc_dim, 3 * config.enc_dim, dtype=dtype
),
"proj": nn.Linear(
config.enc_dim, config.enc_dim, dtype=dtype
),
}
),
"ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.enc_dim, config.enc_ff_dim, dtype=dtype
),
"fc2": nn.Linear(
config.enc_ff_dim, config.enc_dim, dtype=dtype
),
}
),
}
)
for _ in range(config.enc_n_layers)
]
),
"post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
"proj_mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
),
"fc2": nn.Linear(
config.proj_inner_dim, config.proj_out_dim, dtype=dtype
),
}
),
}
)
vision.pos_emb = nn.Parameter(
torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
)
return vision
|