sunana's picture
update files
f97813d
raw
history blame
36.3 kB
import numpy as np
import math
import torch
from io import BytesIO
import numpy
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import os
import pandas as pd
import imageio
from torch.cuda.amp import autocast as autocast
def cart2pol(x, y):
rho = np.sqrt(x ** 2 + y ** 2)
phi = np.arctan2(y, x)
return (rho, phi)
def pol2cart(rho, phi):
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return (x, y)
def inverse_sigmoid(p):
return np.log(p / (1 - p))
def artanh(y):
return 0.5 * np.log((1 + y) / (1 - y))
class V1(nn.Module):
"""each input includes 10 frame with 25 frame/sec sampling rate
temporal window size = 5 frame(200ms)
spatial window size = 5*2 + 1 = 11
spatial filter is
lambda is frequency of cos wave
"""
def __init__(self, spatial_num=32, scale_num=8, scale_factor=16, kernel_radius=7, num_ft=32,
kernel_size=6, average_time=True):
super(V1, self).__init__()
def make_param(in_channels, values, requires_grad=True, dtype=None):
if dtype is None:
dtype = 'float32'
values = numpy.require(values, dtype=dtype)
n = in_channels * len(values)
data = torch.from_numpy(values).view(1, -1)
data = data.repeat(in_channels, 1)
return torch.nn.Parameter(data=data, requires_grad=requires_grad)
assert spatial_num == num_ft
scale_each_level = np.exp(1 / (scale_num - 1) * np.log(1 / scale_factor))
self.scale_each_level = scale_each_level
self.scale_num = scale_num
self.cell_index = 0
self.spatial_filter = nn.ModuleList([GaborFilters(kernel_radius=kernel_radius, num_units=spatial_num,random=False)
for i in range(scale_num)])
self.temporal_decay = 0.2
self.spatial_decay = 0.2
self.spatial_radius = kernel_radius
self.spatial_kernel_size = kernel_radius * 2 + 1
self.spatial_num = spatial_num
self.temporal_filter = nn.ModuleList([TemporalFilter(num_ft=num_ft, kernel_size=kernel_size, random=False)
for i in range(scale_num)]) # 16 filter
self.n_frames = 11
self._num_after_st = spatial_num * scale_num
if not average_time:
self._num_after_st = self._num_after_st * (self.n_frames - kernel_size + 1)
if average_time:
self.temporal_pooling = make_param(self._num_after_st, np.ones((self.n_frames - kernel_size + 1)),
requires_grad=True)
# TODO: concentrate on middle frame
self.temporal_pooling = make_param(self._num_after_st, [0.05, 0.1, 0.4, 0.4, 0.1, 0.05],
requires_grad=True)
self.norm_sigma = make_param(1, np.array([0.2]), requires_grad=True)
self.spontaneous_firing = make_param(1, np.array([0.3]), requires_grad=True)
self.norm_k = make_param(1, np.array([4.0]), requires_grad=True)
self._average_time = average_time
self.t_sin = None
self.t_cos = None
self.s_sin = None
self.s_cos = None
def infer_scale(self, x, scale): # x should be list of B,1,H,W
energy_list = []
n = len(x)
B, C, H, W = x[0].shape
x = [img.unsqueeze(0) for img in x]
x = torch.cat(x, dim=0).reshape(n * B, C, H, W)
sy = x.size(2)
sx = x.size(3)
s_sin = self.s_sin
s_cos = self.s_cos
gb_sin = s_sin.view(self.spatial_num, 1, self.spatial_kernel_size, self.spatial_kernel_size)
gb_cos = s_cos.view(self.spatial_num, 1, self.spatial_kernel_size, self.spatial_kernel_size)
# flip kernel
gb_sin = torch.flip(gb_sin, dims=[-1, -2])
gb_cos = torch.flip(gb_cos, dims=[-1, -2])
res_sin = F.conv2d(input=x, weight=gb_sin,
padding=self.spatial_radius, groups=1)
res_cos = F.conv2d(input=x, weight=gb_cos,
padding=self.spatial_radius, groups=1)
res_sin = res_sin.view(B, -1, sy, sx)
res_cos = res_cos.view(B, -1, sy, sx)
g_asin_list = res_sin.reshape(n, B, -1, H, W)
g_acos_list = res_cos.reshape(n, B, -1, H, W)
for channel in range(self.spatial_filter[0].n_channels_post_conv):
k_sin = self.t_sin[channel, ...][None]
k_cos = self.t_cos[channel, ...][None]
# spatial filter
g_asin, g_acos = g_asin_list[:, :, channel, :, :], g_acos_list[:, :, channel, :, :] # n,b,h,w
g_asin = g_asin.reshape(n, B * H * W, 1).permute(1, 2, 0) # bhw,1,n
g_acos = g_acos.reshape(n, B * H * W, 1).permute(1, 2, 0)
# reverse the impulse response
k_sin = torch.flip(k_sin, dims=(-1,))
k_cos = torch.flip(k_cos, dims=(-1,))
#
a = F.conv1d(g_acos, k_sin, padding="valid", bias=None)
b = F.conv1d(g_asin, k_cos, padding="valid", bias=None)
g_o = a + b
a = F.conv1d(g_acos, k_cos, padding="valid", bias=None)
b = F.conv1d(g_asin, k_sin, padding="valid", bias=None)
g_e = a - b
energy_component = g_o ** 2 + g_e ** 2 + self.spontaneous_firing.square()
energy_component = energy_component.reshape(B, H, W, a.size(-1)).permute(0, 3, 1, 2)
if self._average_time: # average motion energy across time
total_channel = scale * self.spatial_num + channel
pooling = self.temporal_pooling[total_channel][None, ..., None, None]
energy_component = abs(torch.mean(energy_component * pooling, dim=1, keepdim=True))
energy_list.append(energy_component)
energy_list = torch.cat(energy_list, dim=1)
return energy_list
def forward(self, image_list):
_, _, H, W = image_list[0].shape
MT_size = (H // 8, W // 8)
self.cell_index = 0
with torch.no_grad():
if image_list[0].max() > 10:
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1
# I_mean = torch.cat(image_list, dim=0).mean()
# image_list = [(image - I_mean) for image in image_list]
ms_com = []
for scale in range(self.scale_num):
self.t_sin, self.t_cos = self.temporal_filter[scale].make_temporal_filter()
self.s_sin, self.s_cos = self.spatial_filter[scale].make_gabor_filters(quadrature=True)
st_component = self.infer_scale(image_list, scale)
st_component = F.interpolate(st_component, size=MT_size, mode="bilinear", align_corners=True)
ms_com.append(st_component)
image_list = [F.interpolate(img, scale_factor=self.scale_each_level, mode="bilinear") for img in image_list]
motion_energy = self.normalize(torch.cat(ms_com, dim=1))
# self.visualize_activation(motion_energy)
return motion_energy
def normalize(self, x): # TODO
sum_activation = torch.mean(x, dim=[1], keepdim=True) + torch.square(self.norm_sigma)
x = self.norm_k.abs() * x / sum_activation
return x
def _get_v1_order(self):
thetas = [gabor_scale.thetas for gabor_scale in self.spatial_filter]
fss = [gabor_scale.fs for gabor_scale in self.spatial_filter]
fts = [temporal_scale.ft for temporal_scale in self.temporal_filter]
scale_each_level = self.scale_each_level
scale_num = self.scale_num
neural_representation = []
index = 0
for scale_idx in range(len(thetas)):
theta_scale = thetas[scale_idx]
theta_scale = torch.sigmoid(theta_scale) * 2 * torch.pi # spatial orientation constrain to 0-pi
fs_scale = fss[scale_idx]
fs_scale = torch.sigmoid(fs_scale) * 0.25
fs_scale = fs_scale * (scale_each_level ** scale_idx)
ft_scale = fts[scale_idx]
ft_scale = torch.sigmoid(ft_scale) * 0.25
theta_scale = theta_scale.squeeze().cpu().detach().numpy()
fs_scale = fs_scale.squeeze().cpu().detach().numpy()
ft_scale = ft_scale.squeeze().cpu().detach().numpy()
for gabor_idx in range(len(theta_scale)):
speed = ft_scale[gabor_idx] / fs_scale[gabor_idx]
assert speed >= 0
angle = theta_scale[gabor_idx]
a = {"theta": -angle + np.pi, "fs": fs_scale[gabor_idx], "ft": ft_scale[gabor_idx], "speed": speed,
"index": index}
index = index + 1
neural_representation.append(a)
return neural_representation
def visualize_activation(self, activation, if_log=True):
neural_representation = self._get_v1_order()
activation = activation[:, :, 14:-14, 14:-14] # eliminate boundary
activation = torch.mean(activation, dim=[2, 3], keepdim=False)[0]
ax1 = plt.subplot(111, projection='polar')
theta_list = []
v_list = []
energy_list = []
for index in range(len(neural_representation)):
v = neural_representation[index]["speed"]
theta = neural_representation[index]["theta"]
location = neural_representation[index]["index"]
energy = activation.squeeze()[location].cpu().detach().numpy()
theta_list.append(theta)
v_list.append(v)
energy_list.append(energy)
v_list, theta_list, energy_list = np.array(v_list), np.array(theta_list), np.array(energy_list)
x, y = pol2cart(v_list, theta_list)
plt.scatter(theta_list, v_list, c=energy_list, cmap="rainbow", s=(energy_list + 20), alpha=0.5)
plt.axis('on')
if if_log:
ax1.set_rscale('symlog')
plt.colorbar()
energy_list = np.expand_dims(energy_list, 0).repeat(len(theta_list), 0)
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# read the buffer and convert to an image
image = imageio.imread(buf)
buf.close()
plt.close()
plt.clf()
return image
@staticmethod
def demo():
input = [torch.ones(2, 1, 256, 256).cuda() for k in range(11)]
model = V1(spatial_num=16, scale_num=16, scale_factor=16, kernel_radius=7, num_ft=16,
kernel_size=6, average_time=True).cuda()
for i in range(100):
import time
start = time.time()
with autocast(enabled=True):
x = model(input)
print(x.shape)
torch.mean(x).backward()
end = time.time()
print(end - start)
print("#================================++#")
@property
def num_after_st(self):
return self._num_after_st
class TemporalFilter(nn.Module):
def __init__(self, in_channels=1, num_ft=8, kernel_size=6, random=True):
# 40ms per time unit, 200ms -> 5+1 frames
# use exponential decay plus sin wave
super().__init__()
self.kernel_size = kernel_size
def make_param(in_channels, values, requires_grad=True, dtype=None):
if dtype is None:
dtype = 'float32'
values = numpy.require(values, dtype=dtype)
n = in_channels * len(values)
data = torch.from_numpy(values).view(1, -1)
data = data.repeat(in_channels, 1)
return torch.nn.Parameter(data=data, requires_grad=requires_grad)
indices = torch.arange(kernel_size, dtype=torch.float32)
self.register_buffer('indices', indices)
if random:
self.ft = make_param(in_channels, values=inverse_sigmoid(numpy.random.uniform(0.01, 0.99, num_ft)),
requires_grad=True)
self.tao = make_param(in_channels, values=numpy.arange(num_ft) / 2 + 1, requires_grad=True)
else: # evenly distributed
self.ft = make_param(in_channels, values=inverse_sigmoid(numpy.linspace(0.01, 0.99, num_ft)),
requires_grad=True)
self.tao = make_param(in_channels, values=numpy.arange(num_ft) / 2 + 1, requires_grad=True)
self.feat_dim = num_ft
self.temporal_decay = 0.2
def make_temporal_filter(self):
fts = torch.sigmoid(self.ft) * 0.25
tao = torch.sigmoid(self.tao) * (-self.kernel_size / np.log(self.temporal_decay))
t = self.indices
fts = fts.view(1, fts.shape[1], 1)
tao = tao.view(1, tao.shape[1], 1)
t = t.view(1, 1, t.shape[0])
temporal_sin = torch.exp(-t / tao) * torch.sin(2 * torch.pi * fts * t)
temporal_cos = torch.exp(-t / tao) * torch.cos(2 * torch.pi * fts * t)
temporal_sin = temporal_sin.view(-1, self.kernel_size)
temporal_cos = temporal_cos.view(-1, self.kernel_size)
temporal_sin = temporal_sin.view(self.feat_dim, 1, self.kernel_size)
temporal_cos = temporal_cos.view(self.feat_dim, 1, self.kernel_size)
# temporal_sin = torch.chunk(temporal_sin, dim=0, chunks=self._feat_dim)
# temporal_cos = torch.chunk(temporal_cos, dim=0, chunks=self._feat_dim)
return temporal_sin, temporal_cos # 1,kz
def demo_temporal_filter(self, points=100):
fts = torch.sigmoid(self.ft) * 0.25
tao = torch.sigmoid(self.tao) * (-(self.kernel_size - 1) / np.log(self.temporal_decay))
t = torch.linspace(self.indices[0], self.indices[-1], steps=points)
fts = fts.view(1, fts.shape[1], 1)
tao = tao.view(1, tao.shape[1], 1)
t = t.view(1, 1, t.shape[0])
print("ft:" + str(fts))
print("tao:" + str(tao))
temporal_sin = torch.exp(-t / tao) * torch.sin(2 * torch.pi * fts * t)
temporal_cos = torch.exp(-t / tao) * torch.cos(2 * torch.pi * fts * t)
temporal_sin = temporal_sin.view(-1, points)
temporal_cos = temporal_cos.view(-1, points)
temporal_sin = temporal_sin.view(self.feat_dim, 1, points)
temporal_cos = temporal_cos.view(self.feat_dim, 1, points)
# temporal_sin = torch.chunk(temporal_sin, dim=0, chunks=self._feat_dim)
# temporal_cos = torch.chunk(temporal_cos, dim=0, chunks=self._feat_dim)
return temporal_sin, temporal_cos # 1,kz
def forward(self, x_sin, x_cos):
in_channels = x_sin.size(1)
n = x_sin.size(2)
# batch, c, sequence
me = []
t_sin, t_cos = self.make_temporal_filter()
for n_t in range(self.feat_dim):
k_sin = t_sin[n_t, ...].expand(in_channels, -1, -1)
k_cos = t_cos[n_t, ...].expand(in_channels, -1, -1)
a = F.conv1d(x_sin, weight=k_cos, padding="same", groups=in_channels, bias=None)
b = F.conv1d(x_cos, weight=k_sin, padding="same", groups=in_channels, bias=None)
g_o = a + b
a = F.conv1d(x_sin, weight=k_sin, padding="same", groups=in_channels, bias=None)
b = F.conv1d(x_cos, weight=k_cos, padding="same", groups=in_channels, bias=None)
g_e = a - b
energy_component = g_o ** 2 + g_e ** 2
me.append(energy_component)
return me
class GaborFilters(nn.Module):
def __init__(self,
in_channels=1,
kernel_radius=7,
num_units=512,
random=True
):
# the total number of or units for each scale
super().__init__()
self.in_channels = in_channels
kernel_size = kernel_radius * 2 + 1
self.kernel_size = kernel_size
self.kernel_radius = kernel_radius
def make_param(in_channels, values, requires_grad=True, dtype=None):
if dtype is None:
dtype = 'float32'
values = numpy.require(values, dtype=dtype)
n = in_channels * len(values)
data = torch.from_numpy(values).view(1, -1)
data = data.repeat(in_channels, 1)
return torch.nn.Parameter(data=data, requires_grad=requires_grad)
# build all learnable parameters
# random distribution
if random:
self.sigmas = make_param(in_channels, inverse_sigmoid(np.random.uniform(0.8, 0.99, num_units)))
self.fs = make_param(in_channels, values=inverse_sigmoid(numpy.random.uniform(0.2, 0.8, num_units)))
# maximun is 0.25 cycle/frame
self.gammas = make_param(in_channels, numpy.ones(num_units)) # TODO: fix gamma or not
self.psis = make_param(in_channels, np.zeros(num_units), requires_grad=False) # fix phase
self.thetas = make_param(in_channels, values=inverse_sigmoid(numpy.random.uniform(0.01, 0.99, num_units)),
requires_grad=True)
else: # evenly distribution
self.sigmas = make_param(in_channels, inverse_sigmoid(np.linspace(0.8, 0.99, num_units)))
self.fs = make_param(in_channels, values=inverse_sigmoid(numpy.linspace(0.01, 0.99, num_units)))
# maximun is 0.25 cycle/frame
self.gammas = make_param(in_channels, numpy.ones(num_units)) # TODO: fix gamma or not
self.psis = make_param(in_channels, np.zeros(num_units), requires_grad=False) # fix phase
self.thetas = make_param(in_channels, values=inverse_sigmoid(numpy.linspace(0, 1, num_units)),
requires_grad=True)
indices = torch.arange(kernel_size, dtype=torch.float32) - (kernel_size - 1) / 2
self.register_buffer('indices', indices)
self.spatial_decay = 0.5
# number of channels after the conv
self.n_channels_post_conv = num_units
def make_gabor_filters(self, quadrature=True):
sigmas = torch.sigmoid(self.sigmas) * np.sqrt(
(self.kernel_radius - 1) ** 2 * 0.5 / np.log(
1 / self.spatial_decay)) # std of gauss win decay to 0.2 by log(0.2)
fs = torch.sigmoid(self.fs) * 0.25
# frequency of cos and sine wave keep positive, must > 2 to avoid aliasing
gammas = torch.abs(self.gammas) # shape of gauss win, set as 1 by default
psis = self.psis # phase of cos wave
thetas = torch.sigmoid(self.thetas) * 2 * torch.pi # spatial orientation constrain to 0-2pi
y = self.indices
x = self.indices
in_channels = sigmas.shape[0]
assert in_channels == fs.shape[0]
assert in_channels == gammas.shape[0]
kernel_size = y.shape[0], x.shape[0]
sigmas = sigmas.view(in_channels, sigmas.shape[1], 1, 1)
fs = fs.view(in_channels, fs.shape[1], 1, 1)
gammas = gammas.view(in_channels, gammas.shape[1], 1, 1)
psis = psis.view(in_channels, psis.shape[1], 1, 1)
thetas = thetas.view(in_channels, thetas.shape[1], 1, 1)
y = y.view(1, 1, y.shape[0], 1)
x = x.view(1, 1, 1, x.shape[0])
sigma_x = sigmas
sigma_y = sigmas / gammas
sin_t = torch.sin(thetas)
cos_t = torch.cos(thetas)
y_theta = -x * sin_t + y * cos_t
x_theta = x * cos_t + y * sin_t
if quadrature:
gb_cos = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \
* torch.cos(2.0 * math.pi * x_theta * fs + psis)
gb_sin = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \
* torch.sin(2.0 * math.pi * x_theta * fs + psis)
gb_cos = gb_cos.reshape(-1, 1, kernel_size[0], kernel_size[1])
gb_sin = gb_sin.reshape(-1, 1, kernel_size[0], kernel_size[1])
# remove DC
gb_cos = gb_cos - torch.sum(gb_cos, dim=[-1, -2], keepdim=True) / (kernel_size[0] * kernel_size[1])
gb_sin = gb_sin - torch.sum(gb_sin, dim=[-1, -2], keepdim=True) / (kernel_size[0] * kernel_size[1])
return gb_sin, gb_cos
else:
gb = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \
* torch.cos(2.0 * math.pi * x_theta * fs + psis)
gb = gb.view(-1, kernel_size[0], kernel_size[1])
return gb
def forward(self, x):
batch_size = x.size(0)
sy = x.size(2)
sx = x.size(3)
gb_sin, gb_cos = self.make_gabor_filters(quadrature=True)
assert gb_sin.shape[0] == self.n_channels_post_conv
assert gb_sin.shape[2] == self.kernel_size
assert gb_sin.shape[3] == self.kernel_size
gb_sin = gb_sin.view(self.n_channels_post_conv, 1, self.kernel_size, self.kernel_size)
gb_cos = gb_cos.view(self.n_channels_post_conv, 1, self.kernel_size, self.kernel_size)
# flip ke
gb_sin = torch.flip(gb_sin, dims=[-1, -2])
gb_cos = torch.flip(gb_cos, dims=[-1, -2])
res_sin = F.conv2d(input=x, weight=gb_sin,
padding=self.kernel_radius, groups=self.in_channels)
res_cos = F.conv2d(input=x, weight=gb_cos,
padding=self.kernel_radius, groups=self.in_channels)
if self.rotation_invariant:
res_sin = res_sin.view(batch_size, self.in_channels, -1, self.n_thetas, sy, sx)
res_sin, _ = res_sin.max(dim=3)
res_cos = res_cos.view(batch_size, self.in_channels, -1, self.n_thetas, sy, sx)
res_cos, _ = res_cos.max(dim=3)
res_sin = res_sin.view(batch_size, -1, sy, sx)
res_cos = res_cos.view(batch_size, -1, sy, sx)
return res_sin, res_cos
def demo_gabor_filters(self, quadrature=True, points=100):
sigmas = torch.sigmoid(self.sigmas) * np.sqrt(
(self.kernel_radius - 1) ** 2 * 0.5 / np.log(
1 / self.spatial_decay)) # std of gauss win decay to 0.2 by log(0.2)
fs = torch.sigmoid(self.fs) * 0.25
# frequency of cos and sine wave keep positive, must > 2 to avoid aliasing
gammas = torch.abs(self.gammas) # shape of gauss win, set as 1 by default
thetas = torch.sigmoid(self.thetas) * 2 * torch.pi # spatial orientation constrain to 0-2pi
psis = self.psis # phase of cos wave
print("theta:" + str(thetas))
print("fs:" + str(fs))
x = torch.linspace(self.indices[0], self.indices[-1], points)
y = torch.linspace(self.indices[0], self.indices[-1], points)
in_channels = sigmas.shape[0]
assert in_channels == fs.shape[0]
assert in_channels == gammas.shape[0]
kernel_size = y.shape[0], x.shape[0]
sigmas = sigmas.view(in_channels, sigmas.shape[1], 1, 1)
fs = fs.view(in_channels, fs.shape[1], 1, 1)
gammas = gammas.view(in_channels, gammas.shape[1], 1, 1)
psis = psis.view(in_channels, psis.shape[1], 1, 1)
thetas = thetas.view(in_channels, thetas.shape[1], 1, 1)
y = y.view(1, 1, y.shape[0], 1)
x = x.view(1, 1, 1, x.shape[0])
sigma_x = sigmas
sigma_y = sigmas / gammas
sin_t = torch.sin(thetas)
cos_t = torch.cos(thetas)
y_theta = -x * sin_t + y * cos_t
x_theta = x * cos_t + y * sin_t
if quadrature:
gb_cos = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \
* torch.cos(2.0 * math.pi * x_theta * fs + psis)
gb_sin = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \
* torch.sin(2.0 * math.pi * x_theta * fs + psis)
gb_cos = gb_cos.reshape(-1, 1, points, points)
gb_sin = gb_sin.reshape(-1, 1, points, points)
# remove DC
gb_cos = gb_cos - torch.sum(gb_cos, dim=[-1, -2], keepdim=True) / (points * points)
gb_sin = gb_sin - torch.sum(gb_sin, dim=[-1, -2], keepdim=True) / (points * points)
return gb_sin, gb_cos
else:
gb = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \
* torch.cos(2.0 * math.pi * x_theta * fs + psis)
gb = gb.view(-1, kernel_size[0], kernel_size[1])
return gb
def te_gabor_(num_units=48):
s_point = 100
s_kz = 7
gb_sin, gb_cos = GaborFilters(num_units=num_units, kernel_radius=s_kz).demo_gabor_filters(points=s_point)
gb = gb_sin ** 2 + gb_cos ** 2
print(gb_sin.shape)
for c in range(gb_sin.size(0)):
plt.subplot(1, 3, 1)
curve = gb_cos[c].detach().cpu().squeeze().numpy()
plt.imshow(curve)
plt.subplot(1, 3, 2)
curve = gb_sin[c].detach().cpu().squeeze().numpy()
plt.imshow(curve)
plt.subplot(1, 3, 3)
curve = gb[c].detach().cpu().squeeze().numpy()
plt.imshow(curve)
plt.show()
def te_spatial_temporal():
t_point = 6 * 100
s_point = 14 * 100
s_kz = 7
t_kz = 6
filenames = []
gb_sin_b, gb_cos_b = GaborFilters(num_units=48, kernel_radius=s_kz).demo_gabor_filters(points=s_point)
temporal = TemporalFilter(num_ft=2, kernel_size=t_kz)
t_sin, t_cos = temporal.demo_temporal_filter(points=t_point)
x = np.linspace(0, t_kz, t_point)
index = 0
for i in range(gb_sin_b.size(0)):
for j in range(t_sin.size(0)):
plt.figure(figsize=(14, 9), dpi=80)
plt.subplot(2, 3, 1)
curve = gb_sin_b[i].squeeze().detach().numpy()
plt.imshow(curve)
plt.title("Gabor Sin")
plt.subplot(2, 3, 2)
curve = gb_cos_b[i].squeeze().detach().numpy()
plt.imshow(curve)
plt.title("Gabor Cos")
plt.subplot(2, 3, 3)
curve = t_sin[j].squeeze().detach().numpy()
plt.plot(x, curve, label='sin')
plt.title("Temporal Sin")
curve = t_cos[j].squeeze().detach().numpy()
plt.plot(x, curve, label='cos')
plt.xlabel('Time (s)')
plt.ylabel('Response to pulse at t=0')
plt.legend()
plt.title("Temporal filter")
gb_sin = gb_sin_b[i].squeeze().detach()[5, :]
gb_cos = gb_cos_b[i].squeeze().detach()[5, :]
a = np.outer(t_cos[j].detach(), gb_sin)
b = np.outer(t_sin[j].detach(), gb_cos)
g_o = a + b
a = np.outer(t_sin[j].detach(), gb_sin)
b = np.outer(t_cos[j].detach(), gb_cos)
g_e = a - b
energy_component = g_o ** 2 + g_e ** 2
plt.subplot(2, 3, 4)
curve = g_o
plt.imshow(curve, cmap="gray")
plt.title("Spatial Temporal even")
plt.subplot(2, 3, 5)
curve = g_e
plt.imshow(curve, cmap="gray")
plt.title("Spatial Temporal odd")
plt.subplot(2, 3, 6)
curve = energy_component
plt.imshow(curve, cmap="gray")
plt.title("energy")
plt.savefig('filter_%d.png' % (index))
filenames.append('filter_%d.png' % (index))
index += 1
plt.show()
# build gif
with imageio.get_writer('filters_orientation.gif', mode='I') as writer:
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
# Remove files
for filename in set(filenames):
os.remove(filename)
def te_temporal_():
k_size = 6
temporal = TemporalFilter(n_tao=2, num_ft=8, kernel_size=k_size)
sin, cos = temporal.demo_temporal_filter()
print(sin.shape)
x = np.linspace(0, k_size, k_size * 100)
# plot temporal filters to illustrate what they look like.
for c in range(sin.size(0)):
curve = cos[c].detach().cpu().squeeze().numpy()
plt.plot(x, curve, label='cos')
curve = sin[c].detach().cpu().squeeze().numpy()
plt.plot(x, curve, label='sin')
plt.xlabel('Time (s)')
plt.ylabel('Response to pulse at t=0')
plt.legend()
plt.show()
def circular_hist(ax, x, bins=16, density=True, offset=0, gaps=True):
"""
Produce a circular histogram of angles on ax.
Parameters
----------
ax : matplotlib.axes._subplots.PolarAxesSubplot
axis instance created with subplot_kw=dict(projection='polar').
x : array
Angles to plot, expected in units of radians.
bins : int, optional
Defines the number of equal-width bins in the range. The default is 16.
density : bool, optional
If True plot frequency proportional to area. If False plot frequency
proportional to radius. The default is True.
offset : float, optional
Sets the offset for the location of the 0 direction in units of
radians. The default is 0.
gaps : bool, optional
Whether to allow gaps between bins. When gaps = False the bins are
forced to partition the entire [-pi, pi] range. The default is True.
Returns
-------
n : array or list of arrays
The number of values in each bin.
bins : array
The edges of the bins.
patches : `.BarContainer` or list of a single `.Polygon`
Container of individual artists used to create the histogram
or list of such containers if there are multiple input datasets.
"""
# Wrap angles to [-pi, pi)
x = (x + np.pi) % (2 * np.pi) - np.pi
# Force bins to partition entire circle
if not gaps:
bins = np.linspace(-np.pi, np.pi, num=bins + 1)
# Bin data and record counts
n, bins = np.histogram(x, bins=bins)
# Compute width of each bin
widths = np.diff(bins)
# By default plot frequency proportional to area
if density:
# Area to assign each bin
area = n / x.size
# Calculate corresponding bin radius
radius = (area / np.pi) ** .5
# Otherwise plot frequency proportional to radius
else:
radius = n
# Plot data on ax
patches = ax.bar(bins[:-1], radius, zorder=1, align='edge', width=widths,
edgecolor='C0', fill=False, linewidth=1)
# Set the direction of the zero angle
ax.set_theta_offset(offset)
# Remove ylabels for area plots (they are mostly obstructive)
if density:
ax.set_yticks([])
return n, bins, patches
def show_trained_model(file_name="/home/2TSSD/experiment/FFMEDNN/Sintel_fixv1_10.62_ckpt.pth.tar"):
import utils.torch_utils as utils
from model.fle_version_2_3.FFV1MT_MS import FFV1DNN
model = FFV1DNN(num_scales=8,
num_cells=256,
upsample_factor=8,
feature_channels=256,
scale_factor=16,
num_layers=6)
# model = utils.restore_model(model, file_name)
model = model.ffv1
t_point = 100
s_point = 100
t_kz = 6
filenames = []
x = np.arange(0, 6) * 40
x = np.repeat(x[None], axis=0, repeats=256)
temporal = model.temporal_pooling.data.cpu().squeeze().numpy()
mean = np.mean(temporal, axis=0)
plt.figure(figsize=(10, 10))
plt.subplot(2, 1, 1)
for idx in range(0, 256):
plt.plot(x[idx], temporal[idx])
plt.subplot(2, 1, 2)
plt.plot(x[0], mean, label="mean")
plt.xlabel("times (ms)")
plt.ylabel("temporal pooling weight")
plt.legend()
plt.grid(True)
plt.show()
neural_representation = model._get_v1_order()
fs = np.array([ne["fs"] for ne in neural_representation])
ft = np.array([ne["ft"] for ne in neural_representation])
ax1 = plt.subplot(131, projection='polar')
theta_list = []
v_list = []
energy_list = []
for index in range(len(neural_representation)):
v = neural_representation[index]["speed"]
theta = neural_representation[index]["theta"]
theta_list.append(theta)
v_list.append(v)
v_list, theta_list = np.array(v_list), np.array(theta_list)
x, y = pol2cart(v_list, theta_list)
plt.scatter(theta_list, v_list, c=v_list, cmap="rainbow", s=(v_list + 20), alpha=0.8)
plt.axis('on')
# plt.colorbar()
plt.grid(True)
# plt.subplot(132, projection="polar")
# plt.scatter(theta_list, np.ones_like(theta_list))
plt.subplot(132, projection='polar')
plt.scatter(theta_list, np.ones_like(v_list))
lst = []
for scale in range(8):
lst += ["scale %d" % scale] * 32
data = {"Spatial Frequency": fs, 'Temporal Frequency': ft, "Class": lst}
df = pd.DataFrame(data=data)
ax = plt.subplot(133, projection='polar')
# theta_list = theta_list[v_list > (ft * v_list.mean())]
print(len(theta_list))
bins_number = 8 # the [0, 360) interval will be subdivided into this
# number of equal bins
zone = np.pi / 8
theta_list[theta_list < (-np.pi + zone)] = theta_list[theta_list < (-np.pi + zone)] + np.pi * 2
bins = np.linspace(-np.pi + zone, np.pi + zone, bins_number + 1)
n, _, _ = plt.hist(theta_list, bins, edgecolor="black")
# ax.set_theta_offset(-np.pi / 8 - np.pi)
ax.set_yticklabels([])
plt.grid(True)
import seaborn as sns
sns.jointplot(data=df, x="Spatial Frequency", y="Temporal Frequency", hue="Class", xlim=[0, 0.3], ylim=[0, 0.3])
plt.grid(True)
g = sns.jointplot(data=df, x="Spatial Frequency", y="Temporal Frequency", xlim=[0, 0.25], ylim=[0, 0.25])
# g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=6)
plt.grid(True)
plt.show()
# show spatial frequency preference and temporal frequency preference.
x = np.linspace(0, t_kz, t_point)
index = 0
for scale in range(len(model.spatial_filter)):
t_sin, t_cos = model.temporal_filter[scale].demo_temporal_filter(points=t_point)
gb_sin_b, gb_cos_b = model.spatial_filter[scale].demo_gabor_filters(points=s_point)
for i in range(gb_sin_b.size(0)):
plt.figure(figsize=(14, 9), dpi=80)
plt.subplot(2, 3, 1)
curve = gb_sin_b[i].squeeze().detach().numpy()
plt.imshow(curve)
plt.title("Gabor Sin")
plt.subplot(2, 3, 2)
curve = gb_cos_b[i].squeeze().detach().numpy()
plt.imshow(curve)
plt.title("Gabor Cos")
plt.subplot(2, 3, 3)
curve = t_sin[i].squeeze().detach().numpy()
plt.plot(x, curve, label='sin')
plt.title("Temporal Sin")
curve = t_cos[i].squeeze().detach().numpy()
plt.plot(x, curve, label='cos')
plt.xlabel('Time (s)')
plt.ylabel('Response to pulse at t=0')
plt.legend()
plt.title("Temporal filter")
gb_sin = gb_sin_b[i].squeeze().detach()[5, :]
gb_cos = gb_cos_b[i].squeeze().detach()[5, :]
a = np.outer(t_cos[i].detach(), gb_sin)
b = np.outer(t_sin[i].detach(), gb_cos)
g_o = a + b
a = np.outer(t_sin[i].detach(), gb_sin)
b = np.outer(t_cos[i].detach(), gb_cos)
g_e = a - b
energy_component = g_o ** 2 + g_e ** 2
plt.subplot(2, 3, 4)
curve = g_o
plt.imshow(curve, cmap="gray")
plt.title("Spatial Temporal even")
plt.subplot(2, 3, 5)
curve = g_e
plt.imshow(curve, cmap="gray")
plt.title("Spatial Temporal odd")
plt.subplot(2, 3, 6)
curve = energy_component
plt.imshow(curve, cmap="gray")
plt.title("energy")
plt.savefig('filter_%d.png' % (index))
filenames.append('filter_%d.png' % (index))
index += 1
# plt.show()
# build gif
with imageio.get_writer('filters_orientation.gif', mode='I') as writer:
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
# Remove files
for filename in set(filenames):
os.remove(filename)
if __name__ == "__main__":
show_trained_model()
# V1.demo()
# draw_polar()
# # V1.demo()
# # draw_polar()
show_trained_model()
# te_spatial_temporal()