Guest / tensor_network.py
Prositron's picture
Create tensor_network.py
8ab2d14 verified
raw
history blame
891 Bytes
import torch
import torch.nn as nn
# Define a simple neural network model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) # Example Conv layer
self.fc1 = nn.Linear(16*4*4, 10) # Flattening the 4D tensor to 2D for a fully connected layer
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = x.view(x.size(0), -1) # Flatten for the fully connected layer
x = self.fc1(x)
return x
# Example of creating tensors for input
tensor1 = torch.rand(2, 3, 4, 4)
tensor2 = torch.rand(2, 3, 4, 4)
tensor3 = torch.rand(2, 3, 4, 4)
# Adding tensors
input_tensor = tensor1 + tensor2 + tensor3
# Initialize the model
model = SimpleModel()
# Forward pass through the model
output = model(input_tensor)
print(output)