Paul Engstler
Initial commit
92f0e98
import numpy, torch, PIL, io, base64, re
from torchvision import transforms
def as_tensor(data, source='zc', target='zc'):
renorm = renormalizer(source=source, target=target)
return renorm(data)
def as_image(data, source='zc', target='byte'):
assert len(data.shape) == 3
renorm = renormalizer(source=source, target=target)
return PIL.Image.fromarray(renorm(data).
permute(1,2,0).cpu().numpy())
def as_url(data, source='zc', size=None):
if isinstance(data, PIL.Image.Image):
img = data
else:
img = as_image(data, source)
if size is not None:
img = img.resize(size, resample=PIL.Image.BILINEAR)
buffered = io.BytesIO()
img.save(buffered, format='png')
b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return 'data:image/png;base64,%s' % (b64)
def from_image(im, target='zc', size=None):
if im.format != 'RGB':
im = im.convert('RGB')
if size is not None:
im = im.resize(size, resample=PIL.Image.BILINEAR)
pt = transforms.functional.to_tensor(im)
renorm = renormalizer(source='pt', target=target)
return renorm(pt)
def from_url(url, target='zc', size=None):
image_data = re.sub('^data:image/.+;base64,', '', url)
im = PIL.Image.open(io.BytesIO(base64.b64decode(image_data)))
if target == 'image' and size is None:
return im
return from_image(im, target, size=size)
def renormalizer(source='zc', target='zc'):
'''
Returns a function that imposes a standard normalization on
the image data. The returned renormalizer operates on either
3d tensor (single image) or 4d tensor (image batch) data.
The normalization target choices are:
zc (default) - zero centered [-1..1]
pt - pytorch [0..1]
imagenet - zero mean, unit stdev imagenet stats (approx [-2.1...2.6])
byte - as from an image file, [0..255]
If a source is provided (a dataset or transform), then, the renormalizer
first reverses any normalization found in the data source before
imposing the specified normalization. When no source is provided,
the input data is assumed to be pytorch-normalized (range [0..1]).
'''
if isinstance(source, str):
oldoffset, oldscale = OFFSET_SCALE[source]
else:
normalizer = find_normalizer(source)
oldoffset, oldscale = (
(normalizer.mean, normalizer.std) if normalizer is not None
else OFFSET_SCALE['pt'])
newoffset, newscale = (target if isinstance(target, tuple)
else OFFSET_SCALE[target])
return Renormalizer(oldoffset, oldscale, newoffset, newscale,
tobyte=(target == 'byte'))
# The three commonly-seen image normalization schemes.
OFFSET_SCALE=dict(
pt=([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
zc=([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
imagenet=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
imagenet_meanonly=([0.485, 0.456, 0.406],
[1.0/255, 1.0/255, 1.0/255]),
places_meanonly=([0.475, 0.441, 0.408],
[1.0/255, 1.0/255, 1.0/255]),
byte=([0.0, 0.0, 0.0], [1.0/255, 1.0/255, 1.0/255]))
NORMALIZER={k: transforms.Normalize(*OFFSET_SCALE[k]) for k in OFFSET_SCALE}
def find_normalizer(source=None):
'''
Crawl around the transforms attached to a dataset looking for a
Normalize transform to return.
'''
if source is None:
return None
if isinstance(source, (transforms.Normalize, Renormalizer)):
return source
t = getattr(source, 'transform', None)
if t is not None:
return find_normalizer(t)
ts = getattr(source, 'transforms', None)
if ts is not None:
for t in reversed(ts):
result = find_normalizer(t)
if result is not None:
return result
return None
class Renormalizer:
def __init__(self, oldoffset, oldscale, newoffset, newscale, tobyte=False):
self.mul = torch.from_numpy(
numpy.array(oldscale) / numpy.array(newscale))
self.add = torch.from_numpy(
(numpy.array(oldoffset) - numpy.array(newoffset))
/ numpy.array(newscale))
self.tobyte = tobyte
# Store these away to allow the data to be renormalized again
self.mean = newoffset
self.std = newscale
def __call__(self, data):
mul, add = [d.to(data.device, data.dtype) for d in [self.mul, self.add]]
if data.ndimension() == 3:
mul, add = [d[:, None, None] for d in [mul, add]]
elif data.ndimension() == 4:
mul, add = [d[None, :, None, None] for d in [mul, add]]
result = data.mul(mul).add_(add)
if self.tobyte:
result = result.clamp(0, 255).byte()
return result