yeq6x commited on
Commit
02ba63a
1 Parent(s): 471c8d4
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ベースイメージとしてPython 3.9を使用
2
+ FROM python:3.9-slim
3
+
4
+ # 作業ディレクトリを設定
5
+ WORKDIR /app
6
+
7
+ # 必要なPythonライブラリをインストールするための依存ファイルをコピー
8
+ COPY requirements.txt /app/requirements.txt
9
+
10
+ # 必要なPythonパッケージをインストール
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # アプリケーションコードをコンテナにコピー
14
+ COPY . /app
15
+
16
+ # ポート設定(Gradioのデフォルトポート7860)
17
+ EXPOSE 7860
18
+
19
+ # アプリケーションを起動
20
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ import matplotlib.pyplot as plt
7
+ from model_module import AutoencoderModule
8
+ from dataset import MyDataset, load_filenames
9
+ from utils import DistanceMapLogger
10
+ import numpy as np
11
+ from PIL import Image
12
+ import base64
13
+ from io import BytesIO
14
+
15
+ # モデルとデータの読み込み
16
+ def load_model():
17
+ model_path = "checkpoints/ae_model_tf_2024-03-05_00-35-21.pth"
18
+ feature_dim = 32
19
+ model = AutoencoderModule(feature_dim=feature_dim)
20
+ state_dict = torch.load(model_path)
21
+
22
+ # state_dict のキーを修正
23
+ new_state_dict = {}
24
+ for key in state_dict:
25
+ new_key = "model." + key
26
+ new_state_dict[new_key] = state_dict[key]
27
+ model.load_state_dict(new_state_dict)
28
+ model.eval()
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.to(device)
32
+ print("Model loaded successfully.")
33
+ return model, device
34
+
35
+ def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
36
+ filenames = load_filenames(img_dir)
37
+ train_X = filenames[:1000]
38
+ train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
39
+
40
+ train_loader = DataLoader(
41
+ train_ds,
42
+ batch_size=batch_size,
43
+ shuffle=True,
44
+ num_workers=0,
45
+ )
46
+
47
+ iterator = iter(train_loader)
48
+ x, _, _ = next(iterator)
49
+ x = x.to(device)
50
+ x = x[:,0].to(device)
51
+ print("Data loaded successfully.")
52
+ return x
53
+
54
+ model, device = load_model()
55
+ image_size = 112
56
+ batch_size = 32
57
+ x = load_data(device)
58
+
59
+ # アップロード画像の前処理
60
+ def preprocess_uploaded_image(uploaded_image, image_size):
61
+ uploaded_image = Image.fromarray(uploaded_image)
62
+ uploaded_image = uploaded_image.convert("RGB")
63
+ uploaded_image = uploaded_image.resize((image_size, image_size))
64
+ uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
65
+ uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
66
+ return uploaded_image
67
+
68
+ # ヒートマップの生成関数
69
+ @spaces.GPU
70
+ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
71
+ with torch.no_grad():
72
+ dec5, _ = model(x)
73
+ img = x
74
+ feature_map = dec5
75
+ batch_size = feature_map.size(0)
76
+ feature_dim = feature_map.size(1)
77
+
78
+ # アップロード画像の前処理
79
+ if uploaded_image is not None:
80
+ uploaded_image = preprocess_uploaded_image(uploaded_image, image_size)
81
+ target_feature_map, _ = model(uploaded_image)
82
+ img = torch.cat((img, uploaded_image))
83
+ feature_map = torch.cat((feature_map, target_feature_map))
84
+ batch_size += 1
85
+ else:
86
+ uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
87
+
88
+ target_num = batch_size - 1
89
+
90
+ x_coords = [x_coords] * batch_size
91
+ y_coords = [y_coords] * batch_size
92
+
93
+ vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
94
+ vector = vectors[source_num]
95
+
96
+ reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
97
+ batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
98
+
99
+ norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
100
+
101
+ source_map = norm_batch_distance_map[source_num]
102
+ target_map = norm_batch_distance_map[target_num]
103
+
104
+ alpha = 0.8
105
+ blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
106
+ blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
107
+
108
+ # Matplotlibでプロットして画像として保存
109
+ fig, axs = plt.subplots(2, 2, figsize=(10, 10))
110
+ axs[0, 0].imshow(source_map.cpu(), cmap='hot')
111
+ axs[0, 0].set_title("Source Map")
112
+ axs[0, 1].imshow(target_map.cpu(), cmap='hot')
113
+ axs[0, 1].set_title("Target Map")
114
+ axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu())
115
+ axs[1, 0].set_title("Blended Source")
116
+ axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu())
117
+ axs[1, 1].set_title("Blended Target")
118
+ for ax in axs.flat:
119
+ ax.axis('off')
120
+
121
+ plt.tight_layout()
122
+ plt.close(fig)
123
+ return fig
124
+
125
+ def process_image(cropped_image_data):
126
+ # Base64からPILイメージに変換
127
+ header, base64_data = cropped_image_data.split(',', 1)
128
+ image_data = base64.b64decode(base64_data)
129
+ image = Image.open(BytesIO(image_data))
130
+ return image
131
+
132
+ # JavaScriptコード
133
+ scripts = """
134
+ async () => {
135
+ const script = document.createElement("script");
136
+ script.src = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.js";
137
+ document.head.appendChild(script);
138
+
139
+ const style = document.createElement("link");
140
+ style.rel = "stylesheet";
141
+ style.href = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.css";
142
+ document.head.appendChild(style);
143
+
144
+ script.onload = () => {
145
+ let cropper;
146
+
147
+ document.getElementById("input_file_button").onclick = function() {
148
+ document.querySelector("#input_file").click();
149
+ };
150
+
151
+ // GradioのFileコンポーネントから画像を読み込む
152
+ document.querySelector("#input_file").addEventListener("change", function(e) {
153
+ const files = e.target.files;
154
+ console.log(files);
155
+ if (files && files.length > 0) {
156
+ console.log("File selected");
157
+ document.querySelector("#crop_view").style.display = "block";
158
+ document.querySelector("#crop_button").style.display = "block";
159
+ const url = URL.createObjectURL(files[0]);
160
+ const crop_view = document.getElementById("crop_view");
161
+ crop_view.src = url;
162
+
163
+ if (cropper) {
164
+ cropper.destroy();
165
+ }
166
+ cropper = new Cropper(crop_view, {
167
+ aspectRatio: 1,
168
+ viewMode: 1,
169
+ });
170
+ }
171
+ });
172
+
173
+ // GradioボタンにJavaScriptの機能を追加
174
+ document.getElementById("crop_button").onclick = function() {
175
+ if (cropper) {
176
+ const canvas = cropper.getCroppedCanvas();
177
+ const croppedImageData = canvas.toDataURL();
178
+
179
+ // Gradioにクロップ画像を送信
180
+ const textbox = document.querySelector("#cropped_image_data textarea");
181
+ textbox.value = croppedImageData;
182
+ textbox.dispatchEvent(new Event("input", { bubbles: true }));
183
+
184
+ document.getElementById("crop_view").style.display = "none";
185
+ document.getElementById("crop_button").style.display = "none";
186
+
187
+ cropper.destroy();
188
+ }
189
+ };
190
+ document.getElementById("crop_view").style.display = "none";
191
+ document.getElementById("crop_button").style.display = "none";
192
+ };
193
+ }
194
+ """
195
+
196
+ with gr.Blocks() as demo:
197
+ with gr.Row():
198
+ with gr.Column():
199
+ source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index")
200
+ x_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate")
201
+ y_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate")
202
+
203
+ # GradioのFileコンポーネントでファイル選択ボタンを追加
204
+ gr.HTML('<input type="file" id="input_file" style="display:none;">')
205
+ input_file_button = gr.Button("画像を選択", elem_id="input_file_button")
206
+ # 画像を表示するためのHTML画像タグをGradioで表示
207
+ gr.HTML('<img id="crop_view" style="max-width:100%;">')
208
+ # Gradioのボタンコンポーネントを追加し、IDを付与
209
+ crop_button = gr.Button("クロップ", elem_id="crop_button", variant="primary")
210
+ # クロップされた画像データのテキストボックス(Base64データ)
211
+ cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
212
+ input_image = gr.Image(label="Cropped Image", interactive=False)
213
+ # cropped_image_dataが更新されたらprocess_imageを呼び出す
214
+ cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
215
+
216
+ with gr.Column():
217
+ output_plot = gr.Plot()
218
+
219
+ # Gradioインターフェースの代わり
220
+ source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
221
+ x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
222
+ y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
223
+ input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
224
+
225
+ # JavaScriptコードをロード
226
+ demo.load(None, None, None, js=scripts)
227
+
228
+ demo.launch()
229
+
checkpoints/ae_model_tf_2024-03-05_00-35-21.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77b020cb89ad2ccf7a7bf654d86fb975793cbe168bf73cd011e93cf22f63204c
3
+ size 2629576
checkpoints/autoencoder-epoch=09-train_loss=1.00.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af08cd5fbb1c832824b7466be4da59d2a40e6e9eef864097514a2806a24bb92b
3
+ size 3046514
checkpoints/autoencoder-epoch=29-train_loss=1.01.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dc5dcf07a6a66cd4f0773af5fff29903d5fb9fa340221cd59083462e1ae77b7
3
+ size 3046959
checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4f218b652ba63e9891b73efca5faffaabe4692f1b78755860ff46b113d09ecd
3
+ size 3046959
datamodule.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from torch.utils.data import DataLoader
3
+ from dataset import MyDataset, load_filenames # dataset.pyに基づく
4
+
5
+ class DataModule(pl.LightningDataModule):
6
+ def __init__(self, img_dir, batch_size, img_size=112, num_workers=0):
7
+ super().__init__()
8
+ self.img_dir = img_dir
9
+ self.batch_size = batch_size
10
+ self.img_size = img_size
11
+ self.num_workers = num_workers
12
+ self.file_num = 1000 # or 3400
13
+
14
+ def setup(self, stage=None):
15
+ filenames = load_filenames(self.img_dir)
16
+ self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size)
17
+
18
+ def train_dataloader(self):
19
+ return DataLoader(
20
+ self.train_dataset,
21
+ batch_size=self.batch_size,
22
+ shuffle=True,
23
+ num_workers=self.num_workers,
24
+ persistent_workers=True
25
+ )
dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ import random
5
+ from PIL import Image
6
+ import os
7
+
8
+ from utils import RandomAffineAndRetMat
9
+
10
+ def load_filenames(data_dir):
11
+ # label_data = pd.read_json(INPUT_DIR+'DataList.json')
12
+ # label_data = label_data.sort_index()
13
+ # tmp_points = []
14
+ # filenames = []
15
+
16
+ # for o in tqdm(label_data.data[0:1000]):
17
+ # filenames.append(o['filename'])
18
+ # a = o['filename']
19
+
20
+ # tmps = []
21
+ # for i in range(60):
22
+ # tmps.append(o['points'][str(i)]['x'])
23
+ # tmps.append(o['points'][str(i)]['y'])
24
+ # tmp_points.append(tmps) # datanum
25
+
26
+ # filenames = pd.Series(filenames)
27
+ # filenames = [str(i).zfill(4)+'.jpg' for i in range(3400)]
28
+ # df_points = pd.DataFrame(tmp_points)
29
+
30
+ # load from data_dir
31
+ # 画像の拡張子のみ
32
+ img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
33
+ filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
34
+
35
+ return filenames
36
+
37
+
38
+ class MyDataset:
39
+ def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
40
+ self.X = X
41
+ self.valid = valid
42
+ self.img_dir = img_dir
43
+ self.img_size = img_size
44
+
45
+ def __len__(self):
46
+ return len(self.X)
47
+
48
+ def __getitem__(self, index):
49
+ # 画像を読み込んでトランスフォームを適用
50
+ f = self.img_dir + self.X[index]
51
+ original_X = Image.open(f)
52
+ trans = [
53
+ transforms.ToTensor(),
54
+ # transforms.Normalize(mean=means, std=stds),
55
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
56
+
57
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.15),
58
+ transforms.RandomGrayscale(0.3),
59
+ ]
60
+ transform = transforms.Compose(trans)
61
+ xlist = []
62
+ matlist = []
63
+ is_flip = random.randint(0, 1) # 同じ画像はフリップ
64
+ for i in range(2):
65
+ af = RandomAffineAndRetMat(
66
+ degrees=[-30, 30],
67
+ translate=(0.1, 0.1), scale=(0.8, 1.2),
68
+ # fill=(random.random(), random.random(), random.random()),
69
+ fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
70
+ shear=[-10, 10],
71
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
72
+ )
73
+ X, affine_matrix = af(transforms.Resize(self.img_size)(original_X))
74
+
75
+ # randomflip
76
+ if is_flip == 1:
77
+ X = transforms.RandomHorizontalFlip(1.)(X)
78
+ flip_matrix = torch.tensor([[-1., 0., 0.],
79
+ [0., 1., 0.],
80
+ [0., 0., 1.]])
81
+ affine_matrix = torch.matmul(flip_matrix, affine_matrix)
82
+
83
+ xlist.append(transform(X))
84
+ matlist.append(affine_matrix)
85
+
86
+ X = torch.stack(xlist)
87
+ mat = torch.stack(matlist)
88
+ return X, mat, f
model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class ConvBlock(nn.Module):
5
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
6
+ super(ConvBlock, self).__init__()
7
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
8
+ self.batchnorm = nn.BatchNorm2d(out_channels)
9
+ self.relu = nn.ReLU()
10
+
11
+ def forward(self, x):
12
+ return self.relu(self.batchnorm(self.conv(x)))
13
+
14
+ class DeconvBlock(nn.Module):
15
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding):
16
+ super(DeconvBlock, self).__init__()
17
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
18
+ self.batchnorm = nn.BatchNorm2d(out_channels)
19
+ self.relu = nn.ReLU()
20
+
21
+ def forward(self, x):
22
+ return self.relu(self.batchnorm(self.deconv(x)))
23
+
24
+ class Autoencoder(nn.Module):
25
+ def __init__(self, feature_dim=32):
26
+ super(Autoencoder, self).__init__()
27
+ self.feature_dim = feature_dim
28
+
29
+ # エンコーダ
30
+ self.enc1 = ConvBlock(3, 16, 10, 1, 0)
31
+ self.enc2 = ConvBlock(16, 32, 10, 1, 0)
32
+ self.enc3 = ConvBlock(32, 64, 2, 2, 0)
33
+ self.enc4 = ConvBlock(64, 128, 2, 2, 0)
34
+ self.enc5 = ConvBlock(128, 256, 2, 2, 0)
35
+
36
+ # デコーダ
37
+ self.dec1 = DeconvBlock(256, 128, 2, 2, 0, 1)
38
+ self.dec2 = DeconvBlock(256, 64, 2, 2, 0, 1) # 128 + 128
39
+ self.dec3 = DeconvBlock(128, 32, 2, 2, 0, 0) # 64 + 64
40
+ self.dec4 = DeconvBlock(64, 16, 10, 1, 0, 0) # 32 + 32
41
+ self.dec5 = DeconvBlock(32, self.feature_dim, 10, 1, 0, 0)
42
+ self.dec6 = nn.Conv2d(self.feature_dim, 32, 1, 1, 0)
43
+ self.dec7 = nn.Conv2d(32, 3, 1, 1, 0)
44
+
45
+ def forward(self, x):
46
+ # エンコーダ
47
+ enc1 = self.enc1(x)
48
+ enc2 = self.enc2(enc1)
49
+ enc3 = self.enc3(enc2)
50
+ enc4 = self.enc4(enc3)
51
+ enc5 = self.enc5(enc4)
52
+
53
+ # デコーダ
54
+ dec1 = self.dec1(enc5)
55
+ dec2 = self.dec2(torch.cat((dec1, enc4), 1))
56
+ dec3 = self.dec3(torch.cat((dec2, enc3), 1))
57
+ dec4 = self.dec4(torch.cat((dec3, enc2), 1))
58
+ dec5 = self.dec5(torch.cat((dec4, enc1), 1))
59
+ dec6 = self.dec6(dec5)
60
+ dec7 = self.dec7(dec6)
61
+
62
+ return dec5, dec7
model_module.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from torch import nn
4
+ from torch.optim import SGD
5
+ from torchvision.utils import save_image
6
+ import os
7
+ from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
8
+ from model import Autoencoder
9
+
10
+ class AutoencoderModule(pl.LightningModule):
11
+ def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"):
12
+ super(AutoencoderModule, self).__init__()
13
+ self.feature_dim = feature_dim
14
+ self.learning_rate = learning_rate
15
+ self.lambda_c = lambda_c
16
+ self.margin_img = initial_margin
17
+ self.margin_img_init = initial_margin
18
+ self.threshold = initial_threshold
19
+ self.model = Autoencoder(self.feature_dim)
20
+ self.criterion = nn.MSELoss()
21
+ self.triplet_loss = TripletLossBatch()
22
+ self.losses = []
23
+ self.save_interval = save_interval # バッチごとの出力間隔
24
+ self.output_dir = output_dir
25
+ os.makedirs(self.output_dir, exist_ok=True)
26
+
27
+ def forward(self, x):
28
+ return self.model(x)
29
+
30
+ def training_step(self, batch, batch_idx):
31
+ img, mat, _ = batch
32
+ batch_size, _, _, size, size = img.shape
33
+ img = img.view(batch_size*2, 3, size, size)
34
+ mat = mat.view(batch_size*2, 3, 3)
35
+
36
+ dec5_output, output = self.model(img)
37
+ mse_loss = self.criterion(output, img)
38
+
39
+ # 画像内方向の処理
40
+ num_anchor_sets = 2**12
41
+ trip_loss = 0
42
+ std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
43
+ for c in std_list:
44
+ std = size / c
45
+ anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
46
+ coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
47
+ valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
48
+ coords[valid_coords_idx] = 0
49
+ anchors[valid_coords_idx] = 0
50
+
51
+ # 最も近い座標の選択
52
+ d = pairwise_distance_squared(anchors.float(), coords.float())
53
+ idx = torch.argmin(d, dim=2)
54
+ anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
55
+
56
+ # dec5_outputから特徴ベクトルを抽出
57
+ anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
58
+ trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
59
+
60
+ trip_loss /= len(std_list)
61
+ self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
62
+
63
+ # 変形の学習
64
+ num_samples = 2**20
65
+ tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
66
+
67
+ # バッチ方向の処理
68
+ bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
69
+
70
+ # 合計損失
71
+ loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
72
+ self.log("train_loss", loss)
73
+
74
+ # VRAM管理
75
+ del img, output
76
+ torch.cuda.empty_cache()
77
+
78
+ return loss
79
+
80
+
81
+ def _get_triplet_coordinates(self, anchors, coords, idx):
82
+ anchors = anchors.squeeze(2)
83
+ positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
84
+ negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
85
+ return anchors, positives, negatives
86
+
87
+ def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
88
+ y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
89
+ x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
90
+ y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
91
+ x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
92
+ y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
93
+ x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
94
+
95
+ anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
96
+ positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
97
+ negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
98
+ return anchor_vectors, positive_vectors, negative_vectors
99
+
100
+ def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
101
+ anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
102
+ coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
103
+ coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
104
+ anchor_coords = torch.cat((coords_x, coords_y), 1)
105
+ anchor_mat = mat[anchor_indices]
106
+ tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)
107
+
108
+ anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
109
+ inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
110
+ anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]
111
+
112
+ inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
113
+ anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
114
+ return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()
115
+
116
+ def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
117
+ N = 2**12
118
+ anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
119
+ anchor_coords = torch.randint(0, size, (N, 2))
120
+ other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
121
+ other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
122
+ other_coords = torch.randint(0, size, (N, 2, 2))
123
+
124
+ anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
125
+ other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
126
+ distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
127
+ return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)
128
+
129
+ def configure_optimizers(self):
130
+ optimizer = SGD(self.parameters(), lr=self.learning_rate)
131
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
132
+ return [optimizer], [scheduler]
133
+
134
+ # def save_intermediate_image(self, output, epoch):
135
+ # save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
136
+ # print(f"Saved intermediate image at epoch {epoch}")
137
+
138
+ # def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
139
+ # save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
140
+ # DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)
141
+
142
+ def configure_optimizers(self):
143
+ optimizer = SGD(self.parameters(), lr=self.learning_rate)
144
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
145
+ return [optimizer], [scheduler]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch==2.2.0
3
+ torchvision==0.17.0
4
+ torchaudio==2.2.0
5
+ matplotlib==3.9.2
6
+ numpy==1.26.4
7
+ pytorch-lightning==2.4.0
8
+ scikit-learn==1.0.2
9
+ gradio==5.5.0
utils.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.decomposition import PCA
10
+
11
+ class RandomAffineAndRetMat(torch.nn.Module):
12
+ def __init__(
13
+ self,
14
+ degrees,
15
+ translate=None,
16
+ scale=None,
17
+ shear=None,
18
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST,
19
+ fill=0,
20
+ center=None,
21
+ ):
22
+ super().__init__()
23
+ self.degrees = degrees
24
+ self.translate = translate
25
+ self.scale = scale
26
+ self.shear = shear
27
+ self.interpolation = interpolation
28
+ self.fill = fill
29
+ self.center = center
30
+
31
+ def forward(self, img):
32
+ """
33
+ img (PIL Image or Tensor): Image to be transformed.
34
+
35
+ Returns:
36
+ PIL Image or Tensor: Affine transformed image.
37
+ """
38
+ fill = self.fill
39
+ if isinstance(img, Tensor):
40
+ if isinstance(fill, (int, float)):
41
+ fill = [float(fill)] * transforms.functional.get_image_num_channels(img)
42
+ else:
43
+ fill = [float(f) for f in fill]
44
+
45
+ img_size = transforms.functional.get_image_size(img)
46
+
47
+ ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
48
+ transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
49
+
50
+ affine_matrix = self.get_affine_matrix_from_params(ret)
51
+
52
+ return transformed_image, affine_matrix
53
+
54
+ def get_affine_matrix_from_params(self, params):
55
+ degrees, translate, scale, shear = params
56
+ degrees = torch.tensor(degrees)
57
+ shear = torch.tensor(shear)
58
+
59
+ # パラメータを変換行列に変換
60
+ rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0],
61
+ [torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0],
62
+ [0, 0, 1]])
63
+
64
+ translation_matrix = torch.tensor([[1, 0, translate[0]],
65
+ [0, 1, translate[1]],
66
+ [0, 0, 1]]).to(torch.float32)
67
+
68
+ scaling_matrix = torch.tensor([[scale, 0, 0],
69
+ [0, scale, 0],
70
+ [0, 0, 1]])
71
+
72
+ shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0],
73
+ [-torch.tan(torch.deg2rad(shear[1])), 1, 0],
74
+ [0, 0, 1]])
75
+
76
+ # 変換行列を合成
77
+ affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
78
+
79
+ return affine_matrix
80
+
81
+ class GetTransformedCoords(nn.Module):
82
+ def __init__(self, affine_matrix, center):
83
+ super().__init__()
84
+ self.affine_matrix = affine_matrix
85
+ self.center = center
86
+
87
+ def forward(self, _coords):
88
+ # coords: like tensor([[43, 26], [44, 27], [45, 28]])
89
+ center_x, center_y = self.center
90
+ # 元の座標を中心原点にシフト
91
+ coords = _coords.clone()
92
+ coords[:, 0] -= center_x
93
+ coords[:, 1] -= center_y
94
+
95
+ # 各バッチに対して変換を行う
96
+ homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1)
97
+ transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1)
98
+
99
+ # 画像の範囲内に収める
100
+ # transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0]))
101
+ # transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1]))
102
+ transformed_x = transformed_coordinates[:, 0]
103
+ transformed_y = transformed_coordinates[:, 1]
104
+
105
+ transformed_x += center_x
106
+ transformed_y += center_y
107
+ return torch.stack([transformed_x, transformed_y]).t().to(torch.long)
108
+
109
+ # ルートを取らないpairwise_distanceのバージョン
110
+ def pairwise_distance_squared(a, b):
111
+ return torch.sum((a - b) ** 2, dim=-1)
112
+
113
+ def cosine_similarity(a, b):
114
+ # ベクトルaとbの内積を計算
115
+ dot_product = torch.matmul(a, b)
116
+ # ベクトルaとbのノルム(大きさ)を計算
117
+ norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
118
+ norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
119
+ # コサイン類似度を計算(内積をノルムの積で割る)
120
+ return dot_product / (norm_a * norm_b)
121
+
122
+ def batch_cosine_similarity(a, b):
123
+ # ベクトルaとbの内積を計算
124
+ dot_product = torch.einsum('bnd,bnd->bn', a, b)
125
+ # ベクトルaとbのノルム(大きさ)を計算
126
+ norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
127
+ norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
128
+ # コサイン類似度を計算(内積をノルムの積で割る)
129
+ return dot_product / (norm_a * norm_b)
130
+
131
+ class TripletLossBatch(nn.Module):
132
+ def __init__(self):
133
+ super(TripletLossBatch, self).__init__()
134
+
135
+ def forward(self, anchor, positive, negative, margin=1.0):
136
+ distance_positive = F.pairwise_distance(anchor, positive, p=2)
137
+ distance_negative = F.pairwise_distance(anchor, negative, p=2)
138
+ losses = torch.relu(distance_positive - distance_negative + margin)
139
+ return losses.mean()
140
+
141
+ class TripletLossCosineSimilarity(nn.Module):
142
+ def __init__(self):
143
+ super(TripletLossCosineSimilarity, self).__init__()
144
+
145
+ def forward(self, anchor, positive, negative, margin=1.0):
146
+ distance_positive = 1 - batch_cosine_similarity(anchor, positive)
147
+ distance_negative = 1 - batch_cosine_similarity(anchor, negative)
148
+ losses = torch.relu(distance_positive - distance_negative + margin)
149
+ return losses.mean()
150
+
151
+ def imsave(img):
152
+ img = torchvision.utils.make_grid(img)
153
+ img = img / 2 + 0.5
154
+ npimg = img.detach().cpu().numpy()
155
+ # plt.imshow(np.transpose(npimg, (1, 2, 0)))
156
+ # plt.show()
157
+ # save image
158
+ npimg = np.transpose(npimg, (1, 2, 0))
159
+ npimg = npimg * 255
160
+ npimg = npimg.astype(np.uint8)
161
+ Image.fromarray(npimg).save('sample.png')
162
+
163
+ def norm_img(img):
164
+ return (img-img.min())/(img.max()-img.min())
165
+
166
+ def norm_img2(img):
167
+ return (img-img.min())/(img.max()-img.min())*255
168
+
169
+ class DistanceMapLogger:
170
+ def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None):
171
+ device = feature_map.device
172
+ batch_size = feature_map.size(0)
173
+ feature_dim = feature_map.size(1)
174
+ image_size = feature_map.size(2)
175
+
176
+ if x_coords is None:
177
+ x_coords = [69]*batch_size
178
+ if y_coords is None:
179
+ y_coords = [42]*batch_size
180
+
181
+ # PCAで3次元のマップを抽出
182
+ pca = PCA(n_components=3)
183
+ pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行
184
+ reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元)
185
+
186
+ sample_num = 0
187
+ vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
188
+ vector = vectors[sample_num]
189
+
190
+ # バッチ内の各特徴マップに対して内積を計算
191
+ # feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
192
+ reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
193
+ batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
194
+ # batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
195
+ norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
196
+ # norm_batch_distance_map[:,0,0] = 0.001
197
+ # 可視化と保存
198
+ fig, axes = plt.subplots(5, 4, figsize=(20, 25))
199
+ for ax in axes.flatten():
200
+ ax.axis('off')
201
+ # 余白をなくす
202
+ plt.subplots_adjust(wspace=0, hspace=0)
203
+ # 外の余白もなくす
204
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
205
+
206
+ # 距離マップの可視化
207
+ for i in range(5):
208
+ axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot')
209
+ if i == sample_num:
210
+ axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7)
211
+
212
+ distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
213
+ alpha = 0.9 # Transparency factor for the heatmap overlay
214
+ blended_tensor = (1 - alpha) * img[i] + alpha * distance_map
215
+ axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu()))
216
+
217
+ axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu()))
218
+
219
+ axes[i, 3].imshow(norm_img(reshaped_pca_result[i]))
220
+
221
+ plt.savefig(save_path)
222
+
223
+
224
+
225
+ def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42):
226
+ device = feature_map.device
227
+ batch_size = feature_map.size(0)
228
+ feature_dim = feature_map.size(1)
229
+ image_size = feature_map.size(2)
230
+
231
+ x_coords = [x_coords]*batch_size
232
+ y_coords = [y_coords]*batch_size
233
+
234
+ vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
235
+ vector = vectors[source_num]
236
+
237
+ # バッチ内の各特徴マップに対して内積を計算
238
+ # feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
239
+ reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
240
+ batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
241
+ # batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
242
+ norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
243
+ # norm_batch_distance_map[:,0,0] = 0.001
244
+
245
+ source_map = norm_batch_distance_map[source_num]
246
+ target_map = norm_batch_distance_map[target_num]
247
+
248
+ alpha = 0.9
249
+ blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
250
+ blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
251
+
252
+ return source_map, target_map, blended_source, blended_target