Files
Alistair Muldal d029e06aa2 Ensure pip is up to date in kfac_ferminet_alpha/run.sh
Also create the `venv` in `/tmp/` rather than messing with the source tree.

PiperOrigin-RevId: 368225759
2021-04-13 20:58:16 +01:00

186 lines
6.9 KiB
Python

# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model code. Provided settings are identical to what was used in the paper."""
import sonnet as snt
import tensorflow.compat.v1 as tf
from object_attention_for_reasoning import transformer
QUESTION_VOCAB_SIZE = 82
ANSWER_VOCAB_SIZE = 22
MAX_QUESTION_LENGTH = 20
MAX_CHOICE_LENGTH = 12
NUM_CHOICES = 4
EMBED_DIM = 16
PRETRAINED_MODEL_CONFIG = dict(
use_relative_positions=True,
shuffle_objects=True,
transformer_layers=28,
head_size=128,
num_heads=10,
embed_dim=EMBED_DIM,
)
def append_ids(tensor, id_vector, axis):
id_vector = tf.constant(id_vector, tf.float32)
for a in range(len(tensor.shape)):
if a != axis:
id_vector = tf.expand_dims(id_vector, axis=a)
tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
id_tensor = tf.tile(id_vector, tiling_vector)
return tf.concat([tensor, id_tensor], axis=axis)
class ClevrerTransformerModel(object):
"""Model from Ding et al. 2020 (https://arxiv.org/abs/2012.08508)."""
def __init__(self, use_relative_positions, shuffle_objects,
transformer_layers, num_heads, head_size, embed_dim):
"""Instantiate Sonnet modules."""
self._embed_dim = embed_dim
self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2)
self._shuffle_objects = shuffle_objects
self._memory_transformer = transformer.TransformerTower(
value_size=embed_dim + 2,
num_heads=num_heads,
num_layers=transformer_layers,
use_relative_positions=use_relative_positions,
causal=False)
self._final_layer_mc = snt.Sequential(
[snt.Linear(head_size), tf.nn.relu, snt.Linear(1)])
self._final_layer_descriptive = snt.Sequential(
[snt.Linear(head_size), tf.nn.relu,
snt.Linear(ANSWER_VOCAB_SIZE)])
self._dummy = tf.get_variable("dummy", [embed_dim + 2], tf.float32,
initializer=tf.zeros_initializer)
self._infill_linear = snt.Linear(embed_dim + 2)
self._mask_embedding = tf.get_variable(
"mask", [embed_dim + 2], tf.float32, initializer=tf.zeros_initializer)
def _apply_transformers(self, lang_embedding, vision_embedding):
"""Applies transformer to language and vision input.
Args:
lang_embedding: tensor,
vision_embedding: tensor, "validation", or "test".
Returns:
tuple, output at dummy token, all output embeddings, infill loss
"""
def _unroll(tensor):
"""Unroll the time dimension into the object dimension."""
return tf.reshape(
tensor, [tensor.shape[0], -1, tensor.shape[3]])
words = append_ids(lang_embedding, [1, 0], axis=2)
dummy_word = tf.tile(self._dummy[None, None, :], [tf.shape(words)[0], 1, 1])
vision_embedding = append_ids(vision_embedding, [0, 1], axis=3)
vision_over_time = _unroll(vision_embedding)
transformer_input = tf.concat([dummy_word, words, vision_over_time], axis=1)
output, _ = self._memory_transformer(transformer_input,
is_training=False)
return output[:, 0, :]
def apply_model_descriptive(self, inputs):
"""Applies model to CLEVRER descriptive questions.
Args:
inputs: dict of form: {
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
}
Returns:
Tensor of shape [batch, ANSWER_VOCAB_SIZE], representing logits for each
possible answer word.
"""
question = inputs["question"]
# Shape: [batch, question_len, embed_dim-2]
question_embedding = self._embed(question)
# Shape: [batch, question_len, embed_dim]
question_embedding = append_ids(question_embedding, [0, 1], 2)
choices_embedding = self._embed(
tf.zeros([question.shape[0], MAX_CHOICE_LENGTH], tf.int64))
choices_embedding = append_ids(choices_embedding, [0, 1], 2)
# Shape: [batch, choices, question_len + choice_len, embed_dim]
lang_embedding = tf.concat([question_embedding, choices_embedding], axis=1)
# Shape: [batch, frames, num_objects, embed_dim]
vision_embedding = inputs["monet_latents"]
if self._shuffle_objects:
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
vision_embedding = tf.random.shuffle(vision_embedding)
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
output = self._apply_transformers(lang_embedding, vision_embedding)
output = self._final_layer_descriptive(output)
return output
def apply_model_mc(self, inputs):
"""Applies model to CLEVRER multiple-choice questions.
Args:
inputs: dict of form: {
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
"choices": tf.int32 tensor of shape [batch, 4, MAX_CHOICE_LENGTH],
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
}
Returns:
Tensor of shape [batch, 4], representing logits for each choice
"""
question = inputs["question"]
choices = inputs["choices"]
# Shape: [batch, question_len, embed_dim-2]
question_embedding = self._embed(question)
# Shape: [batch, question_len, embed_dim]
question_embedding = append_ids(question_embedding, [1, 0], 2)
# Shape: [batch, choices, choice_len, embed_dim-2]
choices_embedding = snt.BatchApply(self._embed)(choices)
# Shape: [batch, choices, choice_len, embed_dim]
choices_embedding = append_ids(choices_embedding, [0, 1], 3)
# Shape: [batch, choices, question_len + choice_len, embed_dim]
lang_embedding = tf.concat([
tf.tile(question_embedding[:, None],
[1, choices_embedding.shape[1], 1, 1]),
choices_embedding], axis=2)
# Shape: [batch, frames, num_objects, embed_dim]
vision_embedding = inputs["monet_latents"]
if self._shuffle_objects:
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
vision_embedding = tf.random.shuffle(vision_embedding)
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
output_per_choice = []
for c in range(NUM_CHOICES):
output = self._apply_transformers(
lang_embedding[:, c, :, :], vision_embedding)
output_per_choice.append(output)
output = tf.stack(output_per_choice, axis=1)
output = tf.squeeze(snt.BatchApply(self._final_layer_mc)(output), axis=2)
return output