| import numpy as np |
|
|
|
|
| class SimpleTransformer: |
|
|
| """ |
| SimpleTransformer is a simple class for preprocessing and deprocessing |
| images for caffe. |
| """ |
|
|
| def __init__(self, mean=[128, 128, 128]): |
| self.mean = np.array(mean, dtype=np.float32) |
| self.scale = 1.0 |
|
|
| def set_mean(self, mean): |
| """ |
| Set the mean to subtract for centering the data. |
| """ |
| self.mean = mean |
|
|
| def set_scale(self, scale): |
| """ |
| Set the data scaling. |
| """ |
| self.scale = scale |
|
|
| def preprocess(self, im): |
| """ |
| preprocess() emulate the pre-processing occurring in the vgg16 caffe |
| prototxt. |
| """ |
|
|
| im = np.float32(im) |
| im = im[:, :, ::-1] |
| im -= self.mean |
| im *= self.scale |
| im = im.transpose((2, 0, 1)) |
|
|
| return im |
|
|
| def deprocess(self, im): |
| """ |
| inverse of preprocess() |
| """ |
| im = im.transpose(1, 2, 0) |
| im /= self.scale |
| im += self.mean |
| im = im[:, :, ::-1] |
|
|
| return np.uint8(im) |
|
|
|
|
| class CaffeSolver: |
|
|
| """ |
| Caffesolver is a class for creating a solver.prototxt file. It sets default |
| values and can export a solver parameter file. |
| Note that all parameters are stored as strings. Strings variables are |
| stored as strings in strings. |
| """ |
|
|
| def __init__(self, testnet_prototxt_path="testnet.prototxt", |
| trainnet_prototxt_path="trainnet.prototxt", debug=False): |
|
|
| self.sp = {} |
|
|
| |
| self.sp['base_lr'] = '0.001' |
| self.sp['momentum'] = '0.9' |
|
|
| |
| self.sp['test_iter'] = '100' |
| self.sp['test_interval'] = '250' |
|
|
| |
| self.sp['display'] = '25' |
| self.sp['snapshot'] = '2500' |
| self.sp['snapshot_prefix'] = '"snapshot"' |
|
|
| |
| self.sp['lr_policy'] = '"fixed"' |
|
|
| |
| self.sp['gamma'] = '0.1' |
| self.sp['weight_decay'] = '0.0005' |
| self.sp['train_net'] = '"' + trainnet_prototxt_path + '"' |
| self.sp['test_net'] = '"' + testnet_prototxt_path + '"' |
|
|
| |
| self.sp['max_iter'] = '100000' |
| self.sp['test_initialization'] = 'false' |
| self.sp['average_loss'] = '25' |
| self.sp['iter_size'] = '1' |
|
|
| if (debug): |
| self.sp['max_iter'] = '12' |
| self.sp['test_iter'] = '1' |
| self.sp['test_interval'] = '4' |
| self.sp['display'] = '1' |
|
|
| def add_from_file(self, filepath): |
| """ |
| Reads a caffe solver prototxt file and updates the Caffesolver |
| instance parameters. |
| """ |
| with open(filepath, 'r') as f: |
| for line in f: |
| if line[0] == '#': |
| continue |
| splitLine = line.split(':') |
| self.sp[splitLine[0].strip()] = splitLine[1].strip() |
|
|
| def write(self, filepath): |
| """ |
| Export solver parameters to INPUT "filepath". Sorted alphabetically. |
| """ |
| f = open(filepath, 'w') |
| for key, value in sorted(self.sp.items()): |
| if not(type(value) is str): |
| raise TypeError('All solver parameters must be strings') |
| f.write('%s: %s\n' % (key, value)) |
|
|