PAC Bayes Quadratic bound open sourcing.

PiperOrigin-RevId: 302629851
This commit is contained in:
Vikram Tankasali
2020-03-24 10:10:46 +00:00
committed by Louise Deason
parent afcdc77239
commit f6395d709e
13 changed files with 1080 additions and 0 deletions

80
glassy_dynamics/README.md Normal file
View File

@@ -0,0 +1,80 @@
# Unveiling the predictive power of static structure in glassy systems
This repository contains an open source implementation of the graph neural
network model described in our paper.
The model can be trained using the training binary included in this repository,
and the dataset published with our paper.
## Abstract
Despite decades of theoretical studies, the nature of the glass transition
remains elusive and debated, while the existence of structural predictors of the
dynamics is a major open question. Recent approaches propose inferring
predictors from a variety of human-defined features using machine learning.
We learn the long time evolution of a glassy system solely from the initial
particle positions and without any hand-crafted features, using a powerful
model: graph neural networks. We show that this method strongly outperforms
state-of-the-art methods, generalizing over a wide range of temperatures,
pressures, and densities. In shear experiments, it predicts the location of
rearranging particles. The structural predictors learned by our network unveil a
correlation length which increases with larger timescales to reach the size of
our system. Beyond glasses, our method could apply to many other physical
systems that map to a graph of local interactions.
## Dataset
### System description
The dataset was generated with the LAMMPS molecular dynamics package.
The simulated system has periodic boundaries and is a binary mixture of 4096
large (A) and small (B) particles that interact via a 6-12 Lennard-Jones
potential.
The interaction coefficients are set for a typical Kob-Andersen configuration.
### Data format
The data is stored in Python's pickle format protocol version 3.
Each file contains the data for one of the equilibrated systems in a Python
dictionary. The dictionary contains the following entries:
- `positions` the particle positions of the equilibrated system.
- `types` the particle types (0 == type A and 1 == type B) of the equilibrated
system.
- `box` the dimensions of the periodic cubic simulation box.
- `time` the logarithmically sampled time points.
- `time_indices` the indices of the time points for which the sampled
trajectories on average reach a certain value of the intermediate
scattering function.
- `is_values` the values of the intermediate scattering function associated
with each time index.
- `trajectory_start_velocities` the velocities drawn from a Boltzmann
distribution at the start of each trajectory.
- `trajectory_target_positions` the positions of the particles for each of
the trajectories at selected time points (as defined by the `time_indices`
array and the corresponding values of the intermediate scattering function
stored in `is_values`).
- `metadata` a dictionary containing additional metadata:
- `temperature` the temperature at which the system was equilibrated.
- `pressure` the pressure at which the system was equilibrated.
- `fluid` the type of fluid which was simulated (Kob-Andersen).
All units are in Lennard-Jones units. The positions are stored in the absolute
coordinate system i.e. they are outside of the simulation box if the particle
crossed a periodic boundary during the simulation.
## Reference
If this repository is helpful for your research please cite the following
publication:
Unveiling the predictive power of static structure in glassysystems
V. Bapst, T. Keck, A. Grabska-Barwinska, C. Donner, E. D. Cubuk,
S. S. Schoenholz, A.Obika, A. W. R. Nelson, T. Back, D. Hassabis and P. Kohli
## Disclaimer
This is not an official Google product.

View File

