File size: 349 Bytes
1801c3b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

from pipeline.clip_wrapper import ClipWrapper


def test_ClipWrapper():
    clip_wrapper = ClipWrapper()

    images = [torch.rand(3, 224, 224) for _ in range(2)]
    assert clip_wrapper.images2vec(images).shape[-1] == 512

    texts = ["a photo of a cat", "a photo of a dog"]
    assert clip_wrapper.texts2vec(texts).shape[-1] == 512