mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 17:18:46 +08:00
186 lines
6.6 KiB
Python
186 lines
6.6 KiB
Python
# Copyright 2019 Deepmind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""A graph neural network based model to predict particle mobilities.
|
|
|
|
The architecture and performance of this model is described in our publication:
|
|
"Unveiling the predictive power of static structure in glassy systems".
|
|
"""
|
|
|
|
import functools
|
|
from typing import Any, Dict, Text, Tuple, Optional
|
|
|
|
from graph_nets import graphs
|
|
from graph_nets import modules as gn_modules
|
|
from graph_nets import utils_tf
|
|
|
|
import sonnet as snt
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
|
|
def make_graph_from_static_structure(
|
|
positions: tf.Tensor,
|
|
types: tf.Tensor,
|
|
box: tf.Tensor,
|
|
edge_threshold: float) -> graphs.GraphsTuple:
|
|
"""Returns graph representing the static structure of the glass.
|
|
|
|
Each particle is represented by a node in the graph. The particle type is
|
|
stored as a node feature.
|
|
Two particles at a distance less than the threshold are connected by an edge.
|
|
The relative distance vector is stored as an edge feature.
|
|
|
|
Args:
|
|
positions: particle positions with shape [n_particles, 3].
|
|
types: particle types with shape [n_particles].
|
|
box: dimensions of the cubic box that contains the particles with shape [3].
|
|
edge_threshold: particles at distance less than threshold are connected by
|
|
an edge.
|
|
"""
|
|
# Calculate pairwise relative distances between particles: shape [n, n, 3].
|
|
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
|
|
# Enforces periodic boundary conditions.
|
|
box_ = box[tf.newaxis, tf.newaxis, :]
|
|
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
|
|
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
|
|
# Calculates adjacency matrix in a sparse format (indices), based on the given
|
|
# distances and threshold.
|
|
distances = tf.norm(cross_positions, axis=-1)
|
|
indices = tf.where(distances < edge_threshold)
|
|
|
|
# Defines graph.
|
|
nodes = types[:, tf.newaxis]
|
|
senders = indices[:, 0]
|
|
receivers = indices[:, 1]
|
|
edges = tf.gather_nd(cross_positions, indices)
|
|
|
|
return graphs.GraphsTuple(
|
|
nodes=tf.cast(nodes, tf.float32),
|
|
n_node=tf.reshape(tf.shape(nodes)[0], [1]),
|
|
edges=tf.cast(edges, tf.float32),
|
|
n_edge=tf.reshape(tf.shape(edges)[0], [1]),
|
|
globals=tf.zeros((1, 1), dtype=tf.float32),
|
|
receivers=tf.cast(receivers, tf.int32),
|
|
senders=tf.cast(senders, tf.int32)
|
|
)
|
|
|
|
|
|
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
|
|
"""Returns randomly rotated graph representation.
|
|
|
|
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
|
This function assumes that the relative particle distances are stored in
|
|
the edge features.
|
|
|
|
Args:
|
|
graph: The graphs tuple as defined in `graph_nets.graphs`.
|
|
"""
|
|
# Transposes edge features, so that the axes are in the first dimension.
|
|
# Outputs a tensor of shape [3, n_particles].
|
|
xyz = tf.transpose(graph.edges)
|
|
# Random pi/2 rotation(s)
|
|
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
|
|
xyz = tf.gather(xyz, permutation)
|
|
# Random reflections.
|
|
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
|
|
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
|
|
xyz = xyz * symmetry
|
|
edges = tf.transpose(xyz)
|
|
return graph.replace(edges=edges)
|
|
|
|
|
|
class GraphBasedModel(snt.AbstractModule):
|
|
"""Graph based model which predicts particle mobilities from their positions.
|
|
|
|
This network encodes the nodes and edges of the input graph independently, and
|
|
then performs message-passing on this graph, updating its edges based on their
|
|
associated nodes, then updating the nodes based on the input nodes' features
|
|
and their associated updated edge features.
|
|
This update is repeated several times.
|
|
Afterwards the resulting node embeddings are decoded to predict the particle
|
|
mobility.
|
|
"""
|
|
|
|
def __init__(self,
|
|
n_recurrences: int,
|
|
mlp_sizes: Tuple[int],
|
|
mlp_kwargs: Optional[Dict[Text, Any]] = None,
|
|
name='Graph'):
|
|
"""Creates a new GraphBasedModel object.
|
|
|
|
Args:
|
|
n_recurrences: the number of message passing steps in the graph network.
|
|
mlp_sizes: the number of neurons in each layer of the MLP.
|
|
mlp_kwargs: additional keyword aguments passed to the MLP.
|
|
name: the name of the Sonnet module.
|
|
"""
|
|
super(GraphBasedModel, self).__init__(name=name)
|
|
self._n_recurrences = n_recurrences
|
|
|
|
if mlp_kwargs is None:
|
|
mlp_kwargs = {}
|
|
|
|
model_fn = functools.partial(
|
|
snt.nets.MLP,
|
|
output_sizes=mlp_sizes,
|
|
activate_final=True,
|
|
**mlp_kwargs)
|
|
|
|
final_model_fn = functools.partial(
|
|
snt.nets.MLP,
|
|
output_sizes=mlp_sizes + (1,),
|
|
activate_final=False,
|
|
**mlp_kwargs)
|
|
|
|
with self._enter_variable_scope():
|
|
self._encoder = gn_modules.GraphIndependent(
|
|
node_model_fn=model_fn,
|
|
edge_model_fn=model_fn)
|
|
|
|
if self._n_recurrences > 0:
|
|
self._propagation_network = gn_modules.GraphNetwork(
|
|
node_model_fn=model_fn,
|
|
edge_model_fn=model_fn,
|
|
# We do not use globals, hence we just pass the identity function.
|
|
global_model_fn=lambda: lambda x: x,
|
|
reducer=tf.unsorted_segment_sum,
|
|
edge_block_opt=dict(use_globals=False),
|
|
node_block_opt=dict(use_globals=False),
|
|
global_block_opt=dict(use_globals=False))
|
|
|
|
self._decoder = gn_modules.GraphIndependent(
|
|
node_model_fn=final_model_fn,
|
|
edge_model_fn=model_fn)
|
|
|
|
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
|
|
"""Connects the model into the tensorflow graph.
|
|
|
|
Args:
|
|
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`.
|
|
|
|
Returns:
|
|
tensor with shape [n_particles] containing the predicted particle
|
|
mobilities.
|
|
"""
|
|
encoded = self._encoder(graphs_tuple)
|
|
outputs = encoded
|
|
|
|
for _ in range(self._n_recurrences):
|
|
# Adds skip connections.
|
|
inputs = utils_tf.concat([outputs, encoded], axis=-1)
|
|
outputs = self._propagation_network(inputs)
|
|
|
|
decoded = self._decoder(outputs)
|
|
return tf.squeeze(decoded.nodes, axis=-1)
|