@@ -0,0 +1,62 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Applies a graph-based network to predict particle mobilities in glasses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from glassy_dynamics import train
FLAGS = flags.FLAGS
flags.DEFINE_string(
'data_directory',
'',
'Directory which contains the train or test datasets.')
flags.DEFINE_integer(
'time_index',
9,
'The time index of the target mobilities.')
flags.DEFINE_integer(
'max_files_to_load',
None,
'The maximum number of files to load.')
flags.DEFINE_string(
'checkpoint_path',
'checkpoints/t044_s09.ckpt',
'Path used to load the model.')
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
file_pattern = os.path.join(FLAGS.data_directory, 'aggregated*')
train.apply_model(
checkpoint_path=FLAGS.checkpoint_path,
file_pattern=file_pattern,
max_files_to_load=FLAGS.max_files_to_load,
time_index=FLAGS.time_index)
if __name__ == '__main__':
app.run(main)

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,190 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A graph neural network based model to predict particle mobilities.
The architecture and performance of this model is described in our publication:
"Unveiling the predictive power of static structure in glassy systems".
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from graph_nets import graphs
from graph_nets import modules as gn_modules
from graph_nets import utils_tf
import sonnet as snt
import tensorflow.compat.v1 as tf
from typing import Any, Dict, Text, Tuple, Optional
def make_graph_from_static_structure(
positions,
types,
box,
edge_threshold):
"""Returns graph representing the static structure of the glass.
Each particle is represented by a node in the graph. The particle type is
stored as a node feature.
Two particles at a distance less than the threshold are connected by an edge.
The relative distance vector is stored as an edge feature.
Args:
positions: particle positions with shape [n_particles, 3].
types: particle types with shape [n_particles].
box: dimensions of the cubic box that contains the particles with shape [3].
edge_threshold: particles at distance less than threshold are connected by
an edge.
"""
# Calculate pairwise relative distances between particles: shape [n, n, 3].
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
# Enforces periodic boundary conditions.
box_ = box[tf.newaxis, tf.newaxis, :]
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
# Calculates adjacency matrix in a sparse format (indices), based on the given
# distances and threshold.
distances = tf.norm(cross_positions, axis=-1)
indices = tf.where(distances < edge_threshold)
# Defines graph.
nodes = types[:, tf.newaxis]
senders = indices[:, 0]
receivers = indices[:, 1]
edges = tf.gather_nd(cross_positions, indices)
return graphs.GraphsTuple(
nodes=tf.cast(nodes, tf.float32),
n_node=tf.reshape(tf.shape(nodes)[0], [1]),
edges=tf.cast(edges, tf.float32),
n_edge=tf.reshape(tf.shape(edges)[0], [1]),
globals=tf.zeros((1, 1), dtype=tf.float32),
receivers=tf.cast(receivers, tf.int32),
senders=tf.cast(senders, tf.int32)
)
def apply_random_rotation(graph):
"""Returns randomly rotated graph representation.
The rotation is an element of O(3) with rotation angles multiple of pi/2.
This function assumes that the relative particle distances are stored in
the edge features.
Args:
graph: The graphs tuple as defined in `graph_nets.graphs`.
"""
# Transposes edge features, so that the axes are in the first dimension.
# Outputs a tensor of shape [3, n_particles].
xyz = tf.transpose(graph.edges)
# Random pi/2 rotation(s)
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
xyz = tf.gather(xyz, permutation)
# Random reflections.
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
xyz = xyz * symmetry
edges = tf.transpose(xyz)
return graph.replace(edges=edges)
class GraphBasedModel(snt.AbstractModule):
"""Graph based model which predicts particle mobilities from their positions.
This network encodes the nodes and edges of the input graph independently, and
then performs message-passing on this graph, updating its edges based on their
associated nodes, then updating the nodes based on the input nodes' features
and their associated updated edge features.
This update is repeated several times.
Afterwards the resulting node embeddings are decoded to predict the particle
mobility.
"""
def __init__(self,
n_recurrences,
mlp_sizes,
mlp_kwargs = None,
name='Graph'):
"""Creates a new GraphBasedModel object.
Args:
n_recurrences: the number of message passing steps in the graph network.
mlp_sizes: the number of neurons in each layer of the MLP.
mlp_kwargs: additional keyword aguments passed to the MLP.
name: the name of the Sonnet module.
"""
super(GraphBasedModel, self).__init__(name=name)
self._n_recurrences = n_recurrences
if mlp_kwargs is None:
mlp_kwargs = {}
model_fn = functools.partial(
snt.nets.MLP,
output_sizes=mlp_sizes,
activate_final=True,
**mlp_kwargs)
final_model_fn = functools.partial(
snt.nets.MLP,
output_sizes=mlp_sizes + (1,),
activate_final=False,
**mlp_kwargs)
with self._enter_variable_scope():
self._encoder = gn_modules.GraphIndependent(
node_model_fn=model_fn,
edge_model_fn=model_fn)
if self._n_recurrences > 0:
self._propagation_network = gn_modules.GraphNetwork(
node_model_fn=model_fn,
edge_model_fn=model_fn,
# We do not use globals, hence we just pass the identity function.
global_model_fn=lambda: lambda x: x,
reducer=tf.unsorted_segment_sum,
edge_block_opt=dict(use_globals=False),
node_block_opt=dict(use_globals=False),
global_block_opt=dict(use_globals=False))
self._decoder = gn_modules.GraphIndependent(
node_model_fn=final_model_fn,
edge_model_fn=model_fn)
def _build(self, graphs_tuple):
"""Connects the model into the tensorflow graph.
Args:
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`.
Returns:
tensor with shape [n_particles] containing the predicted particle
mobilities.
"""
encoded = self._encoder(graphs_tuple)
outputs = encoded
for _ in range(self._n_recurrences):
# Adds skip connections.
inputs = utils_tf.concat([outputs, encoded], axis=-1)
outputs = self._propagation_network(inputs)
decoded = self._decoder(outputs)
return tf.squeeze(decoded.nodes, axis=-1)

