mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 17:18:46 +08:00
329 lines
9.6 KiB
Python
329 lines
9.6 KiB
Python
# Copyright 2020 DeepMind Technologies Limited.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""ResNet V2 modules.
|
|
|
|
Equivalent to hk.Resnet except accepting a final_endpoint to return
|
|
intermediate activations.
|
|
"""
|
|
|
|
from typing import Optional, Sequence, Text, Type, Union
|
|
|
|
import haiku as hk
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
from mmv.models import types
|
|
|
|
|
|
class BottleneckBlock(hk.Module):
|
|
"""Implements a bottleneck residual block (ResNet50 and ResNet101)."""
|
|
|
|
# pylint:disable=g-bare-generic
|
|
def __init__(self,
|
|
channels: int,
|
|
stride: Union[int, Sequence[int]],
|
|
use_projection: bool,
|
|
normalize_fn: Optional[types.NormalizeFn] = None,
|
|
name: Optional[Text] = None):
|
|
super(BottleneckBlock, self).__init__(name=name)
|
|
self._channels = channels
|
|
self._stride = stride
|
|
self._use_projection = use_projection
|
|
self._normalize_fn = normalize_fn
|
|
|
|
if self._use_projection:
|
|
self._proj_conv = hk.Conv2D(
|
|
output_channels=channels,
|
|
kernel_shape=1,
|
|
stride=stride,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='shortcut_conv')
|
|
|
|
self._conv_0 = hk.Conv2D(
|
|
output_channels=channels // 4,
|
|
kernel_shape=1,
|
|
stride=1,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='conv_0')
|
|
|
|
self._conv_1 = hk.Conv2D(
|
|
output_channels=channels // 4,
|
|
kernel_shape=3,
|
|
stride=stride,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='conv_1')
|
|
|
|
self._conv_2 = hk.Conv2D(
|
|
output_channels=channels,
|
|
kernel_shape=1,
|
|
stride=1,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='conv_2')
|
|
|
|
def __call__(self,
|
|
inputs,
|
|
is_training):
|
|
net = inputs
|
|
shortcut = inputs
|
|
|
|
for i, conv_i in enumerate([self._conv_0, self._conv_1, self._conv_2]):
|
|
if self._normalize_fn is not None:
|
|
net = self._normalize_fn(net, is_training=is_training)
|
|
net = jax.nn.relu(net)
|
|
if i == 0 and self._use_projection:
|
|
shortcut = self._proj_conv(net)
|
|
|
|
# Now do the convs.
|
|
net = conv_i(net)
|
|
|
|
return net + shortcut
|
|
|
|
|
|
class BasicBlock(hk.Module):
|
|
"""Implements a basic residual block (ResNet18 and ResNet34)."""
|
|
|
|
# pylint:disable=g-bare-generic
|
|
def __init__(self,
|
|
channels: int,
|
|
stride: Union[int, Sequence[int]],
|
|
use_projection: bool,
|
|
normalize_fn: Optional[types.NormalizeFn] = None,
|
|
name: Optional[Text] = None):
|
|
super(BasicBlock, self).__init__(name=name)
|
|
self._channels = channels
|
|
self._stride = stride
|
|
self._use_projection = use_projection
|
|
self._normalize_fn = normalize_fn
|
|
|
|
if self._use_projection:
|
|
self._proj_conv = hk.Conv2D(
|
|
output_channels=channels,
|
|
kernel_shape=1,
|
|
stride=stride,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='shortcut_conv')
|
|
|
|
self._conv_0 = hk.Conv2D(
|
|
output_channels=channels,
|
|
kernel_shape=1,
|
|
stride=1,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='conv_0')
|
|
|
|
self._conv_1 = hk.Conv2D(
|
|
output_channels=channels,
|
|
kernel_shape=3,
|
|
stride=stride,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='conv_1')
|
|
|
|
def __call__(self,
|
|
inputs,
|
|
is_training):
|
|
net = inputs
|
|
shortcut = inputs
|
|
|
|
for i, conv_i in enumerate([self._conv_0, self._conv_1]):
|
|
if self._normalize_fn is not None:
|
|
net = self._normalize_fn(net, is_training=is_training)
|
|
net = jax.nn.relu(net)
|
|
if i == 0 and self._use_projection:
|
|
shortcut = self._proj_conv(net)
|
|
|
|
# Now do the convs.
|
|
net = conv_i(net)
|
|
|
|
return net + shortcut
|
|
|
|
|
|
class ResNetUnit(hk.Module):
|
|
"""Unit (group of blocks) for ResNet."""
|
|
|
|
# pylint:disable=g-bare-generic
|
|
def __init__(self,
|
|
channels: int,
|
|
num_blocks: int,
|
|
stride: Union[int, Sequence[int]],
|
|
block_module: Type[BottleneckBlock],
|
|
normalize_fn: Optional[types.NormalizeFn] = None,
|
|
name: Optional[Text] = None,
|
|
remat: bool = False):
|
|
super(ResNetUnit, self).__init__(name=name)
|
|
self._channels = channels
|
|
self._num_blocks = num_blocks
|
|
self._stride = stride
|
|
self._normalize_fn = normalize_fn
|
|
self._block_module = block_module
|
|
self._remat = remat
|
|
|
|
def __call__(self,
|
|
inputs,
|
|
is_training):
|
|
|
|
input_channels = inputs.shape[-1]
|
|
|
|
self._blocks = []
|
|
for id_block in range(self._num_blocks):
|
|
use_projection = id_block == 0 and self._channels != input_channels
|
|
self._blocks.append(
|
|
self._block_module(
|
|
channels=self._channels,
|
|
stride=self._stride if id_block == 0 else 1,
|
|
use_projection=use_projection,
|
|
normalize_fn=self._normalize_fn,
|
|
name='block_%d' % id_block))
|
|
|
|
net = inputs
|
|
for block in self._blocks:
|
|
if self._remat:
|
|
# Note: we can ignore cell-var-from-loop because the lambda is evaluated
|
|
# inside every iteration of the loop. This is needed to go around the
|
|
# way variables are passed to jax.remat.
|
|
net = hk.remat(lambda x: block(x, is_training=is_training))(net) # pylint: disable=cell-var-from-loop
|
|
else:
|
|
net = block(net, is_training=is_training)
|
|
return net
|
|
|
|
|
|
class ResNetV2(hk.Module):
|
|
"""ResNetV2 model."""
|
|
|
|
# Endpoints of the model in order.
|
|
VALID_ENDPOINTS = (
|
|
'resnet_stem',
|
|
'resnet_unit_0',
|
|
'resnet_unit_1',
|
|
'resnet_unit_2',
|
|
'resnet_unit_3',
|
|
'last_conv',
|
|
'output',
|
|
)
|
|
|
|
# pylint:disable=g-bare-generic
|
|
def __init__(self,
|
|
depth=50,
|
|
num_classes: Optional[int] = 1000,
|
|
width_mult: int = 1,
|
|
normalize_fn: Optional[types.NormalizeFn] = None,
|
|
name: Optional[Text] = None,
|
|
remat: bool = False):
|
|
"""Creates ResNetV2 Haiku module.
|
|
|
|
Args:
|
|
depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
|
|
num_classes: (int) Number of outputs in final layer. If None will not add
|
|
a classification head and will return the output embedding.
|
|
width_mult: multiplier for channel width.
|
|
normalize_fn: normalization function, see helpers/utils.py
|
|
name: Name of the module.
|
|
remat: Whether to rematerialize intermediate activations (saves memory).
|
|
"""
|
|
super(ResNetV2, self).__init__(name=name)
|
|
self._normalize_fn = normalize_fn
|
|
self._num_classes = num_classes
|
|
self._width_mult = width_mult
|
|
|
|
self._strides = [1, 2, 2, 2]
|
|
num_blocks = {
|
|
18: [2, 2, 2, 2],
|
|
34: [3, 4, 6, 3],
|
|
50: [3, 4, 6, 3],
|
|
101: [3, 4, 23, 3],
|
|
152: [3, 8, 36, 3],
|
|
200: [3, 24, 36, 3],
|
|
}
|
|
if depth not in num_blocks:
|
|
raise ValueError(
|
|
f'`depth` should be in {list(num_blocks.keys())} ({depth} given).')
|
|
self._num_blocks = num_blocks[depth]
|
|
|
|
if depth >= 50:
|
|
self._block_module = BottleneckBlock
|
|
self._channels = [256, 512, 1024, 2048]
|
|
else:
|
|
self._block_module = BasicBlock
|
|
self._channels = [64, 128, 256, 512]
|
|
|
|
self._initial_conv = hk.Conv2D(
|
|
output_channels=64 * self._width_mult,
|
|
kernel_shape=7,
|
|
stride=2,
|
|
with_bias=False,
|
|
padding='SAME',
|
|
name='initial_conv')
|
|
|
|
if remat:
|
|
self._initial_conv = hk.remat(self._initial_conv)
|
|
|
|
self._block_groups = []
|
|
for i in range(4):
|
|
self._block_groups.append(
|
|
ResNetUnit(
|
|
channels=self._channels[i] * self._width_mult,
|
|
num_blocks=self._num_blocks[i],
|
|
block_module=self._block_module,
|
|
stride=self._strides[i],
|
|
normalize_fn=self._normalize_fn,
|
|
name='block_group_%d' % i,
|
|
remat=remat))
|
|
|
|
if num_classes is not None:
|
|
self._logits_layer = hk.Linear(
|
|
output_size=num_classes, w_init=jnp.zeros, name='logits')
|
|
|
|
def __call__(self, inputs, is_training, final_endpoint='output'):
|
|
self._final_endpoint = final_endpoint
|
|
net = self._initial_conv(inputs)
|
|
net = hk.max_pool(
|
|
net, window_shape=(1, 3, 3, 1),
|
|
strides=(1, 2, 2, 1),
|
|
padding='SAME')
|
|
end_point = 'resnet_stem'
|
|
if self._final_endpoint == end_point:
|
|
return net
|
|
|
|
for i_group, block_group in enumerate(self._block_groups):
|
|
net = block_group(net, is_training=is_training)
|
|
end_point = f'resnet_unit_{i_group}'
|
|
if self._final_endpoint == end_point:
|
|
return net
|
|
|
|
end_point = 'last_conv'
|
|
if self._final_endpoint == end_point:
|
|
return net
|
|
|
|
if self._normalize_fn is not None:
|
|
net = self._normalize_fn(net, is_training=is_training)
|
|
net = jax.nn.relu(net)
|
|
|
|
# The actual representation
|
|
net = jnp.mean(net, axis=[1, 2])
|
|
|
|
assert self._final_endpoint == 'output'
|
|
if self._num_classes is None:
|
|
# If num_classes was None, we just return the output
|
|
# of the last block, without fully connected layer.
|
|
return net
|
|
|
|
return self._logits_layer(net)
|