|
|
|
""" |
|
Trains a model using one or more GPUs. |
|
""" |
|
from multiprocessing import Process |
|
|
|
import caffe |
|
|
|
|
|
def train( |
|
solver, |
|
snapshot, |
|
gpus, |
|
timing=False, |
|
): |
|
|
|
uid = caffe.NCCL.new_uid() |
|
|
|
caffe.init_log() |
|
caffe.log('Using devices %s' % str(gpus)) |
|
|
|
procs = [] |
|
for rank in range(len(gpus)): |
|
p = Process(target=solve, |
|
args=(solver, snapshot, gpus, timing, uid, rank)) |
|
p.daemon = True |
|
p.start() |
|
procs.append(p) |
|
for p in procs: |
|
p.join() |
|
|
|
|
|
def time(solver, nccl): |
|
fprop = [] |
|
bprop = [] |
|
total = caffe.Timer() |
|
allrd = caffe.Timer() |
|
for _ in range(len(solver.net.layers)): |
|
fprop.append(caffe.Timer()) |
|
bprop.append(caffe.Timer()) |
|
display = solver.param.display |
|
|
|
def show_time(): |
|
if solver.iter % display == 0: |
|
s = '\n' |
|
for i in range(len(solver.net.layers)): |
|
s += 'forw %3d %8s ' % (i, solver.net._layer_names[i]) |
|
s += ': %.2f\n' % fprop[i].ms |
|
for i in range(len(solver.net.layers) - 1, -1, -1): |
|
s += 'back %3d %8s ' % (i, solver.net._layer_names[i]) |
|
s += ': %.2f\n' % bprop[i].ms |
|
s += 'solver total: %.2f\n' % total.ms |
|
s += 'allreduce: %.2f\n' % allrd.ms |
|
caffe.log(s) |
|
|
|
solver.net.before_forward(lambda layer: fprop[layer].start()) |
|
solver.net.after_forward(lambda layer: fprop[layer].stop()) |
|
solver.net.before_backward(lambda layer: bprop[layer].start()) |
|
solver.net.after_backward(lambda layer: bprop[layer].stop()) |
|
solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start())) |
|
solver.add_callback(nccl) |
|
solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time())) |
|
|
|
|
|
def solve(proto, snapshot, gpus, timing, uid, rank): |
|
caffe.set_device(gpus[rank]) |
|
caffe.set_mode_gpu() |
|
caffe.set_solver_count(len(gpus)) |
|
caffe.set_solver_rank(rank) |
|
caffe.set_multiprocess(True) |
|
|
|
solver = caffe.SGDSolver(proto) |
|
if snapshot and len(snapshot) != 0: |
|
solver.restore(snapshot) |
|
|
|
nccl = caffe.NCCL(solver, uid) |
|
nccl.bcast() |
|
|
|
if timing and rank == 0: |
|
time(solver, nccl) |
|
else: |
|
solver.add_callback(nccl) |
|
|
|
if solver.param.layer_wise_reduce: |
|
solver.net.after_backward(nccl) |
|
solver.step(solver.param.max_iter) |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--solver", required=True, help="Solver proto definition.") |
|
parser.add_argument("--snapshot", help="Solver snapshot to restore.") |
|
parser.add_argument("--gpus", type=int, nargs='+', default=[0], |
|
help="List of device ids.") |
|
parser.add_argument("--timing", action='store_true', help="Show timing info.") |
|
args = parser.parse_args() |
|
|
|
train(args.solver, args.snapshot, args.gpus, args.timing) |
|
|