jiawei-ren commited on
Commit
e8481f2
1 Parent(s): 58d92ee
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +273 -0
  3. packages.txt +1 -0
  4. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ import seaborn as sns
5
+ import pandas as pd
6
+ import os
7
+ import os.path as osp
8
+ import ffmpeg
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.modules.loss import _Loss
12
+ from torch.utils.data import Dataset, DataLoader
13
+
14
+ NUM_PER_BUCKET = 1000
15
+ NOISE_SIGMA = 1
16
+ Y_UB = 10
17
+ Y_LB = 0
18
+ K = 1
19
+ B = 0
20
+ NUM_SEG = 5
21
+ sns.set_theme(palette='colorblind')
22
+ NUM_EPOCHS = 100
23
+ PRINT_FREQ = NUM_EPOCHS // 20
24
+ NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
25
+ BATCH_SIZE = 256
26
+
27
+
28
+ def make_dataframe(x, y, method=None):
29
+ x = list(x[:, 0].detach().numpy())
30
+ y = list(y[:, 0].detach().numpy())
31
+ if method is not None:
32
+ method = [method for _ in range(len(x))]
33
+ df = pd.DataFrame({'x': x, 'y': y, 'Method': method})
34
+ else:
35
+ df = pd.DataFrame({'x': x, 'y': y})
36
+ return df
37
+
38
+ Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1)
39
+ X_demo = (Y_demo - B) / K
40
+
41
+ df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
42
+
43
+ def prepare_data():
44
+ interval = (Y_UB - Y_LB) / NUM_SEG
45
+ all_x, all_y = [], []
46
+ for i in range(NUM_SEG):
47
+ uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i+1)*interval, Y_UB-i*interval)
48
+ y_uniform = uniform_y_distribution.sample((NUM_TRAIN_SAMPLES, 1))
49
+
50
+ noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
51
+ noise = noise_distribution.sample((NUM_TRAIN_SAMPLES, 1))
52
+ y_uniform_oracle = y_uniform - noise
53
+
54
+ x_uniform = (y_uniform_oracle - B) / K
55
+ all_x.append(x_uniform)
56
+ all_y.append(y_uniform)
57
+ return all_x, all_y
58
+
59
+ def select_data(all_x, all_y, sel_num):
60
+ sel_x, sel_y = [], []
61
+ prob = []
62
+ for i in range(NUM_SEG):
63
+ sel_x += all_x[i][:sel_num[i]]
64
+ sel_y += all_y[i][:sel_num[i]]
65
+ prob += [torch.tensor(sel_num[i]).float() for _ in range(sel_num[i])]
66
+ sel_x = torch.stack(sel_x)
67
+ sel_y = torch.stack(sel_y)
68
+ prob = torch.stack(prob)
69
+ return sel_x, sel_y, prob
70
+
71
+
72
+ def unzip_dataloader(training_loader):
73
+ all_x = []
74
+ all_y = []
75
+ for data, label, _ in training_loader:
76
+ all_x.append(data)
77
+ all_y.append(label)
78
+ all_x = torch.cat(all_x)
79
+ all_y = torch.cat(all_y)
80
+ return all_x, all_y
81
+
82
+ # Train the model
83
+ def train(train_loader, training_bundle, num_epochs):
84
+ training_df = make_dataframe(*unzip_dataloader(train_loader))
85
+ for epoch in range(num_epochs):
86
+ for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
87
+ model.train()
88
+ for data, target, prob in train_loader:
89
+ optimizer.zero_grad()
90
+ pred = model(data)
91
+ if criterion_name == 'Reweight':
92
+ loss = criterion(pred, target, prob)
93
+ else:
94
+ loss = criterion(pred, target)
95
+ loss.backward()
96
+ optimizer.step()
97
+ scheduler.step()
98
+ if (epoch + 1) % PRINT_FREQ == 0:
99
+ visualize(training_df, training_bundle, epoch)
100
+
101
+ def visualize(training_df, training_bundle, epoch):
102
+ df = df_oracle
103
+ for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
104
+ model.eval()
105
+ y = model(X_demo)
106
+ df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
107
+ sns.lineplot(data=df, x='x', y='y', hue='Method', estimator=None, ci=None)
108
+ sns.scatterplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.05, linewidths=0, s=100)
109
+ plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
110
+ plt.ylim(Y_LB, Y_UB)
111
+ plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
112
+ plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
113
+ plt.savefig('train_log/{:05d}.png'.format(epoch+1), bbox_inches='tight')
114
+ plt.close()
115
+
116
+
117
+
118
+ def make_video():
119
+ if osp.isfile('movie.mp4'):
120
+ os.remove('movie.mp4')
121
+ (
122
+ ffmpeg
123
+ .input('train_log/*.png', pattern_type='glob', framerate=3)
124
+ .output('movie.mp4')
125
+ .run()
126
+ )
127
+
128
+ class ReweightL2(_Loss):
129
+ def __init__(self, reweight='inverse'):
130
+ super(ReweightL2, self).__init__()
131
+ self.reweight = reweight
132
+
133
+ def forward(self, pred, target, prob):
134
+ reweight = self.reweight
135
+ if reweight == 'inverse':
136
+ inv_prob = prob.pow(-1)
137
+ elif reweight == 'sqrt_inv':
138
+ inv_prob = prob.pow(-0.5)
139
+ else:
140
+ raise NotImplementedError
141
+ inv_prob = inv_prob / inv_prob.sum()
142
+ loss = F.mse_loss(pred, target, reduction='none').sum(-1) * inv_prob
143
+ loss = loss.sum()
144
+ return loss
145
+
146
+ # we use a linear layer to regress the weight from height
147
+ class LinearModel(nn.Module):
148
+ def __init__(self, input_dim, output_dim):
149
+ super(LinearModel, self).__init__()
150
+ self.mlp = nn.Sequential(
151
+ nn.Linear(input_dim, output_dim),
152
+ )
153
+
154
+ def forward(self, x):
155
+ x = self.mlp(x)
156
+ return x
157
+
158
+ def prepare_model():
159
+ model = LinearModel(input_dim=1, output_dim=1)
160
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
161
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
162
+ return model, optimizer, scheduler
163
+
164
+
165
+ class BMCLoss(_Loss):
166
+ def __init__(self):
167
+ super(BMCLoss, self).__init__()
168
+ self.noise_sigma = NOISE_SIGMA
169
+
170
+ def forward(self, pred, target):
171
+ pred = pred.reshape(-1, 1)
172
+ target = target.reshape(-1, 1)
173
+ noise_var = self.noise_sigma ** 2
174
+ loss = bmc_loss(pred, target, noise_var)
175
+ return loss
176
+
177
+
178
+ def bmc_loss(pred, target, noise_var):
179
+ logits = - 0.5 * (pred - target.T).pow(2) / noise_var
180
+ loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))
181
+
182
+ return loss * (2 * noise_var)
183
+
184
+ def regress(train_loader):
185
+ training_bundle = []
186
+ criterions = {
187
+ 'MSE': torch.nn.MSELoss(),
188
+ 'Reweight': ReweightL2(),
189
+ 'Balanced MSE': BMCLoss(),
190
+ }
191
+ for criterion_name in criterions:
192
+ criterion = criterions[criterion_name]
193
+ model, optimizer, scheduler = prepare_model()
194
+ training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
195
+ train(train_loader, training_bundle, NUM_EPOCHS)
196
+
197
+ class DummyDataset(Dataset):
198
+ def __init__(self, inputs, targets, prob):
199
+ self.inputs = inputs
200
+ self.targets = targets
201
+ self.prob = prob
202
+
203
+ def __getitem__(self, index):
204
+ return self.inputs[index], self.targets[index], self.prob[index]
205
+
206
+ def __len__(self):
207
+ return len(self.inputs)
208
+
209
+ def run(num1, num2, num3, num4, num5, random_seed, submit):
210
+ sel_num = [num1, num2, num3, num4, num5]
211
+ sel_num = [int(num/100*NUM_PER_BUCKET) for num in sel_num]
212
+ torch.manual_seed(int(random_seed))
213
+ all_x, all_y = prepare_data()
214
+ sel_x, sel_y, prob = select_data(all_x, all_y, sel_num)
215
+ train_loader = DataLoader(DummyDataset(sel_x, sel_y, prob), BATCH_SIZE, shuffle=True)
216
+
217
+ training_df = make_dataframe(sel_x, sel_y)
218
+ g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
219
+ marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG+1), rug=True),
220
+ xlim=((Y_LB - B) / K, (Y_UB - B) / K),
221
+ ylim=(Y_LB, Y_UB),
222
+ space=0.1,
223
+ height=8,
224
+ ratio=2
225
+ )
226
+ g.ax_marg_x.remove()
227
+ sns.lineplot(data=df_oracle, x='x', y='y', hue='Method', ax=g.ax_joint, legend=False)
228
+ plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
229
+ plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
230
+ plt.savefig('training_data.png',bbox_inches='tight')
231
+ plt.close()
232
+
233
+ if submit == 0:
234
+ text = "Press \"Start Regressing!\" if your are happy with the training data"
235
+ else:
236
+ text = "Press \"Prepare Training Data\" to change the training data"
237
+ if submit == 1:
238
+ if not osp.exists('train_log'):
239
+ os.mkdir('train_log')
240
+ for f in os.listdir('train_log'):
241
+ os.remove(osp.join('train_log', f))
242
+ regress(train_loader)
243
+ make_video()
244
+ output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit==1 else None
245
+ video = "movie.mp4" if submit==1 else None
246
+ return 'training_data.png', text, output, video
247
+
248
+
249
+ iface = gr.Interface(
250
+ fn=run,
251
+ inputs=[
252
+ gr.inputs.Slider(0, 100, default=2, step=1, label='Label percentage in [0, 2)'),
253
+ gr.inputs.Slider(0, 100, default=20, step=1, label='Label percentage in [2, 4)'),
254
+ gr.inputs.Slider(0, 100, default=100, step=1, label='Label percentage in [4, 6)'),
255
+ gr.inputs.Slider(0, 100, default=20, step=1, label='Label percentage in [6, 8)'),
256
+ gr.inputs.Slider(0, 100, default=2, step=1, label='Label percentage in [8, 10)'),
257
+ gr.inputs.Number(default=0, label='Random Seed', optional=False),
258
+ gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
259
+ type="index", default=None, label='Mode', optional=False),
260
+ ],
261
+ outputs=[
262
+ gr.outputs.Image(type="file", label="Training data"),
263
+ gr.outputs.Textbox(type="auto", label='What\' s next?'),
264
+ gr.outputs.Image(type="file", label="Regression result"),
265
+ gr.outputs.Video(type='mp4', label='Training process')
266
+ ],
267
+ live=True,
268
+ allow_flagging='never',
269
+ title="Balanced MSE for Imbalanced Visual Regression [CVPR 2022]",
270
+ description="Welcome to the demo for Balanced MSE &#9878;. In this demo, we will work on a simple task: imbalanced <i>linear</i> regression. <br>"
271
+ "To get started, drag the sliders &#128071;&#128071; and create your label distribution!"
272
+ )
273
+ iface.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ torch
3
+ seaborn
4
+ pandas
5
+ ffmpeg-python