haoning.wu commited on
Commit
f8ea2c9
1 Parent(s): bf92928

Add InstructIR plugin!

Browse files
app.py CHANGED
@@ -1,10 +1,85 @@
 
1
  import gradio as gr
2
  import requests
 
 
3
  from PIL import Image
4
 
 
5
  import torch
6
  from transformers import AutoModelForCausalLM
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  model = AutoModelForCausalLM.from_pretrained("q-future/co-instruct-preview",
9
  trust_remote_code=True,
10
  torch_dtype=torch.float16,
@@ -15,7 +90,7 @@ def chat(message, history, image_1, image_2, image_3, image_4):
15
  print(history)
16
  if history:
17
  if image_1 is not None and image_2 is None:
18
- past_message = "USER: The image: <|image|> " + history[0][0] + " ASSISTANT:" + history[0][1]
19
  for i in range((len(history) - 1)):
20
  past_message += "USER:" +history[i][0] + " ASSISTANT:" + history[i][1] + "</s>"
21
  message = past_message + "USER:" + message + " ASSISTANT:"
@@ -42,7 +117,7 @@ def chat(message, history, image_1, image_2, image_3, image_4):
42
  images = [image_1, image_2, image_3, image_4]
43
  else:
44
  if image_1 is not None and image_2 is None:
45
- message = "USER: The image: <|image|> " + message + " ASSISTANT:"
46
  images = [image_1]
47
  if image_1 is not None and image_2 is not None:
48
  if image_3 is None:
@@ -58,14 +133,24 @@ def chat(message, history, image_1, image_2, image_3, image_4):
58
 
59
  print(message)
60
 
61
- return model.tokenizer.batch_decode(model.chat(message, images, max_new_tokens=300).clamp(0, 100000))[0].split("ASSISTANT:")[-1]
 
 
 
 
 
 
 
 
62
 
63
 
 
64
  with gr.Blocks(title="img") as demo:
65
  title_markdown = ("""
66
- <h3 align="center">*Super Version of Q-Instruct with Multi-image (up to 4, same as GPT-4V) Support!*</h3>
67
  <h1 align="center"><a href="https://github.com/Q-Future/Q-Instruct"><img src="https://github.com/Q-Future/Q-Instruct/blob/main/q_instruct_logo.png?raw=true", alt="Q-Instruct (mPLUG-Owl-2)" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1>
68
  <h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
 
69
  <h5 align="center"> Please find our more accurate visual scoring demo on <a href='https://huggingface.co/spaces/teowu/OneScorer'>[OneScorer]</a>!</h2>
70
  <div align="center">
71
  <div style="display:flex; gap: 0.25rem;" align="center">
@@ -81,5 +166,16 @@ with gr.Blocks(title="img") as demo:
81
  input_img_2 = gr.Image(type='pil', label="Image 2 (Second image)")
82
  input_img_3 = gr.Image(type='pil', label="Image 3 (Third image)")
83
  input_img_4 = gr.Image(type='pil', label="Image 4 (Third image)")
84
- gr.ChatInterface(fn = chat, additional_inputs=[input_img_1, input_img_2, input_img_3, input_img_4])
 
 
 
 
 
 
 
 
 
 
 
85
  demo.launch(share=True)
 
1
+ import os, yaml
2
  import gradio as gr
3
  import requests
4
+ import argparse
5
+
6
  from PIL import Image
7
 
8
+ import numpy as np
9
  import torch
10
  from transformers import AutoModelForCausalLM
11
 
12
+ from huggingface_hub import hf_hub_download
13
+
14
+
15
+ ## InstructIR Plugin ##
16
+ from insir_models import instructir
17
+ from insir_text.models import LanguageModel, LMHead
18
+
19
+ hf_hub_download(repo_id="marcosv/InstructIR", filename="im_instructir-7d.pt", local_dir="./")
20
+ hf_hub_download(repo_id="marcosv/InstructIR", filename="lm_instructir-7d.pt", local_dir="./")
21
+
22
+ CONFIG = "eval5d.yml"
23
+ LM_MODEL = "lm_instructir-7d.pt"
24
+ MODEL_NAME = "im_instructir-7d.pt"
25
+
26
+ def dict2namespace(config):
27
+ namespace = argparse.Namespace()
28
+ for key, value in config.items():
29
+ if isinstance(value, dict):
30
+ new_value = dict2namespace(value)
31
+ else:
32
+ new_value = value
33
+ setattr(namespace, key, new_value)
34
+ return namespace
35
+
36
+
37
+ # parse config file
38
+ with open(os.path.join(CONFIG), "r") as f:
39
+ config = yaml.safe_load(f)
40
+
41
+ cfg = dict2namespace(config)
42
+
43
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
44
+ ir_model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
45
+ middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
46
+ ir_model = ir_model.to(device)
47
+ print ("IMAGE MODEL CKPT:", MODEL_NAME)
48
+ ir_model.load_state_dict(torch.load(MODEL_NAME, map_location="cpu"), strict=True)
49
+
50
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
51
+ LMODEL = cfg.llm.model
52
+ language_model = LanguageModel(model=LMODEL)
53
+ lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
54
+ lm_head = lm_head.to(device)
55
+
56
+ print("LMHEAD MODEL CKPT:", LM_MODEL)
57
+ lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True)
58
+
59
+ def process_img(image, prompt=None):
60
+ if prompt is None:
61
+ prompt = chat("How to improve the quality of the image?", [], image, None, None, None)
62
+ prompt += "Please help me improve its quality!"
63
+ print(prompt)
64
+ img = np.array(image)
65
+ img = img / 255.
66
+ img = img.astype(np.float32)
67
+ y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
68
+
69
+ lm_embd = language_model(prompt)
70
+ lm_embd = lm_embd.to(device)
71
+
72
+ with torch.no_grad():
73
+ text_embd, deg_pred = lm_head(lm_embd)
74
+ x_hat = ir_model(y, text_embd)
75
+
76
+ restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
77
+ restored_img = np.clip(restored_img, 0. , 1.)
78
+
79
+ restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
80
+ return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
81
+
82
+ ## InstructIR Plugin ##
83
  model = AutoModelForCausalLM.from_pretrained("q-future/co-instruct-preview",
84
  trust_remote_code=True,
85
  torch_dtype=torch.float16,
 
90
  print(history)
91
  if history:
92
  if image_1 is not None and image_2 is None:
93
+ past_message = "USER: The input image: <|image|>" + history[0][0] + " ASSISTANT:" + history[0][1]
94
  for i in range((len(history) - 1)):
95
  past_message += "USER:" +history[i][0] + " ASSISTANT:" + history[i][1] + "</s>"
96
  message = past_message + "USER:" + message + " ASSISTANT:"
 
117
  images = [image_1, image_2, image_3, image_4]
118
  else:
119
  if image_1 is not None and image_2 is None:
120
+ message = "USER: The input image: <|image|>" + message + " ASSISTANT:"
121
  images = [image_1]
122
  if image_1 is not None and image_2 is not None:
123
  if image_3 is None:
 
133
 
134
  print(message)
135
 
136
+ return model.tokenizer.batch_decode(model.chat(message, images, max_new_tokens=600).clamp(0, 100000))[0].split("ASSISTANT:")[-1]
137
+
138
+ #### Image,Prompts examples
139
+ examples = [
140
+ ["Which part of the image is relatively clearer, the upper part or the lower part? Please analyze in details.", Image.open("examples/sausage.jpg"), None],
141
+ ["Which image is noisy, and which one is with motion blur? Please analyze in details.", Image.open("examples/211.jpg"), Image.open("examples/frog.png")],
142
+ ["What is the problem in this image, and how to fix it? Please answer my questions one by one.", Image.open("examples/lol_748.png"), None],
143
+ ]
144
+
145
 
146
 
147
+ title = "Q-Instruct🧑‍🏫"
148
  with gr.Blocks(title="img") as demo:
149
  title_markdown = ("""
150
+
151
  <h1 align="center"><a href="https://github.com/Q-Future/Q-Instruct"><img src="https://github.com/Q-Future/Q-Instruct/blob/main/q_instruct_logo.png?raw=true", alt="Q-Instruct (mPLUG-Owl-2)" border="0" style="margin: 0 auto; height: 85px;" /></a> </h1>
152
  <h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
153
+ <div align="center">Super Version of Q-Instruct with Multi-image (up to 4, same as GPT-4V) Support! We also support <a href='https://huggingface.co/marcosv/InstructIR'>InstructIR</a> as PLUGIN!</div>
154
  <h5 align="center"> Please find our more accurate visual scoring demo on <a href='https://huggingface.co/spaces/teowu/OneScorer'>[OneScorer]</a>!</h2>
155
  <div align="center">
156
  <div style="display:flex; gap: 0.25rem;" align="center">
 
166
  input_img_2 = gr.Image(type='pil', label="Image 2 (Second image)")
167
  input_img_3 = gr.Image(type='pil', label="Image 3 (Third image)")
168
  input_img_4 = gr.Image(type='pil', label="Image 4 (Third image)")
169
+ with gr.Row():
170
+ with gr.Column(scale=2):
171
+ gr.ChatInterface(fn = chat, additional_inputs=[input_img_1, input_img_2, input_img_3, input_img_4], examples=examples)
172
+ with gr.Column(scale=1):
173
+ input_image_ir = gr.Image(type="pil", label="Image for Auto Restoration")
174
+ output_image_ir = gr.Image(type="pil", label="Output of Auto Restoration")
175
+ gr.Interface(
176
+ fn=process_img,
177
+ inputs=[input_image_ir],
178
+ outputs=[output_image_ir],
179
+ examples=[Image.open("examples/gopro.png"), Image.open("examples/noise50.png"), Image.open("examples/lol_748.png")],
180
+ )
181
  demo.launch(share=True)
eval5d.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llm:
2
+ model: 'TaylorAI/bge-micro-v2' # See Paper Sec. 3.2 and Appendix
3
+ model_dim: 384
4
+ embd_dim: 256
5
+ nclasses: 7 # noise, blur, rain, haze, lol, enhancement, upsampling (Paper Sec. 4.3)
6
+ weights: False
7
+
8
+ model:
9
+ arch: "instructir"
10
+ use_text: True
11
+ in_ch: 3
12
+ out_ch: 3
13
+ width : 32
14
+ enc_blks: [2, 2, 4, 8]
15
+ middle_blk_num: 4
16
+ dec_blks: [2, 2, 2, 2]
17
+ textdim: 256
18
+ weights: False
19
+
20
+ test:
21
+ batch_size: 1
22
+ num_workers: 3
23
+
24
+ dn_datapath: "data/denoising_testsets/"
25
+ dn_datasets: ["CBSD68", "urban100", "Kodak24", "McMaster"]
26
+ dn_sigmas: [15, 25, 50]
27
+
28
+ rain_targets: ["data/Rain/rain_test/Rain100L/target/"]
29
+ rain_inputs: ["data/Rain/rain_test/Rain100L/input/"]
30
+
31
+ haze_targets: "data/SOTS-OUT/GT/"
32
+ haze_inputs : "data/SOTS-OUT/IN/"
33
+
34
+ lol_targets: "data/LOL/eval15/high/"
35
+ lol_inputs : "data/LOL/eval15/low/"
36
+
37
+ gopro_targets: "data/gopro_test/GoPro/target/"
38
+ gopro_inputs: "data/gopro_test/GoPro/input/"
39
+
40
+
examples/211.jpg CHANGED

Git LFS Details

  • SHA256: 7980c3c75b6eccd5519918344d03c6e8ba654f3faab2a4aae96e3baddd649a18
  • Pointer size: 130 Bytes
  • Size of remote file: 43.2 kB
examples/extreme_ironing.jpg CHANGED

Git LFS Details

  • SHA256: a54caa21bc513ed25c8ca7f5747555c05dfd4e33f6a3cf5c08b3d9138a4da1d9
  • Pointer size: 130 Bytes
  • Size of remote file: 62.6 kB
examples/frog.png ADDED

Git LFS Details

  • SHA256: 36adda1ff6c39824e480eb92583ca3e2ceea29d9cb206cca880781a102611b11
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
examples/gopro.png ADDED

Git LFS Details

  • SHA256: 2b844eac02ac3499bea0dbccb382e8d4caea026ec6d2092d375e6d4c09f17b09
  • Pointer size: 131 Bytes
  • Size of remote file: 388 kB
examples/lol_748.png ADDED

Git LFS Details

  • SHA256: 325c720df5669e37b9f192bfa9a60add144b82e5e68d9f684c0010a0047b0056
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
examples/noise50.png ADDED

Git LFS Details

  • SHA256: fa84462babeaafdebae7709f71fc048f415e2abeb4e263c69f908265923f3301
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
examples/sausage.jpg CHANGED

Git LFS Details

  • SHA256: f5808fb71099077067cf92b3e4bbd8ddc4c179fa575091ff69dca9c96c175741
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB
insir_models/.ipynb_checkpoints/instructir-checkpoint.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from insir_models.nafnet_utils import Local_Base, LayerNorm2d
9
+ from insir_models.nafnet import SimpleGate, NAFBlock
10
+
11
+
12
+ class ICB(nn.Module):
13
+ """
14
+ Instruction Condition Block (ICB)
15
+ Paper Section 3.3
16
+ """
17
+
18
+ def __init__(self, feature_dim, text_dim=768):
19
+ super(ICB, self).__init__()
20
+ self.fc = nn.Linear(text_dim, feature_dim)
21
+ self.block = NAFBlock(feature_dim)
22
+ self.beta = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
23
+ self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
24
+
25
+ def forward(self, x, text_embedding):
26
+ gating_factors = torch.sigmoid(self.fc(text_embedding))
27
+ gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)
28
+
29
+ f = x * self.gamma + self.beta # 1) learned feature scaling/modulation
30
+ f = f * gating_factors # 2) (soft) feature routing based on text
31
+ f = self.block(f) # 3) block feature enhancement
32
+ return f + x
33
+
34
+
35
+ class InstructIR(nn.Module):
36
+ """
37
+ InstructIR model using NAFNet (ECCV 2022) as backbone.
38
+ The model takes as input an RGB image and a text embedding (encoded instruction).
39
+ Described in Paper Section 3.3
40
+ """
41
+
42
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
43
+ super().__init__()
44
+
45
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
46
+ bias=True)
47
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
48
+ bias=True)
49
+
50
+ self.encoders = nn.ModuleList()
51
+ self.decoders = nn.ModuleList()
52
+ self.middle_blks = nn.ModuleList()
53
+ self.ups = nn.ModuleList()
54
+ self.downs = nn.ModuleList()
55
+ self.enc_cond = nn.ModuleList()
56
+ self.dec_cond = nn.ModuleList()
57
+
58
+ chan = width
59
+ for num in enc_blk_nums:
60
+ self.encoders.append(
61
+ nn.Sequential(
62
+ *[NAFBlock(chan) for _ in range(num)]
63
+ )
64
+ )
65
+
66
+ self.enc_cond.append(ICB(chan, txtdim))
67
+
68
+ self.downs.append(
69
+ nn.Conv2d(chan, 2*chan, 2, 2)
70
+ )
71
+ chan = chan * 2
72
+
73
+ self.middle_blks = nn.Sequential(
74
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
75
+ )
76
+
77
+ for num in dec_blk_nums:
78
+ self.ups.append(
79
+ nn.Sequential(
80
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
81
+ nn.PixelShuffle(2)
82
+ )
83
+ )
84
+ chan = chan // 2
85
+ self.decoders.append(
86
+ nn.Sequential(
87
+ *[NAFBlock(chan) for _ in range(num)]
88
+ )
89
+ )
90
+ # Add text embedding as modulation
91
+ self.dec_cond.append(ICB(chan, txtdim))
92
+
93
+ self.padder_size = 2 ** len(self.encoders)
94
+
95
+ def forward(self, inp, txtembd):
96
+ B, C, H, W = inp.shape
97
+ inp = self.check_image_size(inp)
98
+
99
+ x = self.intro(inp)
100
+ encs = []
101
+
102
+ for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
103
+ x = encoder(x)
104
+ x = enc_mod(x, txtembd)
105
+ encs.append(x)
106
+ x = down(x)
107
+
108
+ x = self.middle_blks(x)
109
+
110
+ for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
111
+ x = up(x)
112
+ x = x + enc_skip
113
+ x = decoder(x)
114
+ x = dec_mod(x, txtembd)
115
+
116
+ x = self.ending(x)
117
+ x = x + inp
118
+
119
+ return x[:, :, :H, :W]
120
+
121
+ def check_image_size(self, x):
122
+ _, _, h, w = x.size()
123
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
124
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
125
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
126
+ return x
127
+
128
+
129
+ def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):
130
+
131
+ net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
132
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)
133
+
134
+ return net
insir_models/.ipynb_checkpoints/nafnet-checkpoint.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Source: https://github.com/megvii-research/NAFNet
5
+
6
+ '''
7
+ Simple Baselines for Image Restoration
8
+
9
+ @article{chen2022simple,
10
+ title={Simple Baselines for Image Restoration},
11
+ author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
12
+ journal={arXiv preprint arXiv:2204.04676},
13
+ year={2022}
14
+ }
15
+ '''
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import init as init
22
+ from torch.nn.modules.batchnorm import _BatchNorm
23
+ from insir_models.nafnet_utils import Local_Base, LayerNorm2d
24
+
25
+
26
+ class SimpleGate(nn.Module):
27
+ def forward(self, x):
28
+ x1, x2 = x.chunk(2, dim=1)
29
+ return x1 * x2
30
+
31
+ class NAFBlock(nn.Module):
32
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
33
+ super().__init__()
34
+ dw_channel = c * DW_Expand
35
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
36
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
37
+ bias=True)
38
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
39
+
40
+ # Simplified Channel Attention
41
+ self.sca = nn.Sequential(
42
+ nn.AdaptiveAvgPool2d(1),
43
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
44
+ groups=1, bias=True),
45
+ )
46
+
47
+ # SimpleGate
48
+ self.sg = SimpleGate()
49
+
50
+ ffn_channel = FFN_Expand * c
51
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
52
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
53
+
54
+ self.norm1 = LayerNorm2d(c)
55
+ self.norm2 = LayerNorm2d(c)
56
+
57
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
58
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
59
+
60
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
61
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
62
+
63
+ def forward(self, inp):
64
+ x = inp
65
+
66
+ x = self.norm1(x)
67
+
68
+ x = self.conv1(x)
69
+ x = self.conv2(x)
70
+ x = self.sg(x)
71
+ x = x * self.sca(x)
72
+ x = self.conv3(x)
73
+
74
+ x = self.dropout1(x)
75
+
76
+ y = inp + x * self.beta
77
+
78
+ x = self.conv4(self.norm2(y))
79
+ x = self.sg(x)
80
+ x = self.conv5(x)
81
+
82
+ x = self.dropout2(x)
83
+
84
+ return y + x * self.gamma
85
+
86
+
87
+ class NAFNet(nn.Module):
88
+
89
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
90
+ super().__init__()
91
+
92
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
93
+ bias=True)
94
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
95
+ bias=True)
96
+
97
+ self.encoders = nn.ModuleList()
98
+ self.decoders = nn.ModuleList()
99
+ self.middle_blks = nn.ModuleList()
100
+ self.ups = nn.ModuleList()
101
+ self.downs = nn.ModuleList()
102
+
103
+ chan = width
104
+ for num in enc_blk_nums:
105
+ self.encoders.append(
106
+ nn.Sequential(
107
+ *[NAFBlock(chan) for _ in range(num)]
108
+ )
109
+ )
110
+ self.downs.append(
111
+ nn.Conv2d(chan, 2*chan, 2, 2)
112
+ )
113
+ chan = chan * 2
114
+
115
+ self.middle_blks = \
116
+ nn.Sequential(
117
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
118
+ )
119
+
120
+ for num in dec_blk_nums:
121
+ self.ups.append(
122
+ nn.Sequential(
123
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
124
+ nn.PixelShuffle(2)
125
+ )
126
+ )
127
+ chan = chan // 2
128
+ self.decoders.append(
129
+ nn.Sequential(
130
+ *[NAFBlock(chan) for _ in range(num)]
131
+ )
132
+ )
133
+
134
+ self.padder_size = 2 ** len(self.encoders)
135
+
136
+ def forward(self, inp):
137
+ B, C, H, W = inp.shape
138
+ inp = self.check_image_size(inp)
139
+
140
+ x = self.intro(inp)
141
+
142
+ encs = []
143
+
144
+ for encoder, down in zip(self.encoders, self.downs):
145
+ x = encoder(x)
146
+ encs.append(x)
147
+ x = down(x)
148
+
149
+ x = self.middle_blks(x)
150
+
151
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
152
+ x = up(x)
153
+ x = x + enc_skip
154
+ x = decoder(x)
155
+
156
+ x = self.ending(x)
157
+ x = x + inp
158
+
159
+ return x[:, :, :H, :W]
160
+
161
+ def check_image_size(self, x):
162
+ _, _, h, w = x.size()
163
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
164
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
165
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
166
+ return x
167
+
168
+ class NAFNetLocal(Local_Base, NAFNet):
169
+ def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
170
+ Local_Base.__init__(self)
171
+ NAFNet.__init__(self, *args, **kwargs)
172
+
173
+ N, C, H, W = train_size
174
+ base_size = (int(H * 1.5), int(W * 1.5))
175
+
176
+ self.eval()
177
+ with torch.no_grad():
178
+ self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
179
+
180
+
181
+ def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
182
+ """
183
+ Create Nafnet model
184
+ https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
185
+ """
186
+
187
+ net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
188
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
189
+
190
+ # inp_shape = (3, 256, 256)
191
+
192
+ # from ptflops import get_model_complexity_info
193
+
194
+ # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
195
+
196
+ # params = float(params[:-3])
197
+ # macs = float(macs[:-4])
198
+
199
+ # print(macs, params)
200
+
201
+ return net
insir_models/.ipynb_checkpoints/nafnet_utils-checkpoint.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Source: https://github.com/megvii-research/NAFNet
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+ class LayerNormFunction(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ def forward(ctx, x, weight, bias, eps):
16
+ ctx.eps = eps
17
+ N, C, H, W = x.size()
18
+ mu = x.mean(1, keepdim=True)
19
+ var = (x - mu).pow(2).mean(1, keepdim=True)
20
+ y = (x - mu) / (var + eps).sqrt()
21
+ ctx.save_for_backward(y, var, weight)
22
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
23
+ return y
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ eps = ctx.eps
28
+
29
+ N, C, H, W = grad_output.size()
30
+ y, var, weight = ctx.saved_variables
31
+ g = grad_output * weight.view(1, C, 1, 1)
32
+ mean_g = g.mean(dim=1, keepdim=True)
33
+
34
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
35
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
36
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
37
+ dim=0), None
38
+
39
+ class LayerNorm2d(nn.Module):
40
+
41
+ def __init__(self, channels, eps=1e-6):
42
+ super(LayerNorm2d, self).__init__()
43
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
44
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
45
+ self.eps = eps
46
+
47
+ def forward(self, x):
48
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
49
+
50
+
51
+
52
+ class AvgPool2d(nn.Module):
53
+ def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
54
+ super().__init__()
55
+ self.kernel_size = kernel_size
56
+ self.base_size = base_size
57
+ self.auto_pad = auto_pad
58
+
59
+ # only used for fast implementation
60
+ self.fast_imp = fast_imp
61
+ self.rs = [5, 4, 3, 2, 1]
62
+ self.max_r1 = self.rs[0]
63
+ self.max_r2 = self.rs[0]
64
+ self.train_size = train_size
65
+
66
+ def extra_repr(self) -> str:
67
+ return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
68
+ self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
69
+ )
70
+
71
+ def forward(self, x):
72
+ if self.kernel_size is None and self.base_size:
73
+ train_size = self.train_size
74
+ if isinstance(self.base_size, int):
75
+ self.base_size = (self.base_size, self.base_size)
76
+ self.kernel_size = list(self.base_size)
77
+ self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
78
+ self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
79
+
80
+ # only used for fast implementation
81
+ self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
82
+ self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
83
+
84
+ if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
85
+ return F.adaptive_avg_pool2d(x, 1)
86
+
87
+ if self.fast_imp: # Non-equivalent implementation but faster
88
+ h, w = x.shape[2:]
89
+ if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
90
+ out = F.adaptive_avg_pool2d(x, 1)
91
+ else:
92
+ r1 = [r for r in self.rs if h % r == 0][0]
93
+ r2 = [r for r in self.rs if w % r == 0][0]
94
+ # reduction_constraint
95
+ r1 = min(self.max_r1, r1)
96
+ r2 = min(self.max_r2, r2)
97
+ s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
98
+ n, c, h, w = s.shape
99
+ k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
100
+ out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
101
+ out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
102
+ else:
103
+ n, c, h, w = x.shape
104
+ s = x.cumsum(dim=-1).cumsum_(dim=-2)
105
+ s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
106
+ k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
107
+ s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
108
+ out = s4 + s1 - s2 - s3
109
+ out = out / (k1 * k2)
110
+
111
+ if self.auto_pad:
112
+ n, c, h, w = x.shape
113
+ _h, _w = out.shape[2:]
114
+ # print(x.shape, self.kernel_size)
115
+ pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
116
+ out = torch.nn.functional.pad(out, pad2d, mode='replicate')
117
+
118
+ return out
119
+
120
+ def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
121
+ for n, m in model.named_children():
122
+ if len(list(m.children())) > 0:
123
+ ## compound module, go inside it
124
+ replace_layers(m, base_size, train_size, fast_imp, **kwargs)
125
+
126
+ if isinstance(m, nn.AdaptiveAvgPool2d):
127
+ pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
128
+ assert m.output_size == 1
129
+ setattr(model, n, pool)
130
+
131
+
132
+ '''
133
+ ref.
134
+ @article{chu2021tlsc,
135
+ title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
136
+ author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
137
+ journal={arXiv preprint arXiv:2112.04491},
138
+ year={2021}
139
+ }
140
+ '''
141
+ class Local_Base():
142
+ def convert(self, *args, train_size, **kwargs):
143
+ replace_layers(self, *args, train_size=train_size, **kwargs)
144
+ imgs = torch.rand(train_size)
145
+ with torch.no_grad():
146
+ self.forward(imgs)
insir_models/__pycache__/instructir.cpython-39.pyc ADDED
Binary file (4.22 kB). View file
 
