Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright 2024 Wen-Chin Huang | |
# MIT License (https://opensource.org/licenses/MIT) | |
# LDNet modules | |
# taken from: https://github.com/unilight/LDNet/blob/main/models/modules.py (written by myself) | |
import torch | |
from torch import nn | |
STRIDE = 3 | |
class Projection(nn.Module): | |
def __init__( | |
self, | |
in_dim, | |
hidden_dim, | |
activation, | |
output_type, | |
_output_dim, | |
output_step=1.0, | |
range_clipping=False, | |
): | |
super(Projection, self).__init__() | |
self.output_type = output_type | |
self.range_clipping = range_clipping | |
if output_type == "scalar": | |
output_dim = 1 | |
if range_clipping: | |
self.proj = nn.Tanh() | |
elif output_type == "categorical": | |
output_dim = _output_dim | |
self.output_step = output_step | |
else: | |
raise NotImplementedError("wrong output_type: {}".format(output_type)) | |
self.net = nn.Sequential( | |
nn.Linear(in_dim, hidden_dim), | |
activation(), | |
nn.Dropout(0.3), | |
nn.Linear(hidden_dim, output_dim), | |
) | |
def forward(self, x, inference=False): | |
output = self.net(x) | |
# scalar / categorical | |
if self.output_type == "scalar": | |
# range clipping | |
if self.range_clipping: | |
return self.proj(output) * 2.0 + 3 | |
else: | |
return output | |
else: | |
if inference: | |
return torch.argmax(output, dim=-1) * self.output_step + 1 | |
else: | |
return output |