|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from numpy.random import RandomState | 
					
						
						|  | from scipy.stats import chi | 
					
						
						|  | from torch.autograd import Variable | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def q_normalize(input, channel=1): | 
					
						
						|  | r = get_r(input) | 
					
						
						|  | i = get_i(input) | 
					
						
						|  | j = get_j(input) | 
					
						
						|  | k = get_k(input) | 
					
						
						|  |  | 
					
						
						|  | norm = torch.sqrt(r*r + i*i + j*j + k*k + 0.0001) | 
					
						
						|  | r = r / norm | 
					
						
						|  | i = i / norm | 
					
						
						|  | j = j / norm | 
					
						
						|  | k = k / norm | 
					
						
						|  |  | 
					
						
						|  | return torch.cat([r, i, j, k], dim=channel) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_input(input): | 
					
						
						|  | if input.dim() not in {2, 3, 4, 5}: | 
					
						
						|  | raise RuntimeError( | 
					
						
						|  | 'Quaternion linear accepts only input of dimension 2 or 3. Quaternion conv accepts up to 5 dim ' | 
					
						
						|  | ' input.dim = ' + str(input.dim()) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() < 4: | 
					
						
						|  | nb_hidden = input.size()[-1] | 
					
						
						|  | else: | 
					
						
						|  | nb_hidden = input.size()[1] | 
					
						
						|  |  | 
					
						
						|  | if nb_hidden % 4 != 0: | 
					
						
						|  | raise RuntimeError( | 
					
						
						|  | 'Quaternion Tensors must be divisible by 4.' | 
					
						
						|  | ' input.size()[1] = ' + str(nb_hidden) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_r(input): | 
					
						
						|  | check_input(input) | 
					
						
						|  | if input.dim() < 4: | 
					
						
						|  | nb_hidden = input.size()[-1] | 
					
						
						|  | else: | 
					
						
						|  | nb_hidden = input.size()[1] | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | return input.narrow(1, 0, nb_hidden // 4) | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | return input.narrow(2, 0, nb_hidden // 4) | 
					
						
						|  | if input.dim() >= 4: | 
					
						
						|  | return input.narrow(1, 0, nb_hidden // 4) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_i(input): | 
					
						
						|  | if input.dim() < 4: | 
					
						
						|  | nb_hidden = input.size()[-1] | 
					
						
						|  | else: | 
					
						
						|  | nb_hidden = input.size()[1] | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | return input.narrow(1, nb_hidden // 4, nb_hidden // 4) | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | return input.narrow(2, nb_hidden // 4, nb_hidden // 4) | 
					
						
						|  | if input.dim() >= 4: | 
					
						
						|  | return input.narrow(1, nb_hidden // 4, nb_hidden // 4) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_j(input): | 
					
						
						|  | check_input(input) | 
					
						
						|  | if input.dim() < 4: | 
					
						
						|  | nb_hidden = input.size()[-1] | 
					
						
						|  | else: | 
					
						
						|  | nb_hidden = input.size()[1] | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | return input.narrow(1, nb_hidden // 2, nb_hidden // 4) | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | return input.narrow(2, nb_hidden // 2, nb_hidden // 4) | 
					
						
						|  | if input.dim() >= 4: | 
					
						
						|  | return input.narrow(1, nb_hidden // 2, nb_hidden // 4) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_k(input): | 
					
						
						|  | check_input(input) | 
					
						
						|  | if input.dim() < 4: | 
					
						
						|  | nb_hidden = input.size()[-1] | 
					
						
						|  | else: | 
					
						
						|  | nb_hidden = input.size()[1] | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | return input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4) | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | return input.narrow(2, nb_hidden - nb_hidden // 4, nb_hidden // 4) | 
					
						
						|  | if input.dim() >= 4: | 
					
						
						|  | return input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_modulus(input, vector_form=False): | 
					
						
						|  | check_input(input) | 
					
						
						|  | r = get_r(input) | 
					
						
						|  | i = get_i(input) | 
					
						
						|  | j = get_j(input) | 
					
						
						|  | k = get_k(input) | 
					
						
						|  | if vector_form: | 
					
						
						|  | return torch.sqrt(r * r + i * i + j * j + k * k) | 
					
						
						|  | else: | 
					
						
						|  | return torch.sqrt((r * r + i * i + j * j + k * k).sum(dim=0)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_normalized(input, eps=0.0001): | 
					
						
						|  | check_input(input) | 
					
						
						|  | data_modulus = get_modulus(input) | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | data_modulus_repeated = data_modulus.repeat(1, 4) | 
					
						
						|  | elif input.dim() == 3: | 
					
						
						|  | data_modulus_repeated = data_modulus.repeat(1, 1, 4) | 
					
						
						|  | return input / (data_modulus_repeated.expand_as(input) + eps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_exp(input): | 
					
						
						|  | r = get_r(input) | 
					
						
						|  | i = get_i(input) | 
					
						
						|  | j = get_j(input) | 
					
						
						|  | k = get_k(input) | 
					
						
						|  |  | 
					
						
						|  | norm_v = torch.sqrt(i*i+j*j+k*k) + 0.0001 | 
					
						
						|  | exp = torch.exp(r) | 
					
						
						|  |  | 
					
						
						|  | r = torch.cos(norm_v) | 
					
						
						|  | i = (i / norm_v) * torch.sin(norm_v) | 
					
						
						|  | j = (j / norm_v) * torch.sin(norm_v) | 
					
						
						|  | k = (k / norm_v) * torch.sin(norm_v) | 
					
						
						|  |  | 
					
						
						|  | return torch.cat([exp*r, exp*i, exp*j, exp*k], dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def kronecker_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride, | 
					
						
						|  | padding, groups, dilatation, learn_A, cuda, first_layer=False): | 
					
						
						|  |  | 
					
						
						|  | """Applies a quaternion convolution to the incoming data:""" | 
					
						
						|  |  | 
					
						
						|  | if first_layer: | 
					
						
						|  | mat1 = torch.zeros((4, 4), requires_grad=False).view(4, 4, 1, 1) | 
					
						
						|  | else: | 
					
						
						|  | mat1 = torch.eye(4, requires_grad=False).view(4, 4, 1, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mat2 = torch.tensor([[0, -1, 0, 0], | 
					
						
						|  | [1, 0, 0, 0], | 
					
						
						|  | [0, 0, 0, -1], | 
					
						
						|  | [0, 0, 1, 0]], requires_grad=False).view(4, 4, 1, 1) | 
					
						
						|  | mat3 = torch.tensor([[0, 0, -1, 0], | 
					
						
						|  | [0, 0, 0, 1], | 
					
						
						|  | [1, 0, 0, 0], | 
					
						
						|  | [0, -1, 0, 0]], requires_grad=False).view(4, 4, 1, 1) | 
					
						
						|  | mat4 = torch.tensor([[0, 0, 0, -1], | 
					
						
						|  | [0, 0, -1, 0], | 
					
						
						|  | [0, 1, 0, 0], | 
					
						
						|  | [1, 0, 0, 0]], requires_grad=False).view(4, 4, 1, 1) | 
					
						
						|  |  | 
					
						
						|  | if cuda: | 
					
						
						|  | mat1, mat2, mat3, mat4 = mat1.cuda(), mat2.cuda(), mat3.cuda(), mat4.cuda() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cat_kernels_4_quaternion = torch.kron(mat1, r_weight) + \ | 
					
						
						|  | torch.kron(mat2, i_weight) + \ | 
					
						
						|  | torch.kron(mat3, j_weight) + \ | 
					
						
						|  | torch.kron(mat4, k_weight) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | convfunc = F.conv1d | 
					
						
						|  | elif input.dim() == 4: | 
					
						
						|  | convfunc = F.conv2d | 
					
						
						|  | elif input.dim() == 5: | 
					
						
						|  | convfunc = F.conv3d | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('The convolutional input is either 3, 4 or 5 dimensions.' | 
					
						
						|  | ' input.dim = ' + str(input.dim())) | 
					
						
						|  |  | 
					
						
						|  | return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, dilatation, groups) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride, | 
					
						
						|  | padding, groups, dilatation): | 
					
						
						|  | """Applies a quaternion convolution to the incoming data:""" | 
					
						
						|  |  | 
					
						
						|  | cat_kernels_4_r = torch.cat( | 
					
						
						|  | [r_weight, -i_weight, -j_weight, -k_weight], dim=1) | 
					
						
						|  | cat_kernels_4_i = torch.cat( | 
					
						
						|  | [i_weight,  r_weight, -k_weight, j_weight], dim=1) | 
					
						
						|  | cat_kernels_4_j = torch.cat( | 
					
						
						|  | [j_weight,  k_weight, r_weight, -i_weight], dim=1) | 
					
						
						|  | cat_kernels_4_k = torch.cat( | 
					
						
						|  | [k_weight,  -j_weight, i_weight, r_weight], dim=1) | 
					
						
						|  |  | 
					
						
						|  | cat_kernels_4_quaternion = torch.cat( | 
					
						
						|  | [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | convfunc = F.conv1d | 
					
						
						|  | elif input.dim() == 4: | 
					
						
						|  | convfunc = F.conv2d | 
					
						
						|  | elif input.dim() == 5: | 
					
						
						|  | convfunc = F.conv3d | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('The convolutional input is either 3, 4 or 5 dimensions.' | 
					
						
						|  | ' input.dim = ' + str(input.dim())) | 
					
						
						|  |  | 
					
						
						|  | return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, dilatation, groups) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_transpose_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride, | 
					
						
						|  | padding, output_padding, groups, dilatation): | 
					
						
						|  | """Applies a quaternion transposed convolution to the incoming data:""" | 
					
						
						|  |  | 
					
						
						|  | cat_kernels_4_r = torch.cat( | 
					
						
						|  | [r_weight, -i_weight, -j_weight, -k_weight], dim=1) | 
					
						
						|  | cat_kernels_4_i = torch.cat( | 
					
						
						|  | [i_weight,  r_weight, -k_weight, j_weight], dim=1) | 
					
						
						|  | cat_kernels_4_j = torch.cat( | 
					
						
						|  | [j_weight,  k_weight, r_weight, -i_weight], dim=1) | 
					
						
						|  | cat_kernels_4_k = torch.cat( | 
					
						
						|  | [k_weight,  -j_weight, i_weight, r_weight], dim=1) | 
					
						
						|  | cat_kernels_4_quaternion = torch.cat( | 
					
						
						|  | [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | convfunc = F.conv_transpose1d | 
					
						
						|  | elif input.dim() == 4: | 
					
						
						|  | convfunc = F.conv_transpose2d | 
					
						
						|  | elif input.dim() == 5: | 
					
						
						|  | convfunc = F.conv_transpose3d | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('The convolutional input is either 3, 4 or 5 dimensions.' | 
					
						
						|  | ' input.dim = ' + str(input.dim())) | 
					
						
						|  |  | 
					
						
						|  | return convfunc(input, cat_kernels_4_quaternion, | 
					
						
						|  | bias, stride, padding, output_padding, groups, dilatation) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_conv_rotation(input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias, stride, | 
					
						
						|  | padding, groups, dilatation, quaternion_format, scale=None): | 
					
						
						|  | """Applies a quaternion rotation and convolution transformation to the incoming data: | 
					
						
						|  |  | 
					
						
						|  | The rotation W*x*W^t can be replaced by R*x following: | 
					
						
						|  | https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation | 
					
						
						|  |  | 
					
						
						|  | Works for unitary and non unitary weights. | 
					
						
						|  |  | 
					
						
						|  | The initial size of the input must be a multiple of 3 if quaternion_format = False and | 
					
						
						|  | 4 if quaternion_format = True. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | square_r = (r_weight*r_weight) | 
					
						
						|  | square_i = (i_weight*i_weight) | 
					
						
						|  | square_j = (j_weight*j_weight) | 
					
						
						|  | square_k = (k_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | r_n_weight = (r_weight / norm) | 
					
						
						|  | i_n_weight = (i_weight / norm) | 
					
						
						|  | j_n_weight = (j_weight / norm) | 
					
						
						|  | k_n_weight = (k_weight / norm) | 
					
						
						|  |  | 
					
						
						|  | norm_factor = 2.0 | 
					
						
						|  |  | 
					
						
						|  | square_i = norm_factor*(i_n_weight*i_n_weight) | 
					
						
						|  | square_j = norm_factor*(j_n_weight*j_n_weight) | 
					
						
						|  | square_k = norm_factor*(k_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | ri = (norm_factor*r_n_weight*i_n_weight) | 
					
						
						|  | rj = (norm_factor*r_n_weight*j_n_weight) | 
					
						
						|  | rk = (norm_factor*r_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | ij = (norm_factor*i_n_weight*j_n_weight) | 
					
						
						|  | ik = (norm_factor*i_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | jk = (norm_factor*j_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | if quaternion_format: | 
					
						
						|  | if scale is not None: | 
					
						
						|  | rot_kernel_1 = torch.cat([zero_kernel, scale * (1.0 - (square_j + square_k)), | 
					
						
						|  | scale * (ij-rk), scale * (ik+rj)], dim=1) | 
					
						
						|  | rot_kernel_2 = torch.cat([zero_kernel, scale * (ij+rk), scale * | 
					
						
						|  | (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=1) | 
					
						
						|  | rot_kernel_3 = torch.cat([zero_kernel, scale * (ik-rj), scale * (jk+ri), | 
					
						
						|  | scale * (1.0 - (square_i + square_j))], dim=1) | 
					
						
						|  | else: | 
					
						
						|  | rot_kernel_1 = torch.cat( | 
					
						
						|  | [zero_kernel, (1.0 - (square_j + square_k)), (ij-rk), (ik+rj)], dim=1) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [zero_kernel, (ij+rk), (1.0 - (square_i + square_k)), (jk-ri)], dim=1) | 
					
						
						|  | rot_kernel_3 = torch.cat( | 
					
						
						|  | [zero_kernel, (ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=1) | 
					
						
						|  |  | 
					
						
						|  | zero_kernel2 = torch.cat( | 
					
						
						|  | [zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1) | 
					
						
						|  | global_rot_kernel = torch.cat( | 
					
						
						|  | [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | if scale is not None: | 
					
						
						|  | rot_kernel_1 = torch.cat([scale * (1.0 - (square_j + square_k)), | 
					
						
						|  | scale * (ij-rk), scale * (ik+rj)], dim=0) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [scale * (ij+rk), scale * (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0) | 
					
						
						|  | rot_kernel_3 = torch.cat([scale * (ik-rj), scale * (jk+ri), scale * | 
					
						
						|  | (1.0 - (square_i + square_j))], dim=0) | 
					
						
						|  | else: | 
					
						
						|  | rot_kernel_1 = torch.cat( | 
					
						
						|  | [1.0 - (square_j + square_k), (ij-rk), (ik+rj)], dim=0) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [(ij+rk), 1.0 - (square_i + square_k), (jk-ri)], dim=0) | 
					
						
						|  | rot_kernel_3 = torch.cat( | 
					
						
						|  | [(ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0) | 
					
						
						|  |  | 
					
						
						|  | global_rot_kernel = torch.cat( | 
					
						
						|  | [rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | convfunc = F.conv1d | 
					
						
						|  | elif input.dim() == 4: | 
					
						
						|  | convfunc = F.conv2d | 
					
						
						|  | elif input.dim() == 5: | 
					
						
						|  | convfunc = F.conv3d | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('The convolutional input is either 3, 4 or 5 dimensions.' | 
					
						
						|  | ' input.dim = ' + str(input.dim())) | 
					
						
						|  |  | 
					
						
						|  | return convfunc(input, global_rot_kernel, bias, stride, padding, dilatation, groups) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_transpose_conv_rotation( | 
					
						
						|  | input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias, stride, | 
					
						
						|  | padding, output_padding, groups, dilatation, quaternion_format): | 
					
						
						|  | """Applies a quaternion rotation and transposed convolution transformation to the incoming data: | 
					
						
						|  |  | 
					
						
						|  | The rotation W*x*W^t can be replaced by R*x following: | 
					
						
						|  | https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation | 
					
						
						|  |  | 
					
						
						|  | Works for unitary and non unitary weights. | 
					
						
						|  |  | 
					
						
						|  | The initial size of the input must be a multiple of 3 if quaternion_format = False and | 
					
						
						|  | 4 if quaternion_format = True. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | square_r = (r_weight*r_weight) | 
					
						
						|  | square_i = (i_weight*i_weight) | 
					
						
						|  | square_j = (j_weight*j_weight) | 
					
						
						|  | square_k = (k_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001) | 
					
						
						|  |  | 
					
						
						|  | r_weight = (r_weight / norm) | 
					
						
						|  | i_weight = (i_weight / norm) | 
					
						
						|  | j_weight = (j_weight / norm) | 
					
						
						|  | k_weight = (k_weight / norm) | 
					
						
						|  |  | 
					
						
						|  | norm_factor = 2.0 | 
					
						
						|  |  | 
					
						
						|  | square_i = norm_factor*(i_weight*i_weight) | 
					
						
						|  | square_j = norm_factor*(j_weight*j_weight) | 
					
						
						|  | square_k = norm_factor*(k_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | ri = (norm_factor*r_weight*i_weight) | 
					
						
						|  | rj = (norm_factor*r_weight*j_weight) | 
					
						
						|  | rk = (norm_factor*r_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | ij = (norm_factor*i_weight*j_weight) | 
					
						
						|  | ik = (norm_factor*i_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | jk = (norm_factor*j_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | if quaternion_format: | 
					
						
						|  | rot_kernel_1 = torch.cat( | 
					
						
						|  | [zero_kernel, 1.0 - (square_j + square_k), ij-rk, ik+rj], dim=1) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [zero_kernel, ij+rk, 1.0 - (square_i + square_k), jk-ri], dim=1) | 
					
						
						|  | rot_kernel_3 = torch.cat( | 
					
						
						|  | [zero_kernel, ik-rj, jk+ri, 1.0 - (square_i + square_j)], dim=1) | 
					
						
						|  |  | 
					
						
						|  | zero_kernel2 = torch.zeros(rot_kernel_1.shape).cuda() | 
					
						
						|  | global_rot_kernel = torch.cat( | 
					
						
						|  | [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0) | 
					
						
						|  | else: | 
					
						
						|  | rot_kernel_1 = torch.cat( | 
					
						
						|  | [1.0 - (square_j + square_k), ij-rk, ik+rj], dim=1) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [ij+rk, 1.0 - (square_i + square_k), jk-ri], dim=1) | 
					
						
						|  | rot_kernel_3 = torch.cat( | 
					
						
						|  | [ik-rj, jk+ri, 1.0 - (square_i + square_j)], dim=1) | 
					
						
						|  | global_rot_kernel = torch.cat( | 
					
						
						|  | [rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 3: | 
					
						
						|  | convfunc = F.conv_transpose1d | 
					
						
						|  | elif input.dim() == 4: | 
					
						
						|  | convfunc = F.conv_transpose2d | 
					
						
						|  | elif input.dim() == 5: | 
					
						
						|  | convfunc = F.conv_transpose3d | 
					
						
						|  | else: | 
					
						
						|  | raise Exception('The convolutional input is either 3, 4 or 5 dimensions.' | 
					
						
						|  | ' input.dim = ' + str(input.dim())) | 
					
						
						|  |  | 
					
						
						|  | return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, output_padding, groups, dilatation) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_linear(input, r_weight, i_weight, j_weight, k_weight, bias=True): | 
					
						
						|  | """Applies a quaternion linear transformation to the incoming data: | 
					
						
						|  |  | 
					
						
						|  | It is important to notice that the forward phase of a QNN is defined | 
					
						
						|  | as W * Inputs (with * equal to the Hamilton product). The constructed | 
					
						
						|  | cat_kernels_4_quaternion is a modified version of the quaternion representation | 
					
						
						|  | so when we do torch.mm(Input,W) it's equivalent to W * Inputs. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | cat_kernels_4_r = torch.cat( | 
					
						
						|  | [r_weight, -i_weight, -j_weight, -k_weight], dim=0) | 
					
						
						|  | cat_kernels_4_i = torch.cat( | 
					
						
						|  | [i_weight,  r_weight, -k_weight, j_weight], dim=0) | 
					
						
						|  | cat_kernels_4_j = torch.cat( | 
					
						
						|  | [j_weight,  k_weight, r_weight, -i_weight], dim=0) | 
					
						
						|  | cat_kernels_4_k = torch.cat( | 
					
						
						|  | [k_weight,  -j_weight, i_weight, r_weight], dim=0) | 
					
						
						|  | cat_kernels_4_quaternion = torch.cat( | 
					
						
						|  | [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  |  | 
					
						
						|  | if bias is not None: | 
					
						
						|  | return torch.addmm(bias, input, cat_kernels_4_quaternion) | 
					
						
						|  | else: | 
					
						
						|  | return torch.mm(input, cat_kernels_4_quaternion) | 
					
						
						|  | else: | 
					
						
						|  | output = torch.matmul(input, cat_kernels_4_quaternion) | 
					
						
						|  | if bias is not None: | 
					
						
						|  | return output+bias | 
					
						
						|  | else: | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_linear_rotation(input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias=None, | 
					
						
						|  | quaternion_format=False, scale=None): | 
					
						
						|  | """Applies a quaternion rotation transformation to the incoming data: | 
					
						
						|  |  | 
					
						
						|  | The rotation W*x*W^t can be replaced by R*x following: | 
					
						
						|  | https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation | 
					
						
						|  |  | 
					
						
						|  | Works for unitary and non unitary weights. | 
					
						
						|  |  | 
					
						
						|  | The initial size of the input must be a multiple of 3 if quaternion_format = False and | 
					
						
						|  | 4 if quaternion_format = True. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | square_r = (r_weight*r_weight) | 
					
						
						|  | square_i = (i_weight*i_weight) | 
					
						
						|  | square_j = (j_weight*j_weight) | 
					
						
						|  | square_k = (k_weight*k_weight) | 
					
						
						|  |  | 
					
						
						|  | norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001) | 
					
						
						|  |  | 
					
						
						|  | r_n_weight = (r_weight / norm) | 
					
						
						|  | i_n_weight = (i_weight / norm) | 
					
						
						|  | j_n_weight = (j_weight / norm) | 
					
						
						|  | k_n_weight = (k_weight / norm) | 
					
						
						|  |  | 
					
						
						|  | norm_factor = 2.0 | 
					
						
						|  |  | 
					
						
						|  | square_i = norm_factor*(i_n_weight*i_n_weight) | 
					
						
						|  | square_j = norm_factor*(j_n_weight*j_n_weight) | 
					
						
						|  | square_k = norm_factor*(k_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | ri = (norm_factor*r_n_weight*i_n_weight) | 
					
						
						|  | rj = (norm_factor*r_n_weight*j_n_weight) | 
					
						
						|  | rk = (norm_factor*r_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | ij = (norm_factor*i_n_weight*j_n_weight) | 
					
						
						|  | ik = (norm_factor*i_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | jk = (norm_factor*j_n_weight*k_n_weight) | 
					
						
						|  |  | 
					
						
						|  | if quaternion_format: | 
					
						
						|  | if scale is not None: | 
					
						
						|  | rot_kernel_1 = torch.cat([zero_kernel, scale * (1.0 - (square_j + square_k)), | 
					
						
						|  | scale * (ij-rk), scale * (ik+rj)], dim=0) | 
					
						
						|  | rot_kernel_2 = torch.cat([zero_kernel, scale * (ij+rk), scale * | 
					
						
						|  | (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0) | 
					
						
						|  | rot_kernel_3 = torch.cat([zero_kernel, scale * (ik-rj), scale * (jk+ri), | 
					
						
						|  | scale * (1.0 - (square_i + square_j))], dim=0) | 
					
						
						|  | else: | 
					
						
						|  | rot_kernel_1 = torch.cat( | 
					
						
						|  | [zero_kernel, (1.0 - (square_j + square_k)), (ij-rk), (ik+rj)], dim=0) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [zero_kernel, (ij+rk), (1.0 - (square_i + square_k)), (jk-ri)], dim=0) | 
					
						
						|  | rot_kernel_3 = torch.cat( | 
					
						
						|  | [zero_kernel, (ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0) | 
					
						
						|  |  | 
					
						
						|  | zero_kernel2 = torch.cat( | 
					
						
						|  | [zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=0) | 
					
						
						|  | global_rot_kernel = torch.cat( | 
					
						
						|  | [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=1) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | if scale is not None: | 
					
						
						|  | rot_kernel_1 = torch.cat([scale * (1.0 - (square_j + square_k)), | 
					
						
						|  | scale * (ij-rk), scale * (ik+rj)], dim=0) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [scale * (ij+rk), scale * (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0) | 
					
						
						|  | rot_kernel_3 = torch.cat([scale * (ik-rj), scale * (jk+ri), scale * | 
					
						
						|  | (1.0 - (square_i + square_j))], dim=0) | 
					
						
						|  | else: | 
					
						
						|  | rot_kernel_1 = torch.cat( | 
					
						
						|  | [1.0 - (square_j + square_k), (ij-rk), (ik+rj)], dim=0) | 
					
						
						|  | rot_kernel_2 = torch.cat( | 
					
						
						|  | [(ij+rk), 1.0 - (square_i + square_k), (jk-ri)], dim=0) | 
					
						
						|  | rot_kernel_3 = torch.cat( | 
					
						
						|  | [(ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0) | 
					
						
						|  |  | 
					
						
						|  | global_rot_kernel = torch.cat( | 
					
						
						|  | [rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=1) | 
					
						
						|  |  | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | if bias is not None: | 
					
						
						|  | return torch.addmm(bias, input, global_rot_kernel) | 
					
						
						|  | else: | 
					
						
						|  | return torch.mm(input, global_rot_kernel) | 
					
						
						|  | else: | 
					
						
						|  | output = torch.matmul(input, global_rot_kernel) | 
					
						
						|  | if bias is not None: | 
					
						
						|  | return output+bias | 
					
						
						|  | else: | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class QuaternionLinearFunction(torch.autograd.Function): | 
					
						
						|  | @staticmethod | 
					
						
						|  | def forward(ctx, input, r_weight, i_weight, j_weight, k_weight, bias=None): | 
					
						
						|  | ctx.save_for_backward(input, r_weight, i_weight, | 
					
						
						|  | j_weight, k_weight, bias) | 
					
						
						|  | check_input(input) | 
					
						
						|  | cat_kernels_4_r = torch.cat( | 
					
						
						|  | [r_weight, -i_weight, -j_weight, -k_weight], dim=0) | 
					
						
						|  | cat_kernels_4_i = torch.cat( | 
					
						
						|  | [i_weight,  r_weight, -k_weight, j_weight], dim=0) | 
					
						
						|  | cat_kernels_4_j = torch.cat( | 
					
						
						|  | [j_weight,  k_weight, r_weight, -i_weight], dim=0) | 
					
						
						|  | cat_kernels_4_k = torch.cat( | 
					
						
						|  | [k_weight,  -j_weight, i_weight, r_weight], dim=0) | 
					
						
						|  | cat_kernels_4_quaternion = torch.cat( | 
					
						
						|  | [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1) | 
					
						
						|  | if input.dim() == 2: | 
					
						
						|  | if bias is not None: | 
					
						
						|  | return torch.addmm(bias, input, cat_kernels_4_quaternion) | 
					
						
						|  | else: | 
					
						
						|  | return torch.mm(input, cat_kernels_4_quaternion) | 
					
						
						|  | else: | 
					
						
						|  | output = torch.matmul(input, cat_kernels_4_quaternion) | 
					
						
						|  | if bias is not None: | 
					
						
						|  | return output+bias | 
					
						
						|  | else: | 
					
						
						|  | return output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def backward(ctx, grad_output): | 
					
						
						|  | input, r_weight, i_weight, j_weight, k_weight, bias = ctx.saved_tensors | 
					
						
						|  | grad_input = grad_weight_r = grad_weight_i = grad_weight_j = grad_weight_k = grad_bias = None | 
					
						
						|  |  | 
					
						
						|  | input_r = torch.cat([r_weight, -i_weight, -j_weight, -k_weight], dim=0) | 
					
						
						|  | input_i = torch.cat([i_weight,  r_weight, -k_weight, j_weight], dim=0) | 
					
						
						|  | input_j = torch.cat([j_weight,  k_weight, r_weight, -i_weight], dim=0) | 
					
						
						|  | input_k = torch.cat([k_weight,  -j_weight, i_weight, r_weight], dim=0) | 
					
						
						|  | cat_kernels_4_quaternion_T = Variable( | 
					
						
						|  | torch.cat([input_r, input_i, input_j, input_k], dim=1).permute(1, 0), requires_grad=False) | 
					
						
						|  |  | 
					
						
						|  | r = get_r(input) | 
					
						
						|  | i = get_i(input) | 
					
						
						|  | j = get_j(input) | 
					
						
						|  | k = get_k(input) | 
					
						
						|  | input_r = torch.cat([r, -i, -j, -k], dim=0) | 
					
						
						|  | input_i = torch.cat([i,  r, -k, j], dim=0) | 
					
						
						|  | input_j = torch.cat([j,  k, r, -i], dim=0) | 
					
						
						|  | input_k = torch.cat([k,  -j, i, r], dim=0) | 
					
						
						|  | input_mat = Variable( | 
					
						
						|  | torch.cat([input_r, input_i, input_j, input_k], dim=1), requires_grad=False) | 
					
						
						|  |  | 
					
						
						|  | r = get_r(grad_output) | 
					
						
						|  | i = get_i(grad_output) | 
					
						
						|  | j = get_j(grad_output) | 
					
						
						|  | k = get_k(grad_output) | 
					
						
						|  | input_r = torch.cat([r, i, j, k], dim=1) | 
					
						
						|  | input_i = torch.cat([-i,  r, k, -j], dim=1) | 
					
						
						|  | input_j = torch.cat([-j,  -k, r, i], dim=1) | 
					
						
						|  | input_k = torch.cat([-k,  j, -i, r], dim=1) | 
					
						
						|  | grad_mat = torch.cat([input_r, input_i, input_j, input_k], dim=0) | 
					
						
						|  |  | 
					
						
						|  | if ctx.needs_input_grad[0]: | 
					
						
						|  | grad_input = grad_output.mm(cat_kernels_4_quaternion_T) | 
					
						
						|  | if ctx.needs_input_grad[1]: | 
					
						
						|  | grad_weight = grad_mat.permute(1, 0).mm(input_mat).permute(1, 0) | 
					
						
						|  | unit_size_x = r_weight.size(0) | 
					
						
						|  | unit_size_y = r_weight.size(1) | 
					
						
						|  | grad_weight_r = grad_weight.narrow( | 
					
						
						|  | 0, 0, unit_size_x).narrow(1, 0, unit_size_y) | 
					
						
						|  | grad_weight_i = grad_weight.narrow( | 
					
						
						|  | 0, 0, unit_size_x).narrow(1, unit_size_y, unit_size_y) | 
					
						
						|  | grad_weight_j = grad_weight.narrow( | 
					
						
						|  | 0, 0, unit_size_x).narrow(1, unit_size_y*2, unit_size_y) | 
					
						
						|  | grad_weight_k = grad_weight.narrow( | 
					
						
						|  | 0, 0, unit_size_x).narrow(1, unit_size_y*3, unit_size_y) | 
					
						
						|  | if ctx.needs_input_grad[5]: | 
					
						
						|  | grad_bias = grad_output.sum(0).squeeze(0) | 
					
						
						|  |  | 
					
						
						|  | return grad_input, grad_weight_r, grad_weight_i, grad_weight_j, grad_weight_k, grad_bias | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def hamilton_product(q0, q1): | 
					
						
						|  | """ | 
					
						
						|  | Applies a Hamilton product q0 * q1: | 
					
						
						|  | Shape: | 
					
						
						|  | - q0, q1 should be (batch_size, quaternion_number) | 
					
						
						|  | (rr' - xx' - yy' - zz')  + | 
					
						
						|  | (rx' + xr' + yz' - zy')i + | 
					
						
						|  | (ry' - xz' + yr' + zx')j + | 
					
						
						|  | (rz' + xy' - yx' + zr')k + | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | q1_r = get_r(q1) | 
					
						
						|  | q1_i = get_i(q1) | 
					
						
						|  | q1_j = get_j(q1) | 
					
						
						|  | q1_k = get_k(q1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | r_base = torch.mul(q0, q1) | 
					
						
						|  |  | 
					
						
						|  | r = get_r(r_base) - get_i(r_base) - get_j(r_base) - get_k(r_base) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | i_base = torch.mul(q0, torch.cat([q1_i, q1_r, q1_k, q1_j], dim=1)) | 
					
						
						|  |  | 
					
						
						|  | i = get_r(i_base) + get_i(i_base) + get_j(i_base) - get_k(i_base) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | j_base = torch.mul(q0, torch.cat([q1_j, q1_k, q1_r, q1_i], dim=1)) | 
					
						
						|  |  | 
					
						
						|  | j = get_r(j_base) - get_i(j_base) + get_j(j_base) + get_k(j_base) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | k_base = torch.mul(q0, torch.cat([q1_k, q1_j, q1_i, q1_r], dim=1)) | 
					
						
						|  |  | 
					
						
						|  | k = get_r(k_base) + get_i(k_base) - get_j(k_base) + get_k(k_base) | 
					
						
						|  |  | 
					
						
						|  | return torch.cat([r, i, j, k], dim=1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def unitary_init(in_features, out_features, rng, kernel_size=None, criterion='he'): | 
					
						
						|  | if kernel_size is not None: | 
					
						
						|  | receptive_field = np.prod(kernel_size) | 
					
						
						|  | fan_in = in_features * receptive_field | 
					
						
						|  | fan_out = out_features * receptive_field | 
					
						
						|  | else: | 
					
						
						|  | fan_in = in_features | 
					
						
						|  | fan_out = out_features | 
					
						
						|  |  | 
					
						
						|  | if kernel_size is None: | 
					
						
						|  | kernel_shape = (in_features, out_features) | 
					
						
						|  | else: | 
					
						
						|  | if type(kernel_size) is int: | 
					
						
						|  | kernel_shape = (out_features, in_features) + tuple((kernel_size,)) | 
					
						
						|  | else: | 
					
						
						|  | kernel_shape = (out_features, in_features) + (*kernel_size,) | 
					
						
						|  |  | 
					
						
						|  | number_of_weights = np.prod(kernel_shape) | 
					
						
						|  | v_r = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_i = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_j = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_k = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for i in range(0, number_of_weights): | 
					
						
						|  | norm = np.sqrt(v_r[i]**2 + v_i[i]**2 + v_j[i]**2 + v_k[i]**2)+0.0001 | 
					
						
						|  | v_r[i] /= norm | 
					
						
						|  | v_i[i] /= norm | 
					
						
						|  | v_j[i] /= norm | 
					
						
						|  | v_k[i] /= norm | 
					
						
						|  | v_r = v_r.reshape(kernel_shape) | 
					
						
						|  | v_i = v_i.reshape(kernel_shape) | 
					
						
						|  | v_j = v_j.reshape(kernel_shape) | 
					
						
						|  | v_k = v_k.reshape(kernel_shape) | 
					
						
						|  |  | 
					
						
						|  | return (v_r, v_i, v_j, v_k) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def random_init(in_features, out_features, rng, kernel_size=None, criterion='glorot'): | 
					
						
						|  | if kernel_size is not None: | 
					
						
						|  | receptive_field = np.prod(kernel_size) | 
					
						
						|  | fan_in = in_features * receptive_field | 
					
						
						|  | fan_out = out_features * receptive_field | 
					
						
						|  | else: | 
					
						
						|  | fan_in = in_features | 
					
						
						|  | fan_out = out_features | 
					
						
						|  |  | 
					
						
						|  | if criterion == 'glorot': | 
					
						
						|  | s = 1. / np.sqrt(2*(fan_in + fan_out)) | 
					
						
						|  | elif criterion == 'he': | 
					
						
						|  | s = 1. / np.sqrt(2*fan_in) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError('Invalid criterion: ' + criterion) | 
					
						
						|  |  | 
					
						
						|  | if kernel_size is None: | 
					
						
						|  | kernel_shape = (in_features, out_features) | 
					
						
						|  | else: | 
					
						
						|  | if type(kernel_size) is int: | 
					
						
						|  | kernel_shape = (out_features, in_features) + tuple((kernel_size,)) | 
					
						
						|  | else: | 
					
						
						|  | kernel_shape = (out_features, in_features) + (*kernel_size,) | 
					
						
						|  |  | 
					
						
						|  | number_of_weights = np.prod(kernel_shape) | 
					
						
						|  | v_r = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_i = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_j = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_k = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  |  | 
					
						
						|  | v_r = v_r.reshape(kernel_shape) | 
					
						
						|  | v_i = v_i.reshape(kernel_shape) | 
					
						
						|  | v_j = v_j.reshape(kernel_shape) | 
					
						
						|  | v_k = v_k.reshape(kernel_shape) | 
					
						
						|  |  | 
					
						
						|  | weight_r = v_r | 
					
						
						|  | weight_i = v_i | 
					
						
						|  | weight_j = v_j | 
					
						
						|  | weight_k = v_k | 
					
						
						|  | return (weight_r, weight_i, weight_j, weight_k) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def quaternion_init(in_features, out_features, rng, kernel_size=None, criterion='glorot'): | 
					
						
						|  | if kernel_size is not None: | 
					
						
						|  | receptive_field = np.prod(kernel_size) | 
					
						
						|  | fan_in = in_features * receptive_field | 
					
						
						|  | fan_out = out_features * receptive_field | 
					
						
						|  | else: | 
					
						
						|  | fan_in = in_features | 
					
						
						|  | fan_out = out_features | 
					
						
						|  |  | 
					
						
						|  | if criterion == 'glorot': | 
					
						
						|  | s = 1. / np.sqrt(2*(fan_in + fan_out)) | 
					
						
						|  | elif criterion == 'he': | 
					
						
						|  | s = 1. / np.sqrt(2*fan_in) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError('Invalid criterion: ' + criterion) | 
					
						
						|  |  | 
					
						
						|  | rng = RandomState(np.random.randint(1, 1234)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if kernel_size is None: | 
					
						
						|  | kernel_shape = (in_features, out_features) | 
					
						
						|  | else: | 
					
						
						|  | if type(kernel_size) is int: | 
					
						
						|  | kernel_shape = (out_features, in_features) + tuple((kernel_size,)) | 
					
						
						|  | else: | 
					
						
						|  | kernel_shape = (out_features, in_features) + (*kernel_size,) | 
					
						
						|  |  | 
					
						
						|  | modulus = chi.rvs(4, loc=0, scale=s, size=kernel_shape) | 
					
						
						|  | number_of_weights = np.prod(kernel_shape) | 
					
						
						|  | v_i = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_j = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  | v_k = np.random.uniform(-1.0, 1.0, number_of_weights) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for i in range(0, number_of_weights): | 
					
						
						|  | norm = np.sqrt(v_i[i]**2 + v_j[i]**2 + v_k[i]**2 + 0.0001) | 
					
						
						|  | v_i[i] /= norm | 
					
						
						|  | v_j[i] /= norm | 
					
						
						|  | v_k[i] /= norm | 
					
						
						|  | v_i = v_i.reshape(kernel_shape) | 
					
						
						|  | v_j = v_j.reshape(kernel_shape) | 
					
						
						|  | v_k = v_k.reshape(kernel_shape) | 
					
						
						|  |  | 
					
						
						|  | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) | 
					
						
						|  |  | 
					
						
						|  | weight_r = modulus * np.cos(phase) | 
					
						
						|  | weight_i = modulus * v_i*np.sin(phase) | 
					
						
						|  | weight_j = modulus * v_j*np.sin(phase) | 
					
						
						|  | weight_k = modulus * v_k*np.sin(phase) | 
					
						
						|  |  | 
					
						
						|  | return (weight_r, weight_i, weight_j, weight_k) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_dropout_mask(dropout_p, size, rng, as_type, operation='linear'): | 
					
						
						|  | if operation == 'linear': | 
					
						
						|  | mask = rng.binomial(n=1, p=1-dropout_p, size=size) | 
					
						
						|  | return Variable(torch.from_numpy(mask).type(as_type)) | 
					
						
						|  | else: | 
					
						
						|  | raise Exception("create_dropout_mask accepts only 'linear'. Found operation = " | 
					
						
						|  | + str(operation)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def affect_init(r_weight, i_weight, j_weight, k_weight, init_func, rng, init_criterion): | 
					
						
						|  | if r_weight.size() != i_weight.size() or r_weight.size() != j_weight.size() or \ | 
					
						
						|  | r_weight.size() != k_weight.size(): | 
					
						
						|  | raise ValueError('The real and imaginary weights ' | 
					
						
						|  | 'should have the same size . Found: r:' | 
					
						
						|  | + str(r_weight.size()) + ' i:' | 
					
						
						|  | + str(i_weight.size()) + ' j:' | 
					
						
						|  | + str(j_weight.size()) + ' k:' | 
					
						
						|  | + str(k_weight.size())) | 
					
						
						|  |  | 
					
						
						|  | elif r_weight.dim() != 2: | 
					
						
						|  | raise Exception('affect_init accepts only matrices. Found dimension = ' | 
					
						
						|  | + str(r_weight.dim())) | 
					
						
						|  | kernel_size = None | 
					
						
						|  | r, i, j, k = init_func(r_weight.size(0), r_weight.size( | 
					
						
						|  | 1), rng, kernel_size, init_criterion) | 
					
						
						|  | r, i, j, k = torch.from_numpy(r), torch.from_numpy( | 
					
						
						|  | i), torch.from_numpy(j), torch.from_numpy(k) | 
					
						
						|  | r_weight.data = r.type_as(r_weight.data) | 
					
						
						|  | i_weight.data = i.type_as(i_weight.data) | 
					
						
						|  | j_weight.data = j.type_as(j_weight.data) | 
					
						
						|  | k_weight.data = k.type_as(k_weight.data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def affect_init_conv(r_weight, i_weight, j_weight, k_weight, kernel_size, init_func, rng, | 
					
						
						|  | init_criterion): | 
					
						
						|  | if r_weight.size() != i_weight.size() or r_weight.size() != j_weight.size() or \ | 
					
						
						|  | r_weight.size() != k_weight.size(): | 
					
						
						|  | raise ValueError('The real and imaginary weights ' | 
					
						
						|  | 'should have the same size . Found: r:' | 
					
						
						|  | + str(r_weight.size()) + ' i:' | 
					
						
						|  | + str(i_weight.size()) + ' j:' | 
					
						
						|  | + str(j_weight.size()) + ' k:' | 
					
						
						|  | + str(k_weight.size())) | 
					
						
						|  |  | 
					
						
						|  | elif 2 >= r_weight.dim(): | 
					
						
						|  | raise Exception('affect_conv_init accepts only tensors that have more than 2 dimensions. Found dimension = ' | 
					
						
						|  | + str(real_weight.dim())) | 
					
						
						|  |  | 
					
						
						|  | r, i, j, k = init_func( | 
					
						
						|  | r_weight.size(1), | 
					
						
						|  | r_weight.size(0), | 
					
						
						|  | rng=rng, | 
					
						
						|  | kernel_size=kernel_size, | 
					
						
						|  | criterion=init_criterion | 
					
						
						|  | ) | 
					
						
						|  | r, i, j, k = torch.from_numpy(r), torch.from_numpy( | 
					
						
						|  | i), torch.from_numpy(j), torch.from_numpy(k) | 
					
						
						|  | r_weight.data = r.type_as(r_weight.data) | 
					
						
						|  | i_weight.data = i.type_as(i_weight.data) | 
					
						
						|  | j_weight.data = j.type_as(j_weight.data) | 
					
						
						|  | k_weight.data = k.type_as(k_weight.data) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_kernel_and_weight_shape(operation, in_channels, out_channels, kernel_size): | 
					
						
						|  | if operation == 'convolution1d': | 
					
						
						|  | if type(kernel_size) is not int: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | """An invalid kernel_size was supplied for a 1d convolution. The kernel size | 
					
						
						|  | must be integer in the case. Found kernel_size = """ + str(kernel_size) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | ks = kernel_size | 
					
						
						|  | w_shape = (out_channels, in_channels) + tuple((ks,)) | 
					
						
						|  | else: | 
					
						
						|  | if operation == 'convolution2d' and type(kernel_size) is int: | 
					
						
						|  | ks = (kernel_size, kernel_size) | 
					
						
						|  | elif operation == 'convolution3d' and type(kernel_size) is int: | 
					
						
						|  | ks = (kernel_size, kernel_size, kernel_size) | 
					
						
						|  | elif type(kernel_size) is not int: | 
					
						
						|  | if operation == 'convolution2d' and len(kernel_size) != 2: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | """An invalid kernel_size was supplied for a 2d convolution. The kernel size | 
					
						
						|  | must be either an integer or a tuple of 2. Found kernel_size = """ + str(kernel_size) | 
					
						
						|  | ) | 
					
						
						|  | elif operation == 'convolution3d' and len(kernel_size) != 3: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | """An invalid kernel_size was supplied for a 3d convolution. The kernel size | 
					
						
						|  | must be either an integer or a tuple of 3. Found kernel_size = """ + str(kernel_size) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | ks = kernel_size | 
					
						
						|  | w_shape = (out_channels, in_channels) + (*ks,) | 
					
						
						|  | return ks, w_shape | 
					
						
						|  |  |