mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 09:02:05 +08:00
Also create the `venv` in `/tmp/` rather than messing with the source tree. PiperOrigin-RevId: 368225759
186 lines
6.9 KiB
Python
186 lines
6.9 KiB
Python
# Copyright 2021 DeepMind Technologies Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Model code. Provided settings are identical to what was used in the paper."""
|
|
|
|
import sonnet as snt
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
from object_attention_for_reasoning import transformer
|
|
|
|
|
|
QUESTION_VOCAB_SIZE = 82
|
|
ANSWER_VOCAB_SIZE = 22
|
|
|
|
MAX_QUESTION_LENGTH = 20
|
|
MAX_CHOICE_LENGTH = 12
|
|
|
|
NUM_CHOICES = 4
|
|
EMBED_DIM = 16
|
|
|
|
PRETRAINED_MODEL_CONFIG = dict(
|
|
use_relative_positions=True,
|
|
shuffle_objects=True,
|
|
transformer_layers=28,
|
|
head_size=128,
|
|
num_heads=10,
|
|
embed_dim=EMBED_DIM,
|
|
)
|
|
|
|
|
|
def append_ids(tensor, id_vector, axis):
|
|
id_vector = tf.constant(id_vector, tf.float32)
|
|
for a in range(len(tensor.shape)):
|
|
if a != axis:
|
|
id_vector = tf.expand_dims(id_vector, axis=a)
|
|
tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
|
|
id_tensor = tf.tile(id_vector, tiling_vector)
|
|
return tf.concat([tensor, id_tensor], axis=axis)
|
|
|
|
|
|
class ClevrerTransformerModel(object):
|
|
"""Model from Ding et al. 2020 (https://arxiv.org/abs/2012.08508)."""
|
|
|
|
def __init__(self, use_relative_positions, shuffle_objects,
|
|
transformer_layers, num_heads, head_size, embed_dim):
|
|
"""Instantiate Sonnet modules."""
|
|
self._embed_dim = embed_dim
|
|
self._embed = snt.Embed(QUESTION_VOCAB_SIZE, embed_dim - 2)
|
|
self._shuffle_objects = shuffle_objects
|
|
self._memory_transformer = transformer.TransformerTower(
|
|
value_size=embed_dim + 2,
|
|
num_heads=num_heads,
|
|
num_layers=transformer_layers,
|
|
use_relative_positions=use_relative_positions,
|
|
causal=False)
|
|
|
|
self._final_layer_mc = snt.Sequential(
|
|
[snt.Linear(head_size), tf.nn.relu, snt.Linear(1)])
|
|
self._final_layer_descriptive = snt.Sequential(
|
|
[snt.Linear(head_size), tf.nn.relu,
|
|
snt.Linear(ANSWER_VOCAB_SIZE)])
|
|
|
|
self._dummy = tf.get_variable("dummy", [embed_dim + 2], tf.float32,
|
|
initializer=tf.zeros_initializer)
|
|
self._infill_linear = snt.Linear(embed_dim + 2)
|
|
self._mask_embedding = tf.get_variable(
|
|
"mask", [embed_dim + 2], tf.float32, initializer=tf.zeros_initializer)
|
|
|
|
def _apply_transformers(self, lang_embedding, vision_embedding):
|
|
"""Applies transformer to language and vision input.
|
|
|
|
Args:
|
|
lang_embedding: tensor,
|
|
vision_embedding: tensor, "validation", or "test".
|
|
|
|
Returns:
|
|
tuple, output at dummy token, all output embeddings, infill loss
|
|
"""
|
|
def _unroll(tensor):
|
|
"""Unroll the time dimension into the object dimension."""
|
|
return tf.reshape(
|
|
tensor, [tensor.shape[0], -1, tensor.shape[3]])
|
|
|
|
words = append_ids(lang_embedding, [1, 0], axis=2)
|
|
dummy_word = tf.tile(self._dummy[None, None, :], [tf.shape(words)[0], 1, 1])
|
|
vision_embedding = append_ids(vision_embedding, [0, 1], axis=3)
|
|
vision_over_time = _unroll(vision_embedding)
|
|
transformer_input = tf.concat([dummy_word, words, vision_over_time], axis=1)
|
|
|
|
output, _ = self._memory_transformer(transformer_input,
|
|
is_training=False)
|
|
return output[:, 0, :]
|
|
|
|
def apply_model_descriptive(self, inputs):
|
|
"""Applies model to CLEVRER descriptive questions.
|
|
|
|
Args:
|
|
inputs: dict of form: {
|
|
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
|
|
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
|
|
}
|
|
Returns:
|
|
Tensor of shape [batch, ANSWER_VOCAB_SIZE], representing logits for each
|
|
possible answer word.
|
|
"""
|
|
question = inputs["question"]
|
|
|
|
# Shape: [batch, question_len, embed_dim-2]
|
|
question_embedding = self._embed(question)
|
|
# Shape: [batch, question_len, embed_dim]
|
|
question_embedding = append_ids(question_embedding, [0, 1], 2)
|
|
choices_embedding = self._embed(
|
|
tf.zeros([question.shape[0], MAX_CHOICE_LENGTH], tf.int64))
|
|
choices_embedding = append_ids(choices_embedding, [0, 1], 2)
|
|
# Shape: [batch, choices, question_len + choice_len, embed_dim]
|
|
lang_embedding = tf.concat([question_embedding, choices_embedding], axis=1)
|
|
|
|
# Shape: [batch, frames, num_objects, embed_dim]
|
|
vision_embedding = inputs["monet_latents"]
|
|
|
|
if self._shuffle_objects:
|
|
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
|
vision_embedding = tf.random.shuffle(vision_embedding)
|
|
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
|
|
|
output = self._apply_transformers(lang_embedding, vision_embedding)
|
|
output = self._final_layer_descriptive(output)
|
|
return output
|
|
|
|
def apply_model_mc(self, inputs):
|
|
"""Applies model to CLEVRER multiple-choice questions.
|
|
|
|
Args:
|
|
inputs: dict of form: {
|
|
"question": tf.int32 tensor of shape [batch, MAX_QUESTION_LENGTH],
|
|
"choices": tf.int32 tensor of shape [batch, 4, MAX_CHOICE_LENGTH],
|
|
"monet_latents": tf.float32 tensor of shape [batch, frames, 8, 16],
|
|
}
|
|
Returns:
|
|
Tensor of shape [batch, 4], representing logits for each choice
|
|
"""
|
|
question = inputs["question"]
|
|
choices = inputs["choices"]
|
|
|
|
# Shape: [batch, question_len, embed_dim-2]
|
|
question_embedding = self._embed(question)
|
|
# Shape: [batch, question_len, embed_dim]
|
|
question_embedding = append_ids(question_embedding, [1, 0], 2)
|
|
# Shape: [batch, choices, choice_len, embed_dim-2]
|
|
choices_embedding = snt.BatchApply(self._embed)(choices)
|
|
# Shape: [batch, choices, choice_len, embed_dim]
|
|
choices_embedding = append_ids(choices_embedding, [0, 1], 3)
|
|
# Shape: [batch, choices, question_len + choice_len, embed_dim]
|
|
lang_embedding = tf.concat([
|
|
tf.tile(question_embedding[:, None],
|
|
[1, choices_embedding.shape[1], 1, 1]),
|
|
choices_embedding], axis=2)
|
|
|
|
# Shape: [batch, frames, num_objects, embed_dim]
|
|
vision_embedding = inputs["monet_latents"]
|
|
|
|
if self._shuffle_objects:
|
|
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
|
vision_embedding = tf.random.shuffle(vision_embedding)
|
|
vision_embedding = tf.transpose(vision_embedding, [2, 1, 0, 3])
|
|
|
|
output_per_choice = []
|
|
for c in range(NUM_CHOICES):
|
|
output = self._apply_transformers(
|
|
lang_embedding[:, c, :, :], vision_embedding)
|
|
output_per_choice.append(output)
|
|
|
|
output = tf.stack(output_per_choice, axis=1)
|
|
output = tf.squeeze(snt.BatchApply(self._final_layer_mc)(output), axis=2)
|
|
return output
|