diff --git a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py index bec0f64..9c31a98 100644 --- a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py +++ b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py @@ -97,22 +97,24 @@ def make_hk_perm_spec(mlp_params): return perm_spec -def make_hk_cnn_perm_spec(mlp_params): +def make_hk_cnn_perm_spec(params): """Produces perm spec for a haiku cnn.""" perm_spec = {} - for i in range(len(mlp_params)): - if i < len(mlp_params) - 1: - if i == 0: - name = 'conv2_d' - else: - name = f'conv2_d_{i}' - perm_spec[name] = { - 'w': (-i, -(len(mlp_params) + i), i, i + 1), - 'b': (i + 1,), - } + num_convs = len([k for k in params if k.startswith('conv2_d')]) + for i in range(num_convs): + if i == 0: + conv_name = 'conv2_d' + ln_name = 'layer_norm' else: - name = 'linear' - perm_spec[name] = {'w': (i, i + 1), 'b': (i + 1,)} + conv_name = f'conv2_d_{i}' + ln_name = f'layer_norm_{i}' + perm_spec[conv_name] = { + 'w': (-i, -(len(params) + i), i, i + 1), + 'b': (i + 1,), + } + if ln_name in params: # layernorm is optional + perm_spec[ln_name] = {'offset': (i + 1,), 'scale': (i + 1,)} + perm_spec['linear'] = {'w': (num_convs, num_convs + 1), 'b': (num_convs + 1,)} # final linear layer return perm_spec diff --git a/learned_optimization/tasks/fixed/conv.py b/learned_optimization/tasks/fixed/conv.py index 3ffbc46..14753ae 100644 --- a/learned_optimization/tasks/fixed/conv.py +++ b/learned_optimization/tasks/fixed/conv.py @@ -100,6 +100,21 @@ def normalizer(self, loss): 1.5 * jnp.log(self.datasets.extra_info["num_classes"])) +@gin.configurable +def Conv_Cifar10_8_16x32_layernorm(): + """A 3 hidden layer convnet with layernorm designed for 16x16 cifar10.""" + + def norm_fn(x): + return hk.LayerNorm(create_scale=True, create_offset=True, axis=-1)(x) + + base_model_fn = _cross_entropy_pool_loss([16, 32], + jax.nn.relu, + num_classes=10, + norm_fn=norm_fn) + datasets = image.cifar10_datasets(batch_size=128, image_size=(8, 8)) + return _ConvTask(base_model_fn, datasets) + + @gin.configurable def Conv_Cifar10_8_16x32(): """A 3 hidden layer convnet designed for 16x16 cifar10."""