Files
deepmind-research/mmv/models/resnet.py

329 lines
9.6 KiB
Python
Raw Normal View History

# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet V2 modules.
Equivalent to hk.Resnet except accepting a final_endpoint to return
intermediate activations.
"""
from typing import Optional, Sequence, Text, Type, Union
import haiku as hk
import jax
import jax.numpy as jnp
from mmv.models import types
class BottleneckBlock(hk.Module):
"""Implements a bottleneck residual block (ResNet50 and ResNet101)."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None):
super(BottleneckBlock, self).__init__(name=name)
self._channels = channels
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
if self._use_projection:
self._proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')
self._conv_0 = hk.Conv2D(
output_channels=channels // 4,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')
self._conv_1 = hk.Conv2D(
output_channels=channels // 4,
kernel_shape=3,
stride=stride,
with_bias=False,
padding='SAME',
name='conv_1')
self._conv_2 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_2')
def __call__(self,
inputs,
is_training):
net = inputs
shortcut = inputs
for i, conv_i in enumerate([self._conv_0, self._conv_1, self._conv_2]):
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
if i == 0 and self._use_projection:
shortcut = self._proj_conv(net)
# Now do the convs.
net = conv_i(net)
return net + shortcut
class BasicBlock(hk.Module):
"""Implements a basic residual block (ResNet18 and ResNet34)."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None):
super(BasicBlock, self).__init__(name=name)
self._channels = channels
self._stride = stride
self._use_projection = use_projection
self._normalize_fn = normalize_fn
if self._use_projection:
self._proj_conv = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=stride,
with_bias=False,
padding='SAME',
name='shortcut_conv')
self._conv_0 = hk.Conv2D(
output_channels=channels,
kernel_shape=1,
stride=1,
with_bias=False,
padding='SAME',
name='conv_0')
self._conv_1 = hk.Conv2D(
output_channels=channels,
kernel_shape=3,
stride=stride,
with_bias=False,
padding='SAME',
name='conv_1')
def __call__(self,
inputs,
is_training):
net = inputs
shortcut = inputs
for i, conv_i in enumerate([self._conv_0, self._conv_1]):
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
if i == 0 and self._use_projection:
shortcut = self._proj_conv(net)
# Now do the convs.
net = conv_i(net)
return net + shortcut
class ResNetUnit(hk.Module):
"""Unit (group of blocks) for ResNet."""
# pylint:disable=g-bare-generic
def __init__(self,
channels: int,
num_blocks: int,
stride: Union[int, Sequence[int]],
block_module: Type[BottleneckBlock],
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None,
remat: bool = False):
super(ResNetUnit, self).__init__(name=name)
self._channels = channels
self._num_blocks = num_blocks
self._stride = stride
self._normalize_fn = normalize_fn
self._block_module = block_module
self._remat = remat
def __call__(self,
inputs,
is_training):
input_channels = inputs.shape[-1]
self._blocks = []
for id_block in range(self._num_blocks):
use_projection = id_block == 0 and self._channels != input_channels
self._blocks.append(
self._block_module(
channels=self._channels,
stride=self._stride if id_block == 0 else 1,
use_projection=use_projection,
normalize_fn=self._normalize_fn,
name='block_%d' % id_block))
net = inputs
for block in self._blocks:
if self._remat:
# Note: we can ignore cell-var-from-loop because the lambda is evaluated
# inside every iteration of the loop. This is needed to go around the
# way variables are passed to jax.remat.
net = hk.remat(lambda x: block(x, is_training=is_training))(net) # pylint: disable=cell-var-from-loop
else:
net = block(net, is_training=is_training)
return net
class ResNetV2(hk.Module):
"""ResNetV2 model."""
# Endpoints of the model in order.
VALID_ENDPOINTS = (
'resnet_stem',
'resnet_unit_0',
'resnet_unit_1',
'resnet_unit_2',
'resnet_unit_3',
'last_conv',
'output',
)
# pylint:disable=g-bare-generic
def __init__(self,
depth=50,
num_classes: Optional[int] = 1000,
width_mult: int = 1,
normalize_fn: Optional[types.NormalizeFn] = None,
name: Optional[Text] = None,
remat: bool = False):
"""Creates ResNetV2 Haiku module.
Args:
depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
num_classes: (int) Number of outputs in final layer. If None will not add
a classification head and will return the output embedding.
width_mult: multiplier for channel width.
normalize_fn: normalization function, see helpers/utils.py
name: Name of the module.
remat: Whether to rematerialize intermediate activations (saves memory).
"""
super(ResNetV2, self).__init__(name=name)
self._normalize_fn = normalize_fn
self._num_classes = num_classes
self._width_mult = width_mult
self._strides = [1, 2, 2, 2]
num_blocks = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
200: [3, 24, 36, 3],
}
if depth not in num_blocks:
raise ValueError(
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
self._num_blocks = num_blocks[depth]
if depth >= 50:
self._block_module = BottleneckBlock
self._channels = [256, 512, 1024, 2048]
else:
self._block_module = BasicBlock
self._channels = [64, 128, 256, 512]
self._initial_conv = hk.Conv2D(
output_channels=64 * self._width_mult,
kernel_shape=7,
stride=2,
with_bias=False,
padding='SAME',
name='initial_conv')
if remat:
self._initial_conv = hk.remat(self._initial_conv)
self._block_groups = []
for i in range(4):
self._block_groups.append(
ResNetUnit(
channels=self._channels[i] * self._width_mult,
num_blocks=self._num_blocks[i],
block_module=self._block_module,
stride=self._strides[i],
normalize_fn=self._normalize_fn,
name='block_group_%d' % i,
remat=remat))
if num_classes is not None:
self._logits_layer = hk.Linear(
output_size=num_classes, w_init=jnp.zeros, name='logits')
def __call__(self, inputs, is_training, final_endpoint='output'):
self._final_endpoint = final_endpoint
net = self._initial_conv(inputs)
net = hk.max_pool(
net, window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')
end_point = 'resnet_stem'
if self._final_endpoint == end_point:
return net
for i_group, block_group in enumerate(self._block_groups):
net = block_group(net, is_training=is_training)
end_point = f'resnet_unit_{i_group}'
if self._final_endpoint == end_point:
return net
end_point = 'last_conv'
if self._final_endpoint == end_point:
return net
if self._normalize_fn is not None:
net = self._normalize_fn(net, is_training=is_training)
net = jax.nn.relu(net)
# The actual representation
net = jnp.mean(net, axis=[1, 2])
assert self._final_endpoint == 'output'
if self._num_classes is None:
# If num_classes was None, we just return the output
# of the last block, without fully connected layer.
return net
return self._logits_layer(net)