|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class SimpleModel(nn.Module): |
|
def __init__(self): |
|
super(SimpleModel, self).__init__() |
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) |
|
self.fc1 = nn.Linear(16*4*4, 10) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = torch.relu(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.fc1(x) |
|
return x |
|
|
|
|
|
tensor1 = torch.rand(2, 3, 4, 4) |
|
tensor2 = torch.rand(2, 3, 4, 4) |
|
tensor3 = torch.rand(2, 3, 4, 4) |
|
|
|
|
|
input_tensor = tensor1 + tensor2 + tensor3 |
|
|
|
|
|
model = SimpleModel() |
|
|
|
|
|
output = model(input_tensor) |
|
|
|
print(output) |