dcher95 commited on
Commit
e237035
·
verified ·
1 Parent(s): 72b2774

Upload render.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. render.py +91 -0
render.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ResidualRenderBlock(nn.Module):
6
+ def __init__(self, dim):
7
+ super().__init__()
8
+ self.block = nn.Sequential(
9
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1),
10
+ nn.GroupNorm(8, dim),
11
+ nn.SiLU(),
12
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1),
13
+ nn.GroupNorm(8, dim)
14
+ )
15
+
16
+ def forward(self, x):
17
+ return x + self.block(x)
18
+
19
+ class RenderEncoder(nn.Module):
20
+ def __init__(self, encoder_type="1d", in_channels=768, out_channels=3):
21
+ super().__init__()
22
+ self.encoder_type = encoder_type
23
+
24
+ if encoder_type == "1d":
25
+ self.model = nn.Sequential(
26
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
27
+ nn.Sigmoid()
28
+ )
29
+
30
+ elif encoder_type == "residual":
31
+ self.model = ResidualBlockRender(in_channels, out_channels)
32
+
33
+ elif encoder_type == "expressive":
34
+ mid_channels = 256
35
+ self.model = nn.Sequential(
36
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
37
+ nn.GroupNorm(8, mid_channels),
38
+ nn.SiLU(),
39
+ ResidualRenderBlock(mid_channels),
40
+ ResidualRenderBlock(mid_channels),
41
+ ResidualRenderBlock(mid_channels),
42
+ nn.Conv2d(mid_channels, out_channels, kernel_size=1),
43
+ nn.Sigmoid()
44
+ )
45
+
46
+ else:
47
+ raise ValueError(f"Unknown encoder_type '{encoder_type}'. Use '1d', 'residual', or 'expressive'.")
48
+
49
+ def forward(self, x):
50
+ return self.model(x)
51
+
52
+ class ResidualBlockRender(nn.Module):
53
+ def __init__(self, in_channels=768, out_channels=3):
54
+ super().__init__()
55
+ self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
56
+ self.relu1 = nn.ReLU()
57
+ self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
58
+ self.relu2 = nn.ReLU()
59
+ self.conv3 = nn.Conv2d(256, out_channels, kernel_size=1)
60
+ self.out = nn.Sigmoid()
61
+
62
+ if in_channels != out_channels:
63
+ self.residual_proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)
64
+ else:
65
+ self.residual_proj = nn.Identity()
66
+
67
+ def forward(self, x):
68
+ residual = self.residual_proj(x)
69
+ h = self.relu1(self.conv1(x))
70
+ h = self.relu2(self.conv2(h))
71
+ h = self.conv3(h)
72
+ h = h + residual
73
+ return self.out(h)
74
+
75
+ def load_render_encoder(checkpoint_path, device='cpu'):
76
+ """Load standalone RenderEncoder from checkpoint"""
77
+ checkpoint = torch.load(checkpoint_path, map_location=device)
78
+
79
+ config = checkpoint['model_config']
80
+ model = RenderEncoder(
81
+ encoder_type=config['encoder_type'],
82
+ in_channels=config['in_channels'],
83
+ out_channels=config['out_channels']
84
+ )
85
+
86
+ model.load_state_dict(checkpoint['model_state_dict'])
87
+ model.to(device)
88
+ model.eval()
89
+
90
+ print(f"Loaded RenderEncoder: {config}")
91
+ return model