Mavthunder commited on
Commit
7dc26b7
·
verified ·
1 Parent(s): 2ff8432

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -11,26 +11,38 @@ import numpy as np
11
  class ZeroDCE(nn.Module):
12
  def __init__(self):
13
  super(ZeroDCE, self).__init__()
14
- self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
15
- self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
16
- self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
17
- self.conv4 = nn.Conv2d(32, 24, 3, padding=1)
18
  self.relu = nn.ReLU(inplace=True)
 
 
 
 
 
19
 
20
  def forward(self, x):
21
  x1 = self.relu(self.conv1(x))
22
  x2 = self.relu(self.conv2(x1))
23
  x3 = self.relu(self.conv3(x2))
24
- x_r = torch.tanh(self.conv4(x3))
25
- return x_r
 
26
 
27
  def enhance_image(img, model):
28
- img_tensor = torch.from_numpy(np.array(img)).float() / 255.0
29
- img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device)
 
 
30
  with torch.no_grad():
31
- enhanced = model(img_tensor)
 
 
 
 
 
 
 
32
  enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
33
- enhanced = np.clip(enhanced * 255, 0, 255).astype(np.uint8)
 
34
  return Image.fromarray(enhanced)
35
 
36
  # -----------------------------
 
11
  class ZeroDCE(nn.Module):
12
  def __init__(self):
13
  super(ZeroDCE, self).__init__()
 
 
 
 
14
  self.relu = nn.ReLU(inplace=True)
15
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1, bias=True)
16
+ self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
17
+ self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=True)
18
+ self.conv4 = nn.Conv2d(32, 24, 3, padding=1, bias=True)
19
+ self.conv5 = nn.Conv2d(24, 8, 3, padding=1, bias=True)
20
 
21
  def forward(self, x):
22
  x1 = self.relu(self.conv1(x))
23
  x2 = self.relu(self.conv2(x1))
24
  x3 = self.relu(self.conv3(x2))
25
+ x4 = self.relu(self.conv4(x3))
26
+ out = torch.tanh(self.conv5(x4))
27
+ return out
28
 
29
  def enhance_image(img, model):
30
+ # Convert PIL -> Tensor
31
+ img_np = np.array(img).astype(np.float32) / 255.0
32
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
33
+
34
  with torch.no_grad():
35
+ enhancement_map = model(img_tensor)
36
+
37
+ # Apply enhancement: simple iterative curve
38
+ enhanced = img_tensor
39
+ for i in range(8):
40
+ enhanced = enhanced + enhancement_map[:, i*3:(i+1)*3, :, :] * (enhanced**2 - enhanced)
41
+
42
+ enhanced = torch.clamp(enhanced, 0, 1)
43
  enhanced = enhanced.squeeze(0).permute(1, 2, 0).cpu().numpy()
44
+ enhanced = (enhanced * 255).astype(np.uint8)
45
+
46
  return Image.fromarray(enhanced)
47
 
48
  # -----------------------------