File size: 652 Bytes
05b4fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def batch_broadcast(a, x):
    """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""

    if len(a.shape) != 1:
        a = a.squeeze()
        if len(a.shape) != 1:
            raise ValueError(
                f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
            )

    if a.shape[0] != x.shape[0] and a.shape[0] != 1:
        raise ValueError(
            f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")

    out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
    return out