insir_models/__pycache__/nafnet.cpython-39.pyc ADDED
Binary file (5.53 kB). View file
 
insir_models/__pycache__/nafnet_utils.cpython-39.pyc ADDED
Binary file (5.4 kB). View file
 
insir_models/instructir.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from insir_models.nafnet_utils import Local_Base, LayerNorm2d
9
+ from insir_models.nafnet import SimpleGate, NAFBlock
10
+
11
+
12
+ class ICB(nn.Module):
13
+ """
14
+ Instruction Condition Block (ICB)
15
+ Paper Section 3.3
16
+ """
17
+
18
+ def __init__(self, feature_dim, text_dim=768):
19
+ super(ICB, self).__init__()
20
+ self.fc = nn.Linear(text_dim, feature_dim)
21
+ self.block = NAFBlock(feature_dim)
22
+ self.beta = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
23
+ self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
24
+
25
+ def forward(self, x, text_embedding):
26
+ gating_factors = torch.sigmoid(self.fc(text_embedding))
27
+ gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)
28
+
29
+ f = x * self.gamma + self.beta # 1) learned feature scaling/modulation
30
+ f = f * gating_factors # 2) (soft) feature routing based on text
31
+ f = self.block(f) # 3) block feature enhancement
32
+ return f + x
33
+
34
+
35
+ class InstructIR(nn.Module):
36
+ """
37
+ InstructIR model using NAFNet (ECCV 2022) as backbone.
38
+ The model takes as input an RGB image and a text embedding (encoded instruction).
39
+ Described in Paper Section 3.3
40
+ """
41
+
42
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
43
+ super().__init__()
44
+
45
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
46
+ bias=True)
47
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
48
+ bias=True)
49
+
50
+ self.encoders = nn.ModuleList()
51
+ self.decoders = nn.ModuleList()
52
+ self.middle_blks = nn.ModuleList()
53
+ self.ups = nn.ModuleList()
54
+ self.downs = nn.ModuleList()
55
+ self.enc_cond = nn.ModuleList()
56
+ self.dec_cond = nn.ModuleList()
57
+
58
+ chan = width
59
+ for num in enc_blk_nums:
60
+ self.encoders.append(
61
+ nn.Sequential(
62
+ *[NAFBlock(chan) for _ in range(num)]
63
+ )
64
+ )
65
+
66
+ self.enc_cond.append(ICB(chan, txtdim))
67
+
68
+ self.downs.append(
69
+ nn.Conv2d(chan, 2*chan, 2, 2)
70
+ )
71
+ chan = chan * 2
72
+
73
+ self.middle_blks = nn.Sequential(
74
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
75
+ )
76
+
77
+ for num in dec_blk_nums:
78
+ self.ups.append(
79
+ nn.Sequential(
80
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
81
+ nn.PixelShuffle(2)
82
+ )
83
+ )
84
+ chan = chan // 2
85
+ self.decoders.append(
86
+ nn.Sequential(
87
+ *[NAFBlock(chan) for _ in range(num)]
88
+ )
89
+ )
90
+ # Add text embedding as modulation
91
+ self.dec_cond.append(ICB(chan, txtdim))
92
+
93
+ self.padder_size = 2 ** len(self.encoders)
94
+
95
+ def forward(self, inp, txtembd):
96
+ B, C, H, W = inp.shape
97
+ inp = self.check_image_size(inp)
98
+
99
+ x = self.intro(inp)
100
+ encs = []
101
+
102
+ for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
103
+ x = encoder(x)
104
+ x = enc_mod(x, txtembd)
105
+ encs.append(x)
106
+ x = down(x)
107
+
108
+ x = self.middle_blks(x)
109
+
110
+ for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
111
+ x = up(x)
112
+ x = x + enc_skip
113
+ x = decoder(x)
114
+ x = dec_mod(x, txtembd)
115
+
116
+ x = self.ending(x)
117
+ x = x + inp
118
+
119
+ return x[:, :, :H, :W]
120
+
121
+ def check_image_size(self, x):
122
+ _, _, h, w = x.size()
123
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
124
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
125
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
126
+ return x
127
+
128
+
129
+ def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):
130
+
131
+ net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
132
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)
133
+
134
+ return net
insir_models/nafnet.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Source: https://github.com/megvii-research/NAFNet
5
+
6
+ '''
7
+ Simple Baselines for Image Restoration
8
+
9
+ @article{chen2022simple,
10
+ title={Simple Baselines for Image Restoration},
11
+ author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
12
+ journal={arXiv preprint arXiv:2204.04676},
13
+ year={2022}
14
+ }
15
+ '''
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import init as init
22
+ from torch.nn.modules.batchnorm import _BatchNorm
23
+ from insir_models.nafnet_utils import Local_Base, LayerNorm2d
24
+
25
+
26
+ class SimpleGate(nn.Module):
27
+ def forward(self, x):
28
+ x1, x2 = x.chunk(2, dim=1)
29
+ return x1 * x2
30
+
31
+ class NAFBlock(nn.Module):
32
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
33
+ super().__init__()
34
+ dw_channel = c * DW_Expand
35
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
36
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
37
+ bias=True)
38
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
39
+
40
+ # Simplified Channel Attention
41
+ self.sca = nn.Sequential(
42
+ nn.AdaptiveAvgPool2d(1),
43
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
44
+ groups=1, bias=True),
45
+ )
46
+
47
+ # SimpleGate
48
+ self.sg = SimpleGate()
49
+
50
+ ffn_channel = FFN_Expand * c
51
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
52
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
53
+
54
+ self.norm1 = LayerNorm2d(c)
55
+ self.norm2 = LayerNorm2d(c)
56
+
57
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
58
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
59
+
60
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
61
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
62
+
63
+ def forward(self, inp):
64
+ x = inp
65
+
66
+ x = self.norm1(x)
67
+
68
+ x = self.conv1(x)
69
+ x = self.conv2(x)
70
+ x = self.sg(x)
71
+ x = x * self.sca(x)
72
+ x = self.conv3(x)
73
+
74
+ x = self.dropout1(x)
75
+
76
+ y = inp + x * self.beta
77
+
78
+ x = self.conv4(self.norm2(y))
79
+ x = self.sg(x)
80
+ x = self.conv5(x)
81
+
82
+ x = self.dropout2(x)
83
+
84
+ return y + x * self.gamma
85
+
86
+
87
+ class NAFNet(nn.Module):
88
+
89
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
90
+ super().__init__()
91
+
92
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
93
+ bias=True)
94
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
95
+ bias=True)
96
+
97
+ self.encoders = nn.ModuleList()
98
+ self.decoders = nn.ModuleList()
99
+ self.middle_blks = nn.ModuleList()
100
+ self.ups = nn.ModuleList()
101
+ self.downs = nn.ModuleList()
102
+
103
+ chan = width
104
+ for num in enc_blk_nums:
105
+ self.encoders.append(
106
+ nn.Sequential(
107
+ *[NAFBlock(chan) for _ in range(num)]
108
+ )
109
+ )
110
+ self.downs.append(
111
+ nn.Conv2d(chan, 2*chan, 2, 2)
112
+ )
113
+ chan = chan * 2
114
+
115
+ self.middle_blks = \
116
+ nn.Sequential(
117
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
118
+ )
119
+
120
+ for num in dec_blk_nums:
121
+ self.ups.append(
122
+ nn.Sequential(
123
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
124
+ nn.PixelShuffle(2)
125
+ )
126
+ )
127
+ chan = chan // 2
128
+ self.decoders.append(
129
+ nn.Sequential(
130
+ *[NAFBlock(chan) for _ in range(num)]
131
+ )
132
+ )
133
+
134
+ self.padder_size = 2 ** len(self.encoders)
135
+
136
+ def forward(self, inp):
137
+ B, C, H, W = inp.shape
138
+ inp = self.check_image_size(inp)
139
+
140
+ x = self.intro(inp)
141
+
142
+ encs = []
143
+
144
+ for encoder, down in zip(self.encoders, self.downs):
145
+ x = encoder(x)
146
+ encs.append(x)
147
+ x = down(x)
148
+
149
+ x = self.middle_blks(x)
150
+
151
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
152
+ x = up(x)
153
+ x = x + enc_skip
154
+ x = decoder(x)
155
+
156
+ x = self.ending(x)
157
+ x = x + inp
158
+
159
+ return x[:, :, :H, :W]
160
+
161
+ def check_image_size(self, x):
162
+ _, _, h, w = x.size()
163
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
164
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
165
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
166
+ return x
167
+
168
+ class NAFNetLocal(Local_Base, NAFNet):
169
+ def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
170
+ Local_Base.__init__(self)
171
+ NAFNet.__init__(self, *args, **kwargs)
172
+
173
+ N, C, H, W = train_size
174
+ base_size = (int(H * 1.5), int(W * 1.5))
175
+
176
+ self.eval()
177
+ with torch.no_grad():
178
+ self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
179
+
180
+
181
+ def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
182
+ """
183
+ Create Nafnet model
184
+ https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
185
+ """
186
+
187
+ net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
188
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
189
+
190
+ # inp_shape = (3, 256, 256)
191
+
192
+ # from ptflops import get_model_complexity_info
193
+
194
+ # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
195
+
196
+ # params = float(params[:-3])
197
+ # macs = float(macs[:-4])
198
+
199
+ # print(macs, params)
200
+
201
+ return net
insir_models/nafnet_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Source: https://github.com/megvii-research/NAFNet
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+ class LayerNormFunction(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ def forward(ctx, x, weight, bias, eps):
16
+ ctx.eps = eps
17
+ N, C, H, W = x.size()
18
+ mu = x.mean(1, keepdim=True)
19
+ var = (x - mu).pow(2).mean(1, keepdim=True)
20
+ y = (x - mu) / (var + eps).sqrt()
21
+ ctx.save_for_backward(y, var, weight)
22
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
23
+ return y
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ eps = ctx.eps
28
+
29
+ N, C, H, W = grad_output.size()
30
+ y, var, weight = ctx.saved_variables
31
+ g = grad_output * weight.view(1, C, 1, 1)
32
+ mean_g = g.mean(dim=1, keepdim=True)
33
+
34
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
35
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
36
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
37
+ dim=0), None
38
+
39
+ class LayerNorm2d(nn.Module):
40
+
41
+ def __init__(self, channels, eps=1e-6):
42
+ super(LayerNorm2d, self).__init__()
43
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
44
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
45
+ self.eps = eps
46
+
47
+ def forward(self, x):
48
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
49
+
50
+
51
+
52
+ class AvgPool2d(nn.Module):
53
+ def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
54
+ super().__init__()
55
+ self.kernel_size = kernel_size
56
+ self.base_size = base_size
57
+ self.auto_pad = auto_pad
58
+
59
+ # only used for fast implementation
60
+ self.fast_imp = fast_imp
61
+ self.rs = [5, 4, 3, 2, 1]
62
+ self.max_r1 = self.rs[0]
63
+ self.max_r2 = self.rs[0]
64
+ self.train_size = train_size
65
+
66
+ def extra_repr(self) -> str:
67
+ return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
68
+ self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
69
+ )
70
+
71
+ def forward(self, x):
72
+ if self.kernel_size is None and self.base_size:
73
+ train_size = self.train_size
74
+ if isinstance(self.base_size, int):
75
+ self.base_size = (self.base_size, self.base_size)
76
+ self.kernel_size = list(self.base_size)
77
+ self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
78
+ self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
79
+
80
+ # only used for fast implementation
81
+ self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
82
+ self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
83
+
84
+ if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
85
+ return F.adaptive_avg_pool2d(x, 1)
86
+
87
+ if self.fast_imp: # Non-equivalent implementation but faster
88
+ h, w = x.shape[2:]
89
+ if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
90
+ out = F.adaptive_avg_pool2d(x, 1)
91
+ else:
92
+ r1 = [r for r in self.rs if h % r == 0][0]
93
+ r2 = [r for r in self.rs if w % r == 0][0]
94
+ # reduction_constraint
95
+ r1 = min(self.max_r1, r1)
96
+ r2 = min(self.max_r2, r2)
97
+ s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
98
+ n, c, h, w = s.shape
99
+ k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
100
+ out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
101
+ out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
102
+ else:
103
+ n, c, h, w = x.shape
104
+ s = x.cumsum(dim=-1).cumsum_(dim=-2)
105
+ s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
106
+ k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
107
+ s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
108
+ out = s4 + s1 - s2 - s3
109
+ out = out / (k1 * k2)
110
+
111
+ if self.auto_pad:
112
+ n, c, h, w = x.shape
113
+ _h, _w = out.shape[2:]
114
+ # print(x.shape, self.kernel_size)
115
+ pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
116
+ out = torch.nn.functional.pad(out, pad2d, mode='replicate')
117
+
118
+ return out
119
+
120
+ def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
121
+ for n, m in model.named_children():
122
+ if len(list(m.children())) > 0:
123
+ ## compound module, go inside it
124
+ replace_layers(m, base_size, train_size, fast_imp, **kwargs)
125
+
126
+ if isinstance(m, nn.AdaptiveAvgPool2d):
127
+ pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
128
+ assert m.output_size == 1
129
+ setattr(model, n, pool)
130
+
131
+
132
+ '''
133
+ ref.
134
+ @article{chu2021tlsc,
135
+ title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
136
+ author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
137
+ journal={arXiv preprint arXiv:2112.04491},
138
+ year={2021}
139
+ }
140
+ '''
141
+ class Local_Base():
142
+ def convert(self, *args, train_size, **kwargs):
143
+ replace_layers(self, *args, train_size=train_size, **kwargs)
144
+ imgs = torch.rand(train_size)
145
+ with torch.no_grad():
146
+ self.forward(imgs)
insir_text/.ipynb_checkpoints/models-checkpoint.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
5
+ import os
6
+
7
+ # Models that use mean pooling
8
+ POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
9
+
10
+ #Mean Pooling - Take attention mask into account for correct averaging
11
+ def mean_pooling(model_output, attention_mask):
12
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
13
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
14
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
15
+
16
+
17
+ class LanguageModel(nn.Module):
18
+ def __init__(self, model='distilbert-base-uncased'):
19
+ super(LanguageModel, self).__init__()
20
+
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
22
+ self.model = AutoModel.from_pretrained(model)
23
+ self.model_name = model
24
+ # Remove the CLIP vision tower
25
+ if "clip" in self.model_name:
26
+ self.model.vision_model = None
27
+ # Freeze the pre-trained parameters (very important)
28
+ for param in self.model.parameters():
29
+ param.requires_grad = False
30
+
31
+ # Make sure to set evaluation mode (also important)
32
+ self.model.eval()
33
+
34
+ def forward(self, text_batch):
35
+ inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
36
+ with torch.no_grad(): # Ensure no gradients are computed for this forward pass
37
+
38
+ if "clip" in self.model_name:
39
+ sentence_embedding = self.model.get_text_features(**inputs)
40
+ return sentence_embedding
41
+
42
+ outputs = self.model(**inputs)
43
+
44
+ if any(model in self.model_name for model in POOL_MODELS):
45
+ sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
46
+ # Normalize embeddings
47
+ sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
48
+ else:
49
+ sentence_embedding = outputs.last_hidden_state[:, 0, :]
50
+ return sentence_embedding
51
+
52
+
53
+ class LMHead(nn.Module):
54
+ def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
55
+ super(LMHead, self).__init__()
56
+
57
+ self.fc1 = nn.Linear(embedding_dim, hidden_dim)
58
+ #self.gelu = nn.GELU()
59
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
60
+
61
+ def forward(self, x):
62
+ embd = self.fc1(x)
63
+ embd = F.normalize(embd, p=2, dim=1)
64
+ deg_pred = self.fc2(embd)
65
+ return embd, deg_pred
insir_text/__pycache__/models.cpython-39.pyc ADDED
Binary file (2.72 kB). View file
 
