dummy / serkan.py
Serkan Ozturk
sdfgsdfg
687e8c3
import torch
import torch.nn as nn
class SimpleUpscaleModel(nn.Module):
def __init__(self, scale_factor=2):
"""
A simple model for upscaling inputs using bilinear interpolation.
Args:
scale_factor (int): The factor by which to upscale the input.
"""
super(SimpleUpscaleModel, self).__init__()
# Upsampling layer
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)
def forward(self, x):
"""
Forward pass of the network.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Upscaled output tensor.
"""
return self.upsample(x)
if __name__ == "__main__":
# Create the model
scale_factor = 2
model = SimpleUpscaleModel(scale_factor=scale_factor)
# Save the model
model_path = "model_weights.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")