mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-12 11:01:38 +08:00
Upload links to pre-trained NFNet weights and add utility functions for loading them, as well as demo colab notebook.
PiperOrigin-RevId: 357692801
This commit is contained in:
committed by
Louise Deason
parent
68a1754d29
commit
ba761289c1
@@ -1,21 +1,57 @@
|
||||
# Code for Normalizer-Free Networks
|
||||
This repository contains code for the ICLR 2021 paper
|
||||
"Characterizing signal propagation to close the performance gap in unnormalized
|
||||
ResNets," by Andrew Brock, Soham De, and Samuel L. Smith, and the arXiv preprint
|
||||
"High-Performance Large-Scale Image Recognition Without Normalization" by
|
||||
["Characterizing signal propagation to close the performance gap in unnormalized
|
||||
ResNets,"](https://arxiv.org/abs/2102.06171) by Andrew Brock, Soham De, and
|
||||
Samuel L. Smith, and the arXiv preprint ["High-Performance Large-Scale Image
|
||||
Recognition Without Normalization"](http://dpmd.ai/06171) by
|
||||
Andrew Brock, Soham De, Samuel L. Smith, and Karen Simonyan.
|
||||
|
||||
|
||||
## Running this code
|
||||
Install using pip install -r requirements.txt and use one of the experiment.py
|
||||
files in combination with [JAXline](https://github.com/deepmind/jaxline) to
|
||||
train models. Optionally copy test.py into
|
||||
a dir one level up and run it to ensure you can take a single experiment step
|
||||
with fake data.
|
||||
Using `run.sh` will create and activate a virtualenv, install all necessary
|
||||
dependencies and run a test program to ensure that you can import all the
|
||||
modules and take a single experiment step. To train with this code, use this
|
||||
virtualenv and use one of the experiment.py files in combination with
|
||||
[JAXline](https://github.com/deepmind/jaxline). The provided
|
||||
demo Colab can be run online, or by starting a jupyter notebook within this
|
||||
virtualenv.
|
||||
|
||||
Note that you will need a local copy of ImageNet compatible with the TFDS format
|
||||
used in dataset.py in order to train on ImageNet.
|
||||
|
||||
|
||||
## Pre-Trained Weights
|
||||
|
||||
We provide pre-trained weights for NFNet-F0 through F5 (trained without SAM),
|
||||
and for NFNet-F6 trained with SAM. All models are pre-trained on ImageNet for
|
||||
360 epochs at batch size 4096, and are provided as numpy files containing
|
||||
parameter trees compatible with haiku. In utils.py we provide a
|
||||
`load_haiku_file` function which loads these parameter trees, and
|
||||
`flatten_haiku_tree` to convert these to flat dictionaries
|
||||
which may prove easier to port to other frameworks. Note that we do not provide
|
||||
model `states`, as these models, lacking batchnorm, do not have running stats.
|
||||
Note also that the conv layer weights are in the format HWIO, so for frameworks
|
||||
like PyTorch which use OIHW you'll need to swap the axes appropriately to the
|
||||
layout you use.
|
||||
|
||||
|
||||
| Model | #FLOPs | #Params | Top-1 | Top-5 | TPUv3 Train | GPU Train | link |
|
||||
|---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|
||||
F0 | 12.38B | 71.5M | 83.6 | 96.8 | 73.3ms | 56.7ms | [haiku](https://storage.googleapis.com/dm-nfnets/F0_haiku.npz)
|
||||
F1 | 35.54B | 132.6M | 84.7 | 97.1 | 158.5ms | 133.9ms | [haiku](https://storage.googleapis.com/dm-nfnets/F1_haiku.npz)
|
||||
F2 | 62.59B | 193.8M | 85.1 | 97.3 | 295.8ms | 226.3ms | [haiku](https://storage.googleapis.com/dm-nfnets/F2_haiku.npz)
|
||||
F3 | 114.76B | 254.9M | 85.7 | 97.5 | 532.2ms | 524.5ms | [haiku](https://storage.googleapis.com/dm-nfnets/F3_haiku.npz)
|
||||
F4 | 215.24B | 316.1M | 85.9 | 97.6 | 1033.3ms | 1190.6ms | [haiku](https://storage.googleapis.com/dm-nfnets/F4_haiku.npz)
|
||||
F5 | 289.76B | 377.2M | 86.0 | 97.6 | 1398.5ms | 2177.1ms | [haiku](https://storage.googleapis.com/dm-nfnets/F5_haiku.npz)
|
||||
F6+SAM | 377.28B | 438.4M | 86.5 | 97.9 | 2774.1ms | - | [haiku](https://storage.googleapis.com/dm-nfnets/F6_haiku.npz)
|
||||
|
||||
|
||||
## Demo Colab [](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/nfnets/nfnet_demo_colab.ipynb)
|
||||
|
||||
We also include a Colab notebook with a demo showing how to run an NFNet to
|
||||
classify an image.
|
||||
|
||||
|
||||
## Giving Credit
|
||||
|
||||
If you use this code in your work, we ask you to please cite one or both of the
|
||||
@@ -39,7 +75,7 @@ The reference for Adaptive Gradient Clipping (AGC) and the NFNets models:
|
||||
@article{brock2021high,
|
||||
author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
|
||||
title={High-Performance Large-Scale Image Recognition Without Normalization},
|
||||
journal={arXiv preprint arXiv:},
|
||||
journal={arXiv preprint arXiv:2102.06171},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
1295
nfnets/nfnet_demo_colab.ipynb
Normal file
1295
nfnets/nfnet_demo_colab.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
absl-py==0.10.0
|
||||
chex>=0.0.2
|
||||
dill>=0.3.3
|
||||
dm-haiku>=0.0.3
|
||||
jax>=0.2.8
|
||||
jaxlib>=0.1.58
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
# Copyright 2021 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils."""
|
||||
import dill
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tree
|
||||
@@ -105,3 +106,23 @@ def split_tree(tuple_tree, base_tree, n):
|
||||
"""Splits tuple_tree with n-tuple leaves into n trees."""
|
||||
return [tree.map_structure_up_to(base_tree, lambda x: x[i], tuple_tree) # pylint: disable=cell-var-from-loop
|
||||
for i in range(n)]
|
||||
|
||||
|
||||
def load_haiku_file(filename):
|
||||
"""Loads a haiku parameter tree, using dill."""
|
||||
with open(filename, 'rb') as in_file:
|
||||
output = dill.load(in_file)
|
||||
return output
|
||||
|
||||
|
||||
def flatten_haiku_tree(haiku_dict):
|
||||
"""Flattens a haiku parameter tree into a flat dictionary."""
|
||||
out = {}
|
||||
for module in haiku_dict.keys():
|
||||
out_module = module.replace('/~/', '.').replace('/', '.')
|
||||
for key in haiku_dict[module]:
|
||||
out_key = f'{out_module}.{key}'
|
||||
out[out_key] = haiku_dict[module][key]
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user