Upload links to pre-trained NFNet weights and add utility functions for loading them, as well as demo colab notebook.

PiperOrigin-RevId: 357692801
This commit is contained in:
Andy Brock
2021-02-16 12:50:59 +00:00
committed by Louise Deason
parent 68a1754d29
commit ba761289c1
19 changed files with 1378 additions and 25 deletions

View File

@@ -1,21 +1,57 @@
# Code for Normalizer-Free Networks
This repository contains code for the ICLR 2021 paper
"Characterizing signal propagation to close the performance gap in unnormalized
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith, and the arXiv preprint
"High-Performance Large-Scale Image Recognition Without Normalization" by
["Characterizing signal propagation to close the performance gap in unnormalized
ResNets,"](https://arxiv.org/abs/2102.06171) by Andrew Brock, Soham De, and
Samuel L. Smith, and the arXiv preprint ["High-Performance Large-Scale Image
Recognition Without Normalization"](http://dpmd.ai/06171) by
Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan.
## Running this code
Install using pip install -r requirements.txt and use one of the experiment.py
files in combination with [JAXline](https://github.com/deepmind/jaxline) to
train models. Optionally copy test.py into
a dir one level up and run it to ensure you can take a single experiment step
with fake data.
Using `run.sh` will create and activate a virtualenv, install all necessary
dependencies and run a test program to ensure that you can import all the
modules and take a single experiment step. To train with this code, use this
virtualenv and use one of the experiment.py files in combination with
[JAXline](https://github.com/deepmind/jaxline). The provided
demo Colab can be run online, or by starting a jupyter notebook within this
virtualenv.
Note that you will need a local copy of ImageNet compatible with the TFDS format
used in dataset.py in order to train on ImageNet.
## Pre-Trained Weights
We provide pre-trained weights for NFNet-F0 through F5 (trained without SAM),
and for NFNet-F6 trained with SAM. All models are pre-trained on ImageNet for
360 epochs at batch size 4096, and are provided as numpy files containing
parameter trees compatible with haiku. In utils.py we provide a
`load_haiku_file` function which loads these parameter trees, and
`flatten_haiku_tree` to convert these to flat dictionaries
which may prove easier to port to other frameworks. Note that we do not provide
model `states`, as these models, lacking batchnorm, do not have running stats.
Note also that the conv layer weights are in the format HWIO, so for frameworks
like PyTorch which use OIHW you'll need to swap the axes appropriately to the
layout you use.
| Model | #FLOPs | #Params | Top-1 | Top-5 | TPUv3 Train | GPU Train | link |
|---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
F0 | 12.38B | 71.5M | 83.6 | 96.8 | 73.3ms | 56.7ms | [haiku](https://storage.googleapis.com/dm-nfnets/F0_haiku.npz)
F1 | 35.54B | 132.6M | 84.7 | 97.1 | 158.5ms | 133.9ms | [haiku](https://storage.googleapis.com/dm-nfnets/F1_haiku.npz)
F2 | 62.59B | 193.8M | 85.1 | 97.3 | 295.8ms | 226.3ms | [haiku](https://storage.googleapis.com/dm-nfnets/F2_haiku.npz)
F3 | 114.76B | 254.9M | 85.7 | 97.5 | 532.2ms | 524.5ms | [haiku](https://storage.googleapis.com/dm-nfnets/F3_haiku.npz)
F4 | 215.24B | 316.1M | 85.9 | 97.6 | 1033.3ms | 1190.6ms | [haiku](https://storage.googleapis.com/dm-nfnets/F4_haiku.npz)
F5 | 289.76B | 377.2M | 86.0 | 97.6 | 1398.5ms | 2177.1ms | [haiku](https://storage.googleapis.com/dm-nfnets/F5_haiku.npz)
F6+SAM | 377.28B | 438.4M | 86.5 | 97.9 | 2774.1ms | - | [haiku](https://storage.googleapis.com/dm-nfnets/F6_haiku.npz)
## Demo Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/nfnets/nfnet_demo_colab.ipynb)
We also include a Colab notebook with a demo showing how to run an NFNet to
classify an image.
## Giving Credit
If you use this code in your work, we ask you to please cite one or both of the
@@ -39,7 +75,7 @@ The reference for Adaptive Gradient Clipping (AGC) and the NFNets models:
@article{brock2021high,
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
title={High-Performance Large-Scale Image Recognition Without Normalization},
journal={arXiv preprint arXiv:},
journal={arXiv preprint arXiv:2102.06171},
year={2021}
}
```

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,6 @@
absl-py==0.10.0
chex>=0.0.2
dill>=0.3.3
dm-haiku>=0.0.3
jax>=0.2.8
jaxlib>=0.1.58

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
#!/bin/sh
# Copyright 2020 Deepmind Technologies Limited.
# Copyright 2021 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Utils."""
import dill
import jax
import jax.numpy as jnp
import tree
@@ -105,3 +106,23 @@ def split_tree(tuple_tree, base_tree, n):
"""Splits tuple_tree with n-tuple leaves into n trees."""
return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree) # pylint: disable=cell-var-from-loop
for i in range(n)]
def load_haiku_file(filename):
"""Loads a haiku parameter tree, using dill."""
with open(filename, 'rb') as in_file:
output = dill.load(in_file)
return output
def flatten_haiku_tree(haiku_dict):
"""Flattens a haiku parameter tree into a flat dictionary."""
out = {}
for module in haiku_dict.keys():
out_module = module.replace('/~/', '.').replace('/', '.')
for key in haiku_dict[module]:
out_key = f'{out_module}.{key}'
out[out_key] = haiku_dict[module][key]
return out