View File

@@ -0,0 +1,147 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for graph_model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
from graph_nets import graphs
import numpy as np
import tensorflow.compat.v1 as tf
from glassy_dynamics import graph_model
class GraphModelTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
"""Initializes a small tractable test (particle) system."""
super(GraphModelTest, self).setUp()
# Fixes random seed to ensure deterministic outputs.
tf.random.set_random_seed(1234)
# In this test we use a small tractable set of particles covering all corner
# cases:
# a) eight particles with different types,
# b) periodic box is not cubic,
# c) three disjoint cluster of particles separated by a threshold > 2,
# d) first two clusters overlap with the periodic boundary,
# e) first cluster is not fully connected,
# f) second cluster is fully connected,
# g) and third cluster is a single isolated particle.
#
# The formatting of the code below separates the three clusters by
# adding linebreaks after each cluster.
self._positions = np.array(
[[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, 9.0],
[0.0, 5.0, 0.0], [0.0, 5.0, 1.0], [3.0, 5.0, 0.0],
[2.0, 3.0, 3.0]])
self._types = np.array([0.0, 0.0, 1.0, 0.0,
0.0, 1.0, 0.0,
0.0])
self._box = np.array([4.0, 10.0, 10.0])
# Creates the corresponding graph elements, assuming a threshold of 2 and
# the conventions described in `graph_nets.graphs`.
self._edge_threshold = 2
self._nodes = np.array(
[[0.0], [0.0], [1.0], [0.0],
[0.0], [1.0], [0.0],
[0.0]])
self._edges = np.array(
[[0.0, 0.0, 0.0], [-1.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, -1.0],
[1.5, 0.0, 0.0], [0.0, 0.0, 0.0], [1.5, 0.0, -1.0],
[0.0, -1.5, 0.0], [0.0, 0.0, 0.0], [0.0, -1.5, -1.0],
[0.0, 0.0, 1.0], [-1.5, 0.0, 1.0], [0.0, 1.5, 1.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0],
[0.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, -1.0],
[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])
self._receivers = np.array(
[0, 1, 2, 3, 0, 1, 3, 0, 2, 3, 0, 1, 2, 3,
4, 5, 6, 4, 5, 6, 4, 5, 6,
7])
self._senders = np.array(
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3,
4, 4, 4, 5, 5, 5, 6, 6, 6,
7])
def _get_graphs_tuple(self):
"""Returns a GraphsTuple containing a graph based on the test system."""
return graphs.GraphsTuple(
nodes=tf.constant(self._nodes, dtype=tf.float32),
edges=tf.constant(self._edges, dtype=tf.float32),
globals=tf.constant(np.array([[0.0]]), dtype=tf.float32),
receivers=tf.constant(self._receivers, dtype=tf.int32),
senders=tf.constant(self._senders, dtype=tf.int32),
n_node=tf.constant([len(self._nodes)], dtype=tf.int32),
n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
def test_make_graph_from_static_structure(self):
graphs_tuple_op = graph_model.make_graph_from_static_structure(
tf.constant(self._positions, dtype=tf.float32),
tf.constant(self._types, dtype=tf.int32),
tf.constant(self._box, dtype=tf.float32),
self._edge_threshold)
graphs_tuple = self.evaluate(graphs_tuple_op)
self.assertLen(self._nodes, graphs_tuple.n_node)
self.assertLen(self._edges, graphs_tuple.n_edge)
np.testing.assert_almost_equal(graphs_tuple.nodes, self._nodes)
np.testing.assert_equal(graphs_tuple.senders, self._senders)
np.testing.assert_equal(graphs_tuple.receivers, self._receivers)
np.testing.assert_almost_equal(graphs_tuple.globals, np.array([[0.0]]))
np.testing.assert_almost_equal(graphs_tuple.edges, self._edges)
def _is_equal_up_to_rotation(self, x, y):
for axes in itertools.permutations([0, 1, 2]):
for mirrors in itertools.product([1, -1], repeat=3):
if np.allclose(x, y[:, axes] * mirrors):
return True
return False
def test_apply_random_rotation(self):
graphs_tuple = self._get_graphs_tuple()
rotated_graphs_tuple_op = graph_model.apply_random_rotation(graphs_tuple)
rotated_graphs_tuple = self.evaluate(rotated_graphs_tuple_op)
np.testing.assert_almost_equal(rotated_graphs_tuple.nodes, self._nodes)
np.testing.assert_almost_equal(rotated_graphs_tuple.senders, self._senders)
np.testing.assert_almost_equal(
rotated_graphs_tuple.receivers, self._receivers)
np.testing.assert_almost_equal(
rotated_graphs_tuple.globals, np.array([[0.0]]))
self.assertTrue(self._is_equal_up_to_rotation(rotated_graphs_tuple.edges,
self._edges))
@parameterized.named_parameters(('no_propagation', 0, (30,)),
('multi_propagation', 5, (15,)),
('multi_layer', 1, (20, 30)))
def test_GraphModel(self, n_recurrences, mlp_sizes):
graphs_tuple = self._get_graphs_tuple()
output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences,
mlp_sizes=mlp_sizes)(graphs_tuple)
self.assertListEqual(output_op.shape.as_list(), [len(self._types)])
# Tests if the model runs without crashing.
with self.session():
tf.global_variables_initializer().run()
output_op.eval()
if __name__ == '__main__':
tf.test.main()

