From 817a4f29bf4fbecd71aa871f62e373605414c046 Mon Sep 17 00:00:00 2001 From: Allan Zhou <8732147+AllanYangZhou@users.noreply.github.com> Date: Thu, 4 Jan 2024 22:19:58 +0000 Subject: [PATCH 1/2] Lopt conv+ln --- .../univ_nfn/learned_opt/learned_opts.py | 28 ++++++++++--------- learned_optimization/tasks/fixed/conv.py | 15 ++++++++++ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py index bec0f64..9b54a87 100644 --- a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py +++ b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py @@ -97,22 +97,24 @@ def make_hk_perm_spec(mlp_params): return perm_spec -def make_hk_cnn_perm_spec(mlp_params): +def make_hk_cnn_perm_spec(params): """Produces perm spec for a haiku cnn.""" perm_spec = {} - for i in range(len(mlp_params)): - if i < len(mlp_params) - 1: - if i == 0: - name = 'conv2_d' - else: - name = f'conv2_d_{i}' - perm_spec[name] = { - 'w': (-i, -(len(mlp_params) + i), i, i + 1), - 'b': (i + 1,), - } + num_convs = len([k for k in params if k.startswith('conv2_d')]) + for i in range(num_convs): + if i == 0: + conv_name = 'conv2_d' + ln_name = 'layer_norm' else: - name = 'linear' - perm_spec[name] = {'w': (i, i + 1), 'b': (i + 1,)} + conv_name = f'conv2_d_{i}' + ln_name = 'layer_norm_{i}' + perm_spec[conv_name] = { + 'w': (-i, -(len(params) + i), i, i + 1), + 'b': (i + 1,), + } + if ln_name in params: # layernorm is optional + perm_spec[ln_name] = {'offset': (i + 1,), 'scale': (i + 1,)} + perm_spec['linear'] = {'w': (num_convs, num_convs + 1), 'b': (num_convs + 1,)} # final linear layer return perm_spec diff --git a/learned_optimization/tasks/fixed/conv.py b/learned_optimization/tasks/fixed/conv.py index 3ffbc46..14753ae 100644 --- a/learned_optimization/tasks/fixed/conv.py +++ b/learned_optimization/tasks/fixed/conv.py @@ -100,6 +100,21 @@ def normalizer(self, loss): 1.5 * jnp.log(self.datasets.extra_info["num_classes"])) +@gin.configurable +def Conv_Cifar10_8_16x32_layernorm(): + """A 3 hidden layer convnet with layernorm designed for 16x16 cifar10.""" + + def norm_fn(x): + return hk.LayerNorm(create_scale=True, create_offset=True, axis=-1)(x) + + base_model_fn = _cross_entropy_pool_loss([16, 32], + jax.nn.relu, + num_classes=10, + norm_fn=norm_fn) + datasets = image.cifar10_datasets(batch_size=128, image_size=(8, 8)) + return _ConvTask(base_model_fn, datasets) + + @gin.configurable def Conv_Cifar10_8_16x32(): """A 3 hidden layer convnet designed for 16x16 cifar10.""" From a5b8c70b456c67adff5a4e7fc0af6a00b5e8927b Mon Sep 17 00:00:00 2001 From: Allan Zhou <8732147+AllanYangZhou@users.noreply.github.com> Date: Thu, 4 Jan 2024 23:19:29 +0000 Subject: [PATCH 2/2] small bugfix --- .../research/univ_nfn/learned_opt/learned_opts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py index 9b54a87..9c31a98 100644 --- a/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py +++ b/learned_optimization/research/univ_nfn/learned_opt/learned_opts.py @@ -107,7 +107,7 @@ def make_hk_cnn_perm_spec(params): ln_name = 'layer_norm' else: conv_name = f'conv2_d_{i}' - ln_name = 'layer_norm_{i}' + ln_name = f'layer_norm_{i}' perm_spec[conv_name] = { 'w': (-i, -(len(params) + i), i, i + 1), 'b': (i + 1,),