diff --git a/train_resnet.py b/train_resnet.py index 024dbb6..93c708e 100644 --- a/train_resnet.py +++ b/train_resnet.py @@ -106,29 +106,28 @@ def main(): rand_mirror = False, num_parts = kv.num_workers, part_index = kv.rank) - model = mx.model.FeedForward( - ctx = devs, - symbol = symbol, + mod = mx.mod.Module( + context = devs, + symbol = symbol) + model.fit( + train_data = train, + eval_data = val, + eval_metric = ['acc', 'ce'] if args.data_type=='cifar10' else + ['acc', mx.metric.create('top_k_accuracy', top_k = 5)], + kvstore = kv, arg_params = arg_params, aux_params = aux_params, num_epoch = 200 if args.data_type == "cifar10" else 120, begin_epoch = begin_epoch, - learning_rate = args.lr, - momentum = args.mom, - wd = args.wd, optimizer = 'nag', # optimizer = 'sgd', + optimizer_params = {'learning_rate': args.lr, + 'momentum': args.mom, + 'wd': args.wd, + 'lr_scheduler': multi_factor_scheduler(begin_epoch, epoch_size, step=[120, 160], factor=0.1) + if args.data_type=='cifar10' else + multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90], factor=0.1)} initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), - lr_scheduler = multi_factor_scheduler(begin_epoch, epoch_size, step=[120, 160], factor=0.1) - if args.data_type=='cifar10' else - multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90], factor=0.1), - ) - model.fit( - X = train, - eval_data = val, - eval_metric = ['acc', 'ce'] if args.data_type=='cifar10' else - ['acc', mx.metric.create('top_k_accuracy', top_k = 5)], - kvstore = kv, batch_end_callback = mx.callback.Speedometer(args.batch_size, args.frequent), epoch_end_callback = checkpoint) # logging.info("top-1 and top-5 acc is {}".format(model.score(X = val,