View File

@@ -0,0 +1,30 @@
# GlassyDynamics
absl-py==0.8.1
astor==0.8.0
cloudpickle==1.1.1
contextlib2==0.6.0.post1
decorator==4.4.0
dm-sonnet==1.35
future==0.18.0
gast==0.3.2
google-pasta==0.1.7
graph-nets==1.0.4
grpcio==1.24.1
h5py==2.10.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
Markdown==3.1.1
mock==3.0.5
networkx==2.3
numpy==1.17.2
pkg-resources==0.0.0
protobuf==3.10.0
semantic-version==2.8.2
six==1.12.0
tensorboard==1.13.1
tensorflow==1.13.1
tensorflow-estimator==1.13.0
tensorflow-probability==0.6.0
termcolor==1.1.0
Werkzeug==0.16.0
wrapt==1.11.2

Binary file not shown.

Binary file not shown.

381
glassy_dynamics/train.py Normal file
View File

@@ -0,0 +1,381 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training pipeline for the prediction of particle mobilities in glasses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import pickle
from absl import logging
import enum
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from typing import Any, Dict, List, Optional, Text, Tuple, Sequence
from glassy_dynamics import graph_model
tf.enable_resource_variables()
LossCollection = collections.namedtuple('LossCollection',
'l1_loss, l2_loss, correlation')
GlassSimulationData = collections.namedtuple('GlassSimulationData',
'positions, targets, types, box')
class ParticleType(enum.IntEnum):
"""The simulation contains two particle types, identified as type A and B.
The dataset encodes the particle type in an integer.
- 0 corresponds to particle type A.
- 1 corresponds to particle type B.
"""
A = 0
B = 1
def get_targets(
initial_positions,
trajectory_target_positions):
"""Returns the averaged particle mobilities from the sampled trajectories.
Args:
initial_positions: the initial positions of the particles with shape
[n_particles, 3].
trajectory_target_positions: the absolute positions of the particles at the
target time for all sampled trajectories, each with shape
[n_particles, 3].
"""
targets = np.mean([np.linalg.norm(t - initial_positions, axis=-1)
for t in trajectory_target_positions], axis=0)
return targets.astype(np.float32)
def load_data(
file_pattern,
time_index,
max_files_to_load = None):
"""Returns a dictionary containing the training or test dataset.
The dictionary contains:
`positions`: `np.ndarray` containing the particle positions with shape
[n_particles, 3].
`targets`: `np.ndarray` containing particle mobilities with shape
[n_particles].
`types`: `np.ndarray` containing the particle types with shape with shape
[n_particles].
`box`: `np.ndarray` containing the dimensions of the periodic box with shape
[3].
Args:
file_pattern: pattern matching the files with the simulation data.
time_index: the time index of the targets.
max_files_to_load: the maximum number of files to load.
"""
filenames = tf.io.gfile.glob(file_pattern)
if max_files_to_load:
filenames = filenames[:max_files_to_load]
static_structures = []
for filename in filenames:
with tf.io.gfile.GFile(filename, 'rb') as f:
data = pickle.load(f)
static_structures.append(GlassSimulationData(
positions=data['positions'].astype(np.float32),
targets=get_targets(
data['positions'], data['trajectory_target_positions'][time_index]),
types=data['types'].astype(np.int32),
box=data['box'].astype(np.float32)))
return static_structures
def get_loss_ops(
prediction,
target,
types):
"""Returns L1/L2 loss and correlation for type A particles.
Args:
prediction: tensor with shape [n_particles] containing the predicted
particle mobilities.
target: tensor with shape [n_particles] containing the true particle
mobilities.
types: tensor with shape [n_particles] containing the particle types.
"""
# Considers only type A particles.
mask = tf.equal(types, ParticleType.A)
prediction = tf.boolean_mask(prediction, mask)
target = tf.boolean_mask(target, mask)
return LossCollection(
l1_loss=tf.reduce_mean(tf.abs(prediction - target)),
l2_loss=tf.reduce_mean((prediction - target)**2),
correlation=tf.squeeze(tfp.stats.correlation(
prediction[:, tf.newaxis], target[:, tf.newaxis])))
def get_minimize_op(
loss,
learning_rate,
grad_clip = None):
"""Returns minimization operation.
Args:
loss: the loss tensor which is minimized.
learning_rate: the learning rate used by the optimizer.
grad_clip: all gradients are clipped to the given value if not None or 0.
"""
optimizer = tf.train.AdamOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
if grad_clip:
grads, _ = tf.clip_by_global_norm([g for g, _ in grads_and_vars], grad_clip)
grads_and_vars = [(g, pair[1]) for g, pair in zip(grads, grads_and_vars)]
minimize = optimizer.apply_gradients(grads_and_vars)
return minimize
def _log_stats_and_return_mean_correlation(
label,
stats):
"""Logs performance statistics and returns mean correlation.
Args:
label: label printed before the combined statistics e.g. train or test.
stats: statistics calculated for each batch in a dataset.
Returns:
mean correlation
"""
for key in LossCollection._fields:
values = [getattr(s, key) for s in stats]
mean = np.mean(values)
std = np.std(values)
logging.info('%s: %s: %.4f +/- %.4f', label, key, mean, std)
return np.mean([s.correlation for s in stats])
def train_model(train_file_pattern,
test_file_pattern,
max_files_to_load = None,
n_epochs = 1000,
time_index = 9,
augment_data_using_rotations = True,
learning_rate = 1e-4,
grad_clip = 1.0,
n_recurrences = 7,
mlp_sizes = (64, 64),
mlp_kwargs = None,
edge_threshold = 2.0,
measurement_store_interval = 1000,
checkpoint_path = None):
"""Trains GraphModel using tensorflow.
Args:
train_file_pattern: pattern matching the files with the training data.
test_file_pattern: pattern matching the files with the test data.
max_files_to_load: the maximum number of train and test files to load.
If None, all files will be loaded.
n_epochs: the number of passes through the training dataset (epochs).
time_index: the time index (0-9) of the target mobilities.
augment_data_using_rotations: data is augemented by using random rotations.
learning_rate: the learning rate used by the optimizer.
grad_clip: all gradients are clipped to the given value.
n_recurrences: the number of message passing steps in the graphnet.
mlp_sizes: the number of neurons in each layer of the MLP.
mlp_kwargs: additional keyword aguments passed to the MLP.
edge_threshold: particles at distance less than threshold are connected by
an edge.
measurement_store_interval: number of steps between storing objective values
(loss and correlation).
checkpoint_path: path used to store the checkpoint with the highest
correlation on the test set.
Returns:
Correlation on the test dataset of best model encountered during training.
"""
if mlp_kwargs is None:
mlp_kwargs = dict(initializers=dict(w=tf.variance_scaling_initializer(1.0),
b=tf.variance_scaling_initializer(0.1)))
# Loads train and test dataset.
dataset_kwargs = dict(
time_index=time_index,
max_files_to_load=max_files_to_load)
training_data = load_data(train_file_pattern, **dataset_kwargs)
test_data = load_data(test_file_pattern, **dataset_kwargs)
# Defines wrapper functions, which can directly be passed to the
# tf.data.Dataset.map function.
def _make_graph_from_static_structure(static_structure):
"""Converts static structure to graph, targets and types."""
return (graph_model.make_graph_from_static_structure(
static_structure.positions,
static_structure.types,
static_structure.box,
edge_threshold),
static_structure.targets,
static_structure.types)
def _apply_random_rotation(graph, targets, types):
"""Applies random rotations to the graph and forwards targets and types."""
return graph_model.apply_random_rotation(graph), targets, types
# Defines data-pipeline based on tf.data.Dataset following the official
# guideline: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays.
# We use initializable iterators to avoid embedding the training and test data
# directly into the graph.
# Instead we feed the data to the iterators during the initalization of the
# iterators before the main training loop.
placeholders = GlassSimulationData._make(
tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0])
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
dataset = dataset.map(_make_graph_from_static_structure)
dataset = dataset.cache()
dataset = dataset.shuffle(400)
# Augments data. This has to be done after calling dataset.cache!
if augment_data_using_rotations:
dataset = dataset.map(_apply_random_rotation)
dataset = dataset.repeat()
train_iterator = dataset.make_initializable_iterator()
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
dataset = dataset.map(_make_graph_from_static_structure)
dataset = dataset.cache()
dataset = dataset.repeat()
test_iterator = dataset.make_initializable_iterator()
# Creates tensorflow graph.
# Note: We decouple the training and test datasets from the input pipeline
# by creating a new iterator from a string-handle placeholder with the same
# output types and shapes as the training dataset.
dataset_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
dataset_handle, train_iterator.output_types, train_iterator.output_shapes)
graph, targets, types = iterator.get_next()
model = graph_model.GraphBasedModel(
n_recurrences, mlp_sizes, mlp_kwargs)
prediction = model(graph)
# Defines loss and minimization operations.
loss_ops = get_loss_ops(prediction, targets, types)
minimize_op = get_minimize_op(loss_ops.l2_loss, learning_rate, grad_clip)
best_so_far = -1
train_stats = []
test_stats = []
saver = tf.train.Saver()
with tf.train.SingularMonitoredSession() as session:
# Initializes train and test iterators with the training and test datasets.
# The obtained training and test string-handles can be passed to the
# dataset_handle placeholder to select the dataset.
train_handle = session.run(train_iterator.string_handle())
test_handle = session.run(test_iterator.string_handle())
feed_dict = {p: [x[i] for x in training_data]
for i, p in enumerate(placeholders)}
session.run(train_iterator.initializer, feed_dict=feed_dict)
feed_dict = {p: [x[i] for x in test_data]
for i, p in enumerate(placeholders)}
session.run(test_iterator.initializer, feed_dict=feed_dict)
# Trains model using stochatic gradient descent on the training dataset.
n_training_steps = len(training_data) * n_epochs
for i in range(n_training_steps):
feed_dict = {dataset_handle: train_handle}
train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
train_stats.append(train_loss)
if (i+1) % measurement_store_interval == 0:
# Evaluates model on test dataset.
for _ in range(len(test_data)):
feed_dict = {dataset_handle: test_handle}
test_stats.append(session.run(loss_ops, feed_dict=feed_dict))
# Outputs performance statistics on training and test dataset.
_log_stats_and_return_mean_correlation('Train', train_stats)
correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
train_stats = []
test_stats = []
# Updates best model based on the observed correlation on the test
# dataset.
if correlation > best_so_far:
best_so_far = correlation
if checkpoint_path:
saver.save(session.raw_session(), checkpoint_path)
return best_so_far
def apply_model(checkpoint_path,
file_pattern,
max_files_to_load = None,
time_index = 9):
"""Applies trained GraphModel using tensorflow.
Args:
checkpoint_path: path from which the model is loaded.
file_pattern: pattern matching the files with the data.
max_files_to_load: the maximum number of files to load.
If None, all files will be loaded.
time_index: the time index (0-9) of the target mobilities.
Returns:
Predictions of the model for all files.
"""
dataset_kwargs = dict(
time_index=time_index,
max_files_to_load=max_files_to_load)
data = load_data(file_pattern, **dataset_kwargs)
tf.reset_default_graph()
saver = tf.train.import_meta_graph(checkpoint_path + '.meta')
graph = tf.get_default_graph()
placeholders = GlassSimulationData(
positions=graph.get_tensor_by_name('Placeholder:0'),
targets=graph.get_tensor_by_name('Placeholder_1:0'),
types=graph.get_tensor_by_name('Placeholder_2:0'),
box=graph.get_tensor_by_name('Placeholder_3:0'))
prediction_tensor = graph.get_tensor_by_name('Graph_1/Squeeze:0')
correlation_tensor = graph.get_tensor_by_name('Squeeze:0')
dataset_handle = graph.get_tensor_by_name('Placeholder_4:0')
test_initalizer = graph.get_operation_by_name('MakeIterator_1')
test_string_handle = graph.get_tensor_by_name('IteratorToStringHandle_1:0')
with tf.Session() as session:
saver.restore(session, checkpoint_path)
handle = session.run(test_string_handle)
feed_dict = {p: [x[i] for x in data] for i, p in enumerate(placeholders)}
session.run(test_initalizer, feed_dict=feed_dict)
predictions = []
correlations = []
for _ in range(len(data)):
p, c = session.run((prediction_tensor, correlation_tensor),
feed_dict={dataset_handle: handle})
predictions.append(p)
correlations.append(c)
logging.info('Correlation: %.4f +/- %.4f',
np.mean(correlations),
np.std(correlations))
return predictions

