| | import numpy as np |
| |
|
| |
|
| | class SimpleTransformer: |
| |
|
| | """ |
| | SimpleTransformer is a simple class for preprocessing and deprocessing |
| | images for caffe. |
| | """ |
| |
|
| | def __init__(self, mean=[128, 128, 128]): |
| | self.mean = np.array(mean, dtype=np.float32) |
| | self.scale = 1.0 |
| |
|
| | def set_mean(self, mean): |
| | """ |
| | Set the mean to subtract for centering the data. |
| | """ |
| | self.mean = mean |
| |
|
| | def set_scale(self, scale): |
| | """ |
| | Set the data scaling. |
| | """ |
| | self.scale = scale |
| |
|
| | def preprocess(self, im): |
| | """ |
| | preprocess() emulate the pre-processing occurring in the vgg16 caffe |
| | prototxt. |
| | """ |
| |
|
| | im = np.float32(im) |
| | im = im[:, :, ::-1] |
| | im -= self.mean |
| | im *= self.scale |
| | im = im.transpose((2, 0, 1)) |
| |
|
| | return im |
| |
|
| | def deprocess(self, im): |
| | """ |
| | inverse of preprocess() |
| | """ |
| | im = im.transpose(1, 2, 0) |
| | im /= self.scale |
| | im += self.mean |
| | im = im[:, :, ::-1] |
| |
|
| | return np.uint8(im) |
| |
|
| |
|
| | class CaffeSolver: |
| |
|
| | """ |
| | Caffesolver is a class for creating a solver.prototxt file. It sets default |
| | values and can export a solver parameter file. |
| | Note that all parameters are stored as strings. Strings variables are |
| | stored as strings in strings. |
| | """ |
| |
|
| | def __init__(self, testnet_prototxt_path="testnet.prototxt", |
| | trainnet_prototxt_path="trainnet.prototxt", debug=False): |
| |
|
| | self.sp = {} |
| |
|
| | |
| | self.sp['base_lr'] = '0.001' |
| | self.sp['momentum'] = '0.9' |
| |
|
| | |
| | self.sp['test_iter'] = '100' |
| | self.sp['test_interval'] = '250' |
| |
|
| | |
| | self.sp['display'] = '25' |
| | self.sp['snapshot'] = '2500' |
| | self.sp['snapshot_prefix'] = '"snapshot"' |
| |
|
| | |
| | self.sp['lr_policy'] = '"fixed"' |
| |
|
| | |
| | self.sp['gamma'] = '0.1' |
| | self.sp['weight_decay'] = '0.0005' |
| | self.sp['train_net'] = '"' + trainnet_prototxt_path + '"' |
| | self.sp['test_net'] = '"' + testnet_prototxt_path + '"' |
| |
|
| | |
| | self.sp['max_iter'] = '100000' |
| | self.sp['test_initialization'] = 'false' |
| | self.sp['average_loss'] = '25' |
| | self.sp['iter_size'] = '1' |
| |
|
| | if (debug): |
| | self.sp['max_iter'] = '12' |
| | self.sp['test_iter'] = '1' |
| | self.sp['test_interval'] = '4' |
| | self.sp['display'] = '1' |
| |
|
| | def add_from_file(self, filepath): |
| | """ |
| | Reads a caffe solver prototxt file and updates the Caffesolver |
| | instance parameters. |
| | """ |
| | with open(filepath, 'r') as f: |
| | for line in f: |
| | if line[0] == '#': |
| | continue |
| | splitLine = line.split(':') |
| | self.sp[splitLine[0].strip()] = splitLine[1].strip() |
| |
|
| | def write(self, filepath): |
| | """ |
| | Export solver parameters to INPUT "filepath". Sorted alphabetically. |
| | """ |
| | f = open(filepath, 'w') |
| | for key, value in sorted(self.sp.items()): |
| | if not(type(value) is str): |
| | raise TypeError('All solver parameters must be strings') |
| | f.write('%s: %s\n' % (key, value)) |
| |
|