| import os |
| import requests |
| import torch |
| import torch.nn as nn |
| Gl_z = torch.ones(64,10) |
|
|
| def download_file(url, local_filename, chunk_size=1024): |
| if os.path.exists(local_filename): |
| return local_filename |
| r = requests.get(url, stream=True) |
| with open(local_filename, 'wb') as f: |
| for chunk in r.iter_content(chunk_size=chunk_size): |
| if chunk: |
| f.write(chunk) |
| return local_filename |
|
|
| class AverageMeter(object): |
| def __init__(self): |
| self.reset() |
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| class EDMLoss(nn.Module): |
| def __init__(self): |
| super(EDMLoss, self).__init__() |
|
|
| def forward(self, p_target, p_estimate): |
| assert p_target.shape == p_estimate.shape |
| cdf_target = torch.cumsum(p_target, dim=1) |
| cdf_estimate = torch.cumsum(p_estimate, dim=1) |
|
|
| cdf_diff = cdf_estimate - cdf_target |
| |
| samplewise_emd = torch.mean(torch.pow(torch.abs(cdf_diff), 1)) |
|
|
| return samplewise_emd.mean() |