insir_text/models.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
5
+ import os
6
+
7
+ # Models that use mean pooling
8
+ POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
9
+
10
+ #Mean Pooling - Take attention mask into account for correct averaging
11
+ def mean_pooling(model_output, attention_mask):
12
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
13
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
14
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
15
+
16
+
17
+ class LanguageModel(nn.Module):
18
+ def __init__(self, model='distilbert-base-uncased'):
19
+ super(LanguageModel, self).__init__()
20
+
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
22
+ self.model = AutoModel.from_pretrained(model)
23
+ self.model_name = model
24
+ # Remove the CLIP vision tower
25
+ if "clip" in self.model_name:
26
+ self.model.vision_model = None
27
+ # Freeze the pre-trained parameters (very important)
28
+ for param in self.model.parameters():
29
+ param.requires_grad = False
30
+
31
+ # Make sure to set evaluation mode (also important)
32
+ self.model.eval()
33
+
34
+ def forward(self, text_batch):
35
+ inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
36
+ with torch.no_grad(): # Ensure no gradients are computed for this forward pass
37
+
38
+ if "clip" in self.model_name:
39
+ sentence_embedding = self.model.get_text_features(**inputs)
40
+ return sentence_embedding
41
+
42
+ outputs = self.model(**inputs)
43
+
44
+ if any(model in self.model_name for model in POOL_MODELS):
45
+ sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
46
+ # Normalize embeddings
47
+ sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
48
+ else:
49
+ sentence_embedding = outputs.last_hidden_state[:, 0, :]
50
+ return sentence_embedding
51
+
52
+
53
+ class LMHead(nn.Module):
54
+ def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
55
+ super(LMHead, self).__init__()
56
+
57
+ self.fc1 = nn.Linear(embedding_dim, hidden_dim)
58
+ #self.gelu = nn.GELU()
59
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
60
+
61
+ def forward(self, x):
62
+ embd = self.fc1(x)
63
+ embd = F.normalize(embd, p=2, dim=1)
64
+ deg_pred = self.fc2(embd)
65
+ return embd, deg_pred
insir_text/sample_prompts.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "denoising": [
3
+ "Help me reduce the fuzziness in this image.",
4
+ "I need this image denoised ASAP.",
5
+ "Clean up this noisy image, it's an eyesore.",
6
+ "Can you clean the dots from my image?",
7
+ "Help me with my picture, it's full of tiny spots.",
8
+ "Clean up this image, it's all grainy."
9
+ ],
10
+ "deblurring": [
11
+ "Please, clean up this blurry photo.",
12
+ "My picture's not sharp, fix it.",
13
+ "Deblur my picture, it's too fuzzy.",
14
+ "Help, my photo is too blurry.",
15
+ "Please, make my image less smudgy."
16
+ ],
17
+ "dehazing": [
18
+ "Please, fix the haziness in my image.",
19
+ "I need to remove the haziness from this image.",
20
+ "Get rid of the fog in my image.",
21
+ "Fix my photo, it's too misty.",
22
+ "Help me, my photo is all hazy."
23
+ ],
24
+ "deraining": [
25
+ "I want to eliminate the water from this image.",
26
+ "Clear the rain from my picture.",
27
+ "I need to clear the rain from this image.",
28
+ "Can you get rid of the raindrops in my picture?"
29
+ ],
30
+ "sr": [
31
+ "I need to enhance the size and quality of this image.",
32
+ "My photo is lacking size and clarity; can you improve it?",
33
+ "I'd appreciate it if you could upscale this photo.",
34
+ "My picture is too little, enlarge it."
35
+ ],
36
+ "ambiguous": [
37
+ "Please, clear up the mess on this image.",
38
+ "I want this image to look good.",
39
+ "make it pop",
40
+ "Fix my photo, it's all messed up."
41
+ ],
42
+ "lol": [
43
+ "I took this photo during night, enhance it",
44
+ "The photo is too dark, improve exposure",
45
+ "my image has poor lighting conditions, can you fix it?",
46
+ "Can you make the image brighter?"
47
+ ],
48
+ "enhancement": [
49
+ "make my image look like DSLR",
50
+ "improve the colors of my image",
51
+ "enhance the colors of the image",
52
+ "Can you edit this to look like an award-winning photo?",
53
+ "I want the picture to be retouched for a professional portfolio."
54
+ ]
55
+ }