Spaces:
Runtime error
Runtime error
File size: 1,911 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from xformers.components import PatchEmbeddingConfig, build_patch_embedding
from xformers.components.positional_embedding import (
POSITION_EMBEDDING_REGISTRY,
build_positional_embedding,
)
BATCH = 20
SEQ = 512
MODEL = 384
assert (
POSITION_EMBEDDING_REGISTRY.keys()
), "Positional encoding layers should have been registered"
@pytest.mark.parametrize("encoding_name", POSITION_EMBEDDING_REGISTRY.keys())
@pytest.mark.parametrize("dropout", [0.0, 0.2])
def test_dimensions(encoding_name: str, dropout: float):
test_config = {
"name": encoding_name,
"dim_model": MODEL,
"vocab_size": 32,
"dropout": dropout,
"seq_len": SEQ,
}
# dummy, just check construction and dimensions in the FW pass
encoding = build_positional_embedding(test_config)
inputs = (torch.rand(BATCH, SEQ) * 10).abs().to(torch.int)
_ = encoding(inputs)
# Test that inputs having an embedding dimension would also work out
if "name" == "sine":
inputs = (torch.rand(BATCH, SEQ, MODEL) * 10).abs().to(torch.int)
_ = encoding(inputs)
def test_patch_embedding():
patch_embedding_config = {
"in_channels": 3,
"out_channels": 64,
"kernel_size": 7,
"stride": 4,
"padding": 2,
}
# dummy, just check construction and dimensions in the FW pass
patch_emb = build_patch_embedding(PatchEmbeddingConfig(**patch_embedding_config))
# Check BHWC
inputs = torch.rand(BATCH, 32 * 32, 3)
out = patch_emb(inputs)
assert out.shape[-1] == 64
# Check BCHW
inputs = torch.rand(BATCH, 3, 32, 32)
out = patch_emb(inputs)
assert out.shape[-1] == 64
|