|
import torch |
|
import torch.nn as nn |
|
|
|
class SimpleUpscaleModel(nn.Module): |
|
def __init__(self, scale_factor=2): |
|
""" |
|
A simple model for upscaling inputs using bilinear interpolation. |
|
|
|
Args: |
|
scale_factor (int): The factor by which to upscale the input. |
|
""" |
|
super(SimpleUpscaleModel, self).__init__() |
|
|
|
|
|
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True) |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass of the network. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). |
|
|
|
Returns: |
|
torch.Tensor: Upscaled output tensor. |
|
""" |
|
return self.upsample(x) |
|
|
|
if __name__ == "__main__": |
|
|
|
scale_factor = 2 |
|
model = SimpleUpscaleModel(scale_factor=scale_factor) |
|
|
|
|
|
model_path = "model_weights.pth" |
|
torch.save(model.state_dict(), model_path) |
|
|
|
print(f"Model saved to {model_path}") |