View File

@@ -0,0 +1,64 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trains a graph-based network to predict particle mobilities in glasses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from glassy_dynamics import train
FLAGS = flags.FLAGS
flags.DEFINE_string(
'data_directory',
'',
'Directory which contains the train and test datasets.')
flags.DEFINE_integer(
'time_index',
9,
'The time index of the target mobilities.')
flags.DEFINE_integer(
'max_files_to_load',
None,
'The maximum number of files to load from the train and test datasets.')
flags.DEFINE_string(
'checkpoint_path',
None,
'Path used to store a checkpoint of the best model.')
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
train_file_pattern = os.path.join(FLAGS.data_directory, 'train/aggregated*')
test_file_pattern = os.path.join(FLAGS.data_directory, 'test/aggregated*')
train.train_model(
train_file_pattern=train_file_pattern,
test_file_pattern=test_file_pattern,
max_files_to_load=FLAGS.max_files_to_load,
time_index=FLAGS.time_index,
checkpoint_path=FLAGS.checkpoint_path)
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,126 @@
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow.compat.v1 as tf
from glassy_dynamics import train
class TrainTest(tf.test.TestCase):
def test_get_targets(self):
initial_positions = np.array([[0, 0, 0], [1, 2, 3]])
trajectory_target_positions = [
np.array([[1, 0, 0], [1, 2, 4]]),
np.array([[0, 1, 0], [1, 0, 3]]),
np.array([[0, 0, 5], [1, 2, 3]]),
]
expected_targets = np.array([7.0 / 3.0, 1.0])
targets = train.get_targets(initial_positions, trajectory_target_positions)
np.testing.assert_almost_equal(expected_targets, targets)
def test_load_data(self):
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
'test_small.pickle')
with self.subTest('ContentAndShapesAreAsExpected'):
data = train.load_data(file_pattern, 0)
self.assertEqual(len(data), 1)
element = data[0]
self.assertTupleEqual(element.positions.shape, (20, 3))
self.assertTupleEqual(element.box.shape, (3,))
self.assertTupleEqual(element.targets.shape, (20,))
self.assertTupleEqual(element.types.shape, (20,))
with self.subTest('TargetsGrowAsAFunctionOfTime'):
previous_mean_target = 0.0
# Time index 9 refers to 1/e = 0.36 in the IS, and therefore it is between
# Time index 5 (0.4) and time index 6 (0.3).
for time_index in [0, 1, 2, 3, 4, 5, 9, 6, 7, 8]:
data = train.load_data(file_pattern, time_index)[0]
current_mean_target = data.targets.mean()
self.assertGreater(current_mean_target, previous_mean_target)
previous_mean_target = current_mean_target
class TensorflowTrainTest(tf.test.TestCase):
def test_get_loss_op(self):
"""Tests the correct calculation of the loss operations."""
prediction = tf.constant([0.0, 1.0, 2.0, 1.0, 2.0], dtype=tf.float32)
target = tf.constant([1.0, 25.0, 0.0, 4.0, 2.0], dtype=tf.float32)
types = tf.constant([0, 1, 0, 0, 0], dtype=tf.int32)
loss_ops = train.get_loss_ops(prediction, target, types)
loss = self.evaluate(loss_ops)
self.assertAlmostEqual(loss.l1_loss, 1.5)
self.assertAlmostEqual(loss.l2_loss, 14.0 / 4.0)
self.assertAlmostEqual(loss.correlation, -0.15289416)
def test_get_minimize_op(self):
"""Tests the minimize operation by minimizing a single variable."""
var = tf.Variable([1.0], name='test')
loss = var**2
minimize = train.get_minimize_op(loss, 1e-1)
with self.session():
tf.global_variables_initializer().run()
for _ in range(100):
minimize.run()
value = var.eval()
self.assertLess(abs(value[0]), 0.01)
def test_train_model(self):
"""Tests if we can overfit to a small test dataset."""
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
'test_small.pickle')
best_correlation_value = train.train_model(
train_file_pattern=file_pattern,
test_file_pattern=file_pattern,
n_epochs=1000,
augment_data_using_rotations=False,
learning_rate=1e-4,
n_recurrences=2,
edge_threshold=5,
mlp_sizes=(32, 32),
measurement_store_interval=1000)
# The test dataset contains only a single sample with 20 particles.
# Therefore we expect the model to be able to memorize the targets perfectly
# if the model works correctly.
self.assertGreater(best_correlation_value, 0.99)
def test_apply_model(self):
"""Tests if we can apply a model to a small test dataset."""
checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoints',
't044_s09.ckpt')
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
'test_large.pickle')
predictions = train.apply_model(checkpoint_path=checkpoint_path,
file_pattern=file_pattern,
time_index=0)
data = train.load_data(file_pattern, 0)
targets = data[0].targets
correlation_value = np.corrcoef(predictions[0], targets)[0, 1]
self.assertGreater(correlation_value, 0.5)
if __name__ == '__main__':
tf.test.main()