mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 09:02:05 +08:00
PAC Bayes Quadratic bound open sourcing.
PiperOrigin-RevId: 302629851
This commit is contained in:
committed by
Louise Deason
parent
afcdc77239
commit
f6395d709e
80
glassy_dynamics/README.md
Normal file
80
glassy_dynamics/README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# Unveiling the predictive power of static structure in glassy systems
|
||||
|
||||
This repository contains an open source implementation of the graph neural
|
||||
network model described in our paper.
|
||||
The model can be trained using the training binary included in this repository,
|
||||
and the dataset published with our paper.
|
||||
|
||||
|
||||
## Abstract
|
||||
|
||||
Despite decades of theoretical studies, the nature of the glass transition
|
||||
remains elusive and debated, while the existence of structural predictors of the
|
||||
dynamics is a major open question. Recent approaches propose inferring
|
||||
predictors from a variety of human-defined features using machine learning.
|
||||
We learn the long time evolution of a glassy system solely from the initial
|
||||
particle positions and without any hand-crafted features, using a powerful
|
||||
model: graph neural networks. We show that this method strongly outperforms
|
||||
state-of-the-art methods, generalizing over a wide range of temperatures,
|
||||
pressures, and densities. In shear experiments, it predicts the location of
|
||||
rearranging particles. The structural predictors learned by our network unveil a
|
||||
correlation length which increases with larger timescales to reach the size of
|
||||
our system. Beyond glasses, our method could apply to many other physical
|
||||
systems that map to a graph of local interactions.
|
||||
|
||||
|
||||
## Dataset
|
||||
|
||||
### System description
|
||||
|
||||
The dataset was generated with the LAMMPS molecular dynamics package.
|
||||
The simulated system has periodic boundaries and is a binary mixture of 4096
|
||||
large (A) and small (B) particles that interact via a 6-12 Lennard-Jones
|
||||
potential.
|
||||
The interaction coefficients are set for a typical Kob-Andersen configuration.
|
||||
|
||||
### Data format
|
||||
|
||||
The data is stored in Python's pickle format protocol version 3.
|
||||
Each file contains the data for one of the equilibrated systems in a Python
|
||||
dictionary. The dictionary contains the following entries:
|
||||
|
||||
- `positions` the particle positions of the equilibrated system.
|
||||
- `types` the particle types (0 == type A and 1 == type B) of the equilibrated
|
||||
system.
|
||||
- `box` the dimensions of the periodic cubic simulation box.
|
||||
- `time` the logarithmically sampled time points.
|
||||
- `time_indices` the indices of the time points for which the sampled
|
||||
trajectories on average reach a certain value of the intermediate
|
||||
scattering function.
|
||||
- `is_values` the values of the intermediate scattering function associated
|
||||
with each time index.
|
||||
- `trajectory_start_velocities` the velocities drawn from a Boltzmann
|
||||
distribution at the start of each trajectory.
|
||||
- `trajectory_target_positions` the positions of the particles for each of
|
||||
the trajectories at selected time points (as defined by the `time_indices`
|
||||
array and the corresponding values of the intermediate scattering function
|
||||
stored in `is_values`).
|
||||
- `metadata` a dictionary containing additional metadata:
|
||||
- `temperature` the temperature at which the system was equilibrated.
|
||||
- `pressure` the pressure at which the system was equilibrated.
|
||||
- `fluid` the type of fluid which was simulated (Kob-Andersen).
|
||||
|
||||
All units are in Lennard-Jones units. The positions are stored in the absolute
|
||||
coordinate system i.e. they are outside of the simulation box if the particle
|
||||
crossed a periodic boundary during the simulation.
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
If this repository is helpful for your research please cite the following
|
||||
publication:
|
||||
|
||||
Unveiling the predictive power of static structure in glassysystems
|
||||
V. Bapst, T. Keck, A. Grabska-Barwinska, C. Donner, E. D. Cubuk,
|
||||
S. S. Schoenholz, A.Obika, A. W. R. Nelson, T. Back, D. Hassabis and P. Kohli
|
||||
|
||||
|
||||
## Disclaimer
|
||||
This is not an official Google product.
|
||||
|
||||
62
glassy_dynamics/apply_binary.py
Normal file
62
glassy_dynamics/apply_binary.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Applies a graph-based network to predict particle mobilities in glasses."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from glassy_dynamics import train
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
'data_directory',
|
||||
'',
|
||||
'Directory which contains the train or test datasets.')
|
||||
flags.DEFINE_integer(
|
||||
'time_index',
|
||||
9,
|
||||
'The time index of the target mobilities.')
|
||||
flags.DEFINE_integer(
|
||||
'max_files_to_load',
|
||||
None,
|
||||
'The maximum number of files to load.')
|
||||
flags.DEFINE_string(
|
||||
'checkpoint_path',
|
||||
'checkpoints/t044_s09.ckpt',
|
||||
'Path used to load the model.')
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError('Too many command-line arguments.')
|
||||
|
||||
file_pattern = os.path.join(FLAGS.data_directory, 'aggregated*')
|
||||
train.apply_model(
|
||||
checkpoint_path=FLAGS.checkpoint_path,
|
||||
file_pattern=file_pattern,
|
||||
max_files_to_load=FLAGS.max_files_to_load,
|
||||
time_index=FLAGS.time_index)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
BIN
glassy_dynamics/checkpoints/t044_s09.ckpt.data-00000-of-00001
Normal file
BIN
glassy_dynamics/checkpoints/t044_s09.ckpt.data-00000-of-00001
Normal file
Binary file not shown.
BIN
glassy_dynamics/checkpoints/t044_s09.ckpt.index
Normal file
BIN
glassy_dynamics/checkpoints/t044_s09.ckpt.index
Normal file
Binary file not shown.
BIN
glassy_dynamics/checkpoints/t044_s09.ckpt.meta
Normal file
BIN
glassy_dynamics/checkpoints/t044_s09.ckpt.meta
Normal file
Binary file not shown.
190
glassy_dynamics/graph_model.py
Normal file
190
glassy_dynamics/graph_model.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A graph neural network based model to predict particle mobilities.
|
||||
|
||||
The architecture and performance of this model is described in our publication:
|
||||
"Unveiling the predictive power of static structure in glassy systems".
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from graph_nets import graphs
|
||||
from graph_nets import modules as gn_modules
|
||||
from graph_nets import utils_tf
|
||||
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
from typing import Any, Dict, Text, Tuple, Optional
|
||||
|
||||
|
||||
def make_graph_from_static_structure(
|
||||
positions,
|
||||
types,
|
||||
box,
|
||||
edge_threshold):
|
||||
"""Returns graph representing the static structure of the glass.
|
||||
|
||||
Each particle is represented by a node in the graph. The particle type is
|
||||
stored as a node feature.
|
||||
Two particles at a distance less than the threshold are connected by an edge.
|
||||
The relative distance vector is stored as an edge feature.
|
||||
|
||||
Args:
|
||||
positions: particle positions with shape [n_particles, 3].
|
||||
types: particle types with shape [n_particles].
|
||||
box: dimensions of the cubic box that contains the particles with shape [3].
|
||||
edge_threshold: particles at distance less than threshold are connected by
|
||||
an edge.
|
||||
"""
|
||||
# Calculate pairwise relative distances between particles: shape [n, n, 3].
|
||||
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
|
||||
# Enforces periodic boundary conditions.
|
||||
box_ = box[tf.newaxis, tf.newaxis, :]
|
||||
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
|
||||
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
|
||||
# Calculates adjacency matrix in a sparse format (indices), based on the given
|
||||
# distances and threshold.
|
||||
distances = tf.norm(cross_positions, axis=-1)
|
||||
indices = tf.where(distances < edge_threshold)
|
||||
|
||||
# Defines graph.
|
||||
nodes = types[:, tf.newaxis]
|
||||
senders = indices[:, 0]
|
||||
receivers = indices[:, 1]
|
||||
edges = tf.gather_nd(cross_positions, indices)
|
||||
|
||||
return graphs.GraphsTuple(
|
||||
nodes=tf.cast(nodes, tf.float32),
|
||||
n_node=tf.reshape(tf.shape(nodes)[0], [1]),
|
||||
edges=tf.cast(edges, tf.float32),
|
||||
n_edge=tf.reshape(tf.shape(edges)[0], [1]),
|
||||
globals=tf.zeros((1, 1), dtype=tf.float32),
|
||||
receivers=tf.cast(receivers, tf.int32),
|
||||
senders=tf.cast(senders, tf.int32)
|
||||
)
|
||||
|
||||
|
||||
def apply_random_rotation(graph):
|
||||
"""Returns randomly rotated graph representation.
|
||||
|
||||
The rotation is an element of O(3) with rotation angles multiple of pi/2.
|
||||
This function assumes that the relative particle distances are stored in
|
||||
the edge features.
|
||||
|
||||
Args:
|
||||
graph: The graphs tuple as defined in `graph_nets.graphs`.
|
||||
"""
|
||||
# Transposes edge features, so that the axes are in the first dimension.
|
||||
# Outputs a tensor of shape [3, n_particles].
|
||||
xyz = tf.transpose(graph.edges)
|
||||
# Random pi/2 rotation(s)
|
||||
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
|
||||
xyz = tf.gather(xyz, permutation)
|
||||
# Random reflections.
|
||||
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
|
||||
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
|
||||
xyz = xyz * symmetry
|
||||
edges = tf.transpose(xyz)
|
||||
return graph.replace(edges=edges)
|
||||
|
||||
|
||||
class GraphBasedModel(snt.AbstractModule):
|
||||
"""Graph based model which predicts particle mobilities from their positions.
|
||||
|
||||
This network encodes the nodes and edges of the input graph independently, and
|
||||
then performs message-passing on this graph, updating its edges based on their
|
||||
associated nodes, then updating the nodes based on the input nodes' features
|
||||
and their associated updated edge features.
|
||||
This update is repeated several times.
|
||||
Afterwards the resulting node embeddings are decoded to predict the particle
|
||||
mobility.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_recurrences,
|
||||
mlp_sizes,
|
||||
mlp_kwargs = None,
|
||||
name='Graph'):
|
||||
"""Creates a new GraphBasedModel object.
|
||||
|
||||
Args:
|
||||
n_recurrences: the number of message passing steps in the graph network.
|
||||
mlp_sizes: the number of neurons in each layer of the MLP.
|
||||
mlp_kwargs: additional keyword aguments passed to the MLP.
|
||||
name: the name of the Sonnet module.
|
||||
"""
|
||||
super(GraphBasedModel, self).__init__(name=name)
|
||||
self._n_recurrences = n_recurrences
|
||||
|
||||
if mlp_kwargs is None:
|
||||
mlp_kwargs = {}
|
||||
|
||||
model_fn = functools.partial(
|
||||
snt.nets.MLP,
|
||||
output_sizes=mlp_sizes,
|
||||
activate_final=True,
|
||||
**mlp_kwargs)
|
||||
|
||||
final_model_fn = functools.partial(
|
||||
snt.nets.MLP,
|
||||
output_sizes=mlp_sizes + (1,),
|
||||
activate_final=False,
|
||||
**mlp_kwargs)
|
||||
|
||||
with self._enter_variable_scope():
|
||||
self._encoder = gn_modules.GraphIndependent(
|
||||
node_model_fn=model_fn,
|
||||
edge_model_fn=model_fn)
|
||||
|
||||
if self._n_recurrences > 0:
|
||||
self._propagation_network = gn_modules.GraphNetwork(
|
||||
node_model_fn=model_fn,
|
||||
edge_model_fn=model_fn,
|
||||
# We do not use globals, hence we just pass the identity function.
|
||||
global_model_fn=lambda: lambda x: x,
|
||||
reducer=tf.unsorted_segment_sum,
|
||||
edge_block_opt=dict(use_globals=False),
|
||||
node_block_opt=dict(use_globals=False),
|
||||
global_block_opt=dict(use_globals=False))
|
||||
|
||||
self._decoder = gn_modules.GraphIndependent(
|
||||
node_model_fn=final_model_fn,
|
||||
edge_model_fn=model_fn)
|
||||
|
||||
def _build(self, graphs_tuple):
|
||||
"""Connects the model into the tensorflow graph.
|
||||
|
||||
Args:
|
||||
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`.
|
||||
|
||||
Returns:
|
||||
tensor with shape [n_particles] containing the predicted particle
|
||||
mobilities.
|
||||
"""
|
||||
encoded = self._encoder(graphs_tuple)
|
||||
outputs = encoded
|
||||
|
||||
for _ in range(self._n_recurrences):
|
||||
# Adds skip connections.
|
||||
inputs = utils_tf.concat([outputs, encoded], axis=-1)
|
||||
outputs = self._propagation_network(inputs)
|
||||
|
||||
decoded = self._decoder(outputs)
|
||||
return tf.squeeze(decoded.nodes, axis=-1)
|
||||
147
glassy_dynamics/graph_model_test.py
Normal file
147
glassy_dynamics/graph_model_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for graph_model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
from absl.testing import parameterized
|
||||
|
||||
from graph_nets import graphs
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from glassy_dynamics import graph_model
|
||||
|
||||
|
||||
class GraphModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Initializes a small tractable test (particle) system."""
|
||||
super(GraphModelTest, self).setUp()
|
||||
# Fixes random seed to ensure deterministic outputs.
|
||||
tf.random.set_random_seed(1234)
|
||||
|
||||
# In this test we use a small tractable set of particles covering all corner
|
||||
# cases:
|
||||
# a) eight particles with different types,
|
||||
# b) periodic box is not cubic,
|
||||
# c) three disjoint cluster of particles separated by a threshold > 2,
|
||||
# d) first two clusters overlap with the periodic boundary,
|
||||
# e) first cluster is not fully connected,
|
||||
# f) second cluster is fully connected,
|
||||
# g) and third cluster is a single isolated particle.
|
||||
#
|
||||
# The formatting of the code below separates the three clusters by
|
||||
# adding linebreaks after each cluster.
|
||||
self._positions = np.array(
|
||||
[[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, 9.0],
|
||||
[0.0, 5.0, 0.0], [0.0, 5.0, 1.0], [3.0, 5.0, 0.0],
|
||||
[2.0, 3.0, 3.0]])
|
||||
self._types = np.array([0.0, 0.0, 1.0, 0.0,
|
||||
0.0, 1.0, 0.0,
|
||||
0.0])
|
||||
self._box = np.array([4.0, 10.0, 10.0])
|
||||
|
||||
# Creates the corresponding graph elements, assuming a threshold of 2 and
|
||||
# the conventions described in `graph_nets.graphs`.
|
||||
self._edge_threshold = 2
|
||||
self._nodes = np.array(
|
||||
[[0.0], [0.0], [1.0], [0.0],
|
||||
[0.0], [1.0], [0.0],
|
||||
[0.0]])
|
||||
self._edges = np.array(
|
||||
[[0.0, 0.0, 0.0], [-1.5, 0.0, 0.0], [0.0, 1.5, 0.0], [0.0, 0.0, -1.0],
|
||||
[1.5, 0.0, 0.0], [0.0, 0.0, 0.0], [1.5, 0.0, -1.0],
|
||||
[0.0, -1.5, 0.0], [0.0, 0.0, 0.0], [0.0, -1.5, -1.0],
|
||||
[0.0, 0.0, 1.0], [-1.5, 0.0, 1.0], [0.0, 1.5, 1.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [-1.0, 0.0, 0.0],
|
||||
[0.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, -1.0],
|
||||
[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]])
|
||||
|
||||
self._receivers = np.array(
|
||||
[0, 1, 2, 3, 0, 1, 3, 0, 2, 3, 0, 1, 2, 3,
|
||||
4, 5, 6, 4, 5, 6, 4, 5, 6,
|
||||
7])
|
||||
self._senders = np.array(
|
||||
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3,
|
||||
4, 4, 4, 5, 5, 5, 6, 6, 6,
|
||||
7])
|
||||
|
||||
def _get_graphs_tuple(self):
|
||||
"""Returns a GraphsTuple containing a graph based on the test system."""
|
||||
return graphs.GraphsTuple(
|
||||
nodes=tf.constant(self._nodes, dtype=tf.float32),
|
||||
edges=tf.constant(self._edges, dtype=tf.float32),
|
||||
globals=tf.constant(np.array([[0.0]]), dtype=tf.float32),
|
||||
receivers=tf.constant(self._receivers, dtype=tf.int32),
|
||||
senders=tf.constant(self._senders, dtype=tf.int32),
|
||||
n_node=tf.constant([len(self._nodes)], dtype=tf.int32),
|
||||
n_edge=tf.constant([len(self._edges)], dtype=tf.int32))
|
||||
|
||||
def test_make_graph_from_static_structure(self):
|
||||
graphs_tuple_op = graph_model.make_graph_from_static_structure(
|
||||
tf.constant(self._positions, dtype=tf.float32),
|
||||
tf.constant(self._types, dtype=tf.int32),
|
||||
tf.constant(self._box, dtype=tf.float32),
|
||||
self._edge_threshold)
|
||||
graphs_tuple = self.evaluate(graphs_tuple_op)
|
||||
self.assertLen(self._nodes, graphs_tuple.n_node)
|
||||
self.assertLen(self._edges, graphs_tuple.n_edge)
|
||||
np.testing.assert_almost_equal(graphs_tuple.nodes, self._nodes)
|
||||
np.testing.assert_equal(graphs_tuple.senders, self._senders)
|
||||
np.testing.assert_equal(graphs_tuple.receivers, self._receivers)
|
||||
np.testing.assert_almost_equal(graphs_tuple.globals, np.array([[0.0]]))
|
||||
np.testing.assert_almost_equal(graphs_tuple.edges, self._edges)
|
||||
|
||||
def _is_equal_up_to_rotation(self, x, y):
|
||||
for axes in itertools.permutations([0, 1, 2]):
|
||||
for mirrors in itertools.product([1, -1], repeat=3):
|
||||
if np.allclose(x, y[:, axes] * mirrors):
|
||||
return True
|
||||
return False
|
||||
|
||||
def test_apply_random_rotation(self):
|
||||
graphs_tuple = self._get_graphs_tuple()
|
||||
rotated_graphs_tuple_op = graph_model.apply_random_rotation(graphs_tuple)
|
||||
rotated_graphs_tuple = self.evaluate(rotated_graphs_tuple_op)
|
||||
np.testing.assert_almost_equal(rotated_graphs_tuple.nodes, self._nodes)
|
||||
np.testing.assert_almost_equal(rotated_graphs_tuple.senders, self._senders)
|
||||
np.testing.assert_almost_equal(
|
||||
rotated_graphs_tuple.receivers, self._receivers)
|
||||
np.testing.assert_almost_equal(
|
||||
rotated_graphs_tuple.globals, np.array([[0.0]]))
|
||||
self.assertTrue(self._is_equal_up_to_rotation(rotated_graphs_tuple.edges,
|
||||
self._edges))
|
||||
|
||||
@parameterized.named_parameters(('no_propagation', 0, (30,)),
|
||||
('multi_propagation', 5, (15,)),
|
||||
('multi_layer', 1, (20, 30)))
|
||||
def test_GraphModel(self, n_recurrences, mlp_sizes):
|
||||
graphs_tuple = self._get_graphs_tuple()
|
||||
output_op = graph_model.GraphBasedModel(n_recurrences=n_recurrences,
|
||||
mlp_sizes=mlp_sizes)(graphs_tuple)
|
||||
self.assertListEqual(output_op.shape.as_list(), [len(self._types)])
|
||||
# Tests if the model runs without crashing.
|
||||
with self.session():
|
||||
tf.global_variables_initializer().run()
|
||||
output_op.eval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
30
glassy_dynamics/requirements.txt
Normal file
30
glassy_dynamics/requirements.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
# GlassyDynamics
|
||||
absl-py==0.8.1
|
||||
astor==0.8.0
|
||||
cloudpickle==1.1.1
|
||||
contextlib2==0.6.0.post1
|
||||
decorator==4.4.0
|
||||
dm-sonnet==1.35
|
||||
future==0.18.0
|
||||
gast==0.3.2
|
||||
google-pasta==0.1.7
|
||||
graph-nets==1.0.4
|
||||
grpcio==1.24.1
|
||||
h5py==2.10.0
|
||||
Keras-Applications==1.0.8
|
||||
Keras-Preprocessing==1.1.0
|
||||
Markdown==3.1.1
|
||||
mock==3.0.5
|
||||
networkx==2.3
|
||||
numpy==1.17.2
|
||||
pkg-resources==0.0.0
|
||||
protobuf==3.10.0
|
||||
semantic-version==2.8.2
|
||||
six==1.12.0
|
||||
tensorboard==1.13.1
|
||||
tensorflow==1.13.1
|
||||
tensorflow-estimator==1.13.0
|
||||
tensorflow-probability==0.6.0
|
||||
termcolor==1.1.0
|
||||
Werkzeug==0.16.0
|
||||
wrapt==1.11.2
|
||||
BIN
glassy_dynamics/testdata/test_large.pickle
vendored
Normal file
BIN
glassy_dynamics/testdata/test_large.pickle
vendored
Normal file
Binary file not shown.
BIN
glassy_dynamics/testdata/test_small.pickle
vendored
Normal file
BIN
glassy_dynamics/testdata/test_small.pickle
vendored
Normal file
Binary file not shown.
381
glassy_dynamics/train.py
Normal file
381
glassy_dynamics/train.py
Normal file
@@ -0,0 +1,381 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training pipeline for the prediction of particle mobilities in glasses."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import pickle
|
||||
|
||||
from absl import logging
|
||||
import enum
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
from typing import Any, Dict, List, Optional, Text, Tuple, Sequence
|
||||
|
||||
from glassy_dynamics import graph_model
|
||||
|
||||
tf.enable_resource_variables()
|
||||
|
||||
LossCollection = collections.namedtuple('LossCollection',
|
||||
'l1_loss, l2_loss, correlation')
|
||||
GlassSimulationData = collections.namedtuple('GlassSimulationData',
|
||||
'positions, targets, types, box')
|
||||
|
||||
|
||||
class ParticleType(enum.IntEnum):
|
||||
"""The simulation contains two particle types, identified as type A and B.
|
||||
|
||||
The dataset encodes the particle type in an integer.
|
||||
- 0 corresponds to particle type A.
|
||||
- 1 corresponds to particle type B.
|
||||
"""
|
||||
A = 0
|
||||
B = 1
|
||||
|
||||
|
||||
def get_targets(
|
||||
initial_positions,
|
||||
trajectory_target_positions):
|
||||
"""Returns the averaged particle mobilities from the sampled trajectories.
|
||||
|
||||
Args:
|
||||
initial_positions: the initial positions of the particles with shape
|
||||
[n_particles, 3].
|
||||
trajectory_target_positions: the absolute positions of the particles at the
|
||||
target time for all sampled trajectories, each with shape
|
||||
[n_particles, 3].
|
||||
"""
|
||||
targets = np.mean([np.linalg.norm(t - initial_positions, axis=-1)
|
||||
for t in trajectory_target_positions], axis=0)
|
||||
return targets.astype(np.float32)
|
||||
|
||||
|
||||
def load_data(
|
||||
file_pattern,
|
||||
time_index,
|
||||
max_files_to_load = None):
|
||||
"""Returns a dictionary containing the training or test dataset.
|
||||
|
||||
The dictionary contains:
|
||||
`positions`: `np.ndarray` containing the particle positions with shape
|
||||
[n_particles, 3].
|
||||
`targets`: `np.ndarray` containing particle mobilities with shape
|
||||
[n_particles].
|
||||
`types`: `np.ndarray` containing the particle types with shape with shape
|
||||
[n_particles].
|
||||
`box`: `np.ndarray` containing the dimensions of the periodic box with shape
|
||||
[3].
|
||||
|
||||
Args:
|
||||
file_pattern: pattern matching the files with the simulation data.
|
||||
time_index: the time index of the targets.
|
||||
max_files_to_load: the maximum number of files to load.
|
||||
"""
|
||||
filenames = tf.io.gfile.glob(file_pattern)
|
||||
if max_files_to_load:
|
||||
filenames = filenames[:max_files_to_load]
|
||||
|
||||
static_structures = []
|
||||
for filename in filenames:
|
||||
with tf.io.gfile.GFile(filename, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
static_structures.append(GlassSimulationData(
|
||||
positions=data['positions'].astype(np.float32),
|
||||
targets=get_targets(
|
||||
data['positions'], data['trajectory_target_positions'][time_index]),
|
||||
types=data['types'].astype(np.int32),
|
||||
box=data['box'].astype(np.float32)))
|
||||
return static_structures
|
||||
|
||||
|
||||
def get_loss_ops(
|
||||
prediction,
|
||||
target,
|
||||
types):
|
||||
"""Returns L1/L2 loss and correlation for type A particles.
|
||||
|
||||
Args:
|
||||
prediction: tensor with shape [n_particles] containing the predicted
|
||||
particle mobilities.
|
||||
target: tensor with shape [n_particles] containing the true particle
|
||||
mobilities.
|
||||
types: tensor with shape [n_particles] containing the particle types.
|
||||
"""
|
||||
# Considers only type A particles.
|
||||
mask = tf.equal(types, ParticleType.A)
|
||||
prediction = tf.boolean_mask(prediction, mask)
|
||||
target = tf.boolean_mask(target, mask)
|
||||
return LossCollection(
|
||||
l1_loss=tf.reduce_mean(tf.abs(prediction - target)),
|
||||
l2_loss=tf.reduce_mean((prediction - target)**2),
|
||||
correlation=tf.squeeze(tfp.stats.correlation(
|
||||
prediction[:, tf.newaxis], target[:, tf.newaxis])))
|
||||
|
||||
|
||||
def get_minimize_op(
|
||||
loss,
|
||||
learning_rate,
|
||||
grad_clip = None):
|
||||
"""Returns minimization operation.
|
||||
|
||||
Args:
|
||||
loss: the loss tensor which is minimized.
|
||||
learning_rate: the learning rate used by the optimizer.
|
||||
grad_clip: all gradients are clipped to the given value if not None or 0.
|
||||
"""
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate)
|
||||
grads_and_vars = optimizer.compute_gradients(loss)
|
||||
if grad_clip:
|
||||
grads, _ = tf.clip_by_global_norm([g for g, _ in grads_and_vars], grad_clip)
|
||||
grads_and_vars = [(g, pair[1]) for g, pair in zip(grads, grads_and_vars)]
|
||||
minimize = optimizer.apply_gradients(grads_and_vars)
|
||||
return minimize
|
||||
|
||||
|
||||
def _log_stats_and_return_mean_correlation(
|
||||
label,
|
||||
stats):
|
||||
"""Logs performance statistics and returns mean correlation.
|
||||
|
||||
Args:
|
||||
label: label printed before the combined statistics e.g. train or test.
|
||||
stats: statistics calculated for each batch in a dataset.
|
||||
|
||||
Returns:
|
||||
mean correlation
|
||||
"""
|
||||
for key in LossCollection._fields:
|
||||
values = [getattr(s, key) for s in stats]
|
||||
mean = np.mean(values)
|
||||
std = np.std(values)
|
||||
logging.info('%s: %s: %.4f +/- %.4f', label, key, mean, std)
|
||||
return np.mean([s.correlation for s in stats])
|
||||
|
||||
|
||||
def train_model(train_file_pattern,
|
||||
test_file_pattern,
|
||||
max_files_to_load = None,
|
||||
n_epochs = 1000,
|
||||
time_index = 9,
|
||||
augment_data_using_rotations = True,
|
||||
learning_rate = 1e-4,
|
||||
grad_clip = 1.0,
|
||||
n_recurrences = 7,
|
||||
mlp_sizes = (64, 64),
|
||||
mlp_kwargs = None,
|
||||
edge_threshold = 2.0,
|
||||
measurement_store_interval = 1000,
|
||||
checkpoint_path = None):
|
||||
"""Trains GraphModel using tensorflow.
|
||||
|
||||
Args:
|
||||
train_file_pattern: pattern matching the files with the training data.
|
||||
test_file_pattern: pattern matching the files with the test data.
|
||||
max_files_to_load: the maximum number of train and test files to load.
|
||||
If None, all files will be loaded.
|
||||
n_epochs: the number of passes through the training dataset (epochs).
|
||||
time_index: the time index (0-9) of the target mobilities.
|
||||
augment_data_using_rotations: data is augemented by using random rotations.
|
||||
learning_rate: the learning rate used by the optimizer.
|
||||
grad_clip: all gradients are clipped to the given value.
|
||||
n_recurrences: the number of message passing steps in the graphnet.
|
||||
mlp_sizes: the number of neurons in each layer of the MLP.
|
||||
mlp_kwargs: additional keyword aguments passed to the MLP.
|
||||
edge_threshold: particles at distance less than threshold are connected by
|
||||
an edge.
|
||||
measurement_store_interval: number of steps between storing objective values
|
||||
(loss and correlation).
|
||||
checkpoint_path: path used to store the checkpoint with the highest
|
||||
correlation on the test set.
|
||||
|
||||
Returns:
|
||||
Correlation on the test dataset of best model encountered during training.
|
||||
"""
|
||||
if mlp_kwargs is None:
|
||||
mlp_kwargs = dict(initializers=dict(w=tf.variance_scaling_initializer(1.0),
|
||||
b=tf.variance_scaling_initializer(0.1)))
|
||||
# Loads train and test dataset.
|
||||
dataset_kwargs = dict(
|
||||
time_index=time_index,
|
||||
max_files_to_load=max_files_to_load)
|
||||
training_data = load_data(train_file_pattern, **dataset_kwargs)
|
||||
test_data = load_data(test_file_pattern, **dataset_kwargs)
|
||||
|
||||
# Defines wrapper functions, which can directly be passed to the
|
||||
# tf.data.Dataset.map function.
|
||||
def _make_graph_from_static_structure(static_structure):
|
||||
"""Converts static structure to graph, targets and types."""
|
||||
return (graph_model.make_graph_from_static_structure(
|
||||
static_structure.positions,
|
||||
static_structure.types,
|
||||
static_structure.box,
|
||||
edge_threshold),
|
||||
static_structure.targets,
|
||||
static_structure.types)
|
||||
|
||||
def _apply_random_rotation(graph, targets, types):
|
||||
"""Applies random rotations to the graph and forwards targets and types."""
|
||||
return graph_model.apply_random_rotation(graph), targets, types
|
||||
|
||||
# Defines data-pipeline based on tf.data.Dataset following the official
|
||||
# guideline: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays.
|
||||
# We use initializable iterators to avoid embedding the training and test data
|
||||
# directly into the graph.
|
||||
# Instead we feed the data to the iterators during the initalization of the
|
||||
# iterators before the main training loop.
|
||||
placeholders = GlassSimulationData._make(
|
||||
tf.placeholder(s.dtype, (None,) + s.shape) for s in training_data[0])
|
||||
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
|
||||
dataset = dataset.map(_make_graph_from_static_structure)
|
||||
dataset = dataset.cache()
|
||||
dataset = dataset.shuffle(400)
|
||||
# Augments data. This has to be done after calling dataset.cache!
|
||||
if augment_data_using_rotations:
|
||||
dataset = dataset.map(_apply_random_rotation)
|
||||
dataset = dataset.repeat()
|
||||
train_iterator = dataset.make_initializable_iterator()
|
||||
|
||||
dataset = tf.data.Dataset.from_tensor_slices(placeholders)
|
||||
dataset = dataset.map(_make_graph_from_static_structure)
|
||||
dataset = dataset.cache()
|
||||
dataset = dataset.repeat()
|
||||
test_iterator = dataset.make_initializable_iterator()
|
||||
|
||||
# Creates tensorflow graph.
|
||||
# Note: We decouple the training and test datasets from the input pipeline
|
||||
# by creating a new iterator from a string-handle placeholder with the same
|
||||
# output types and shapes as the training dataset.
|
||||
dataset_handle = tf.placeholder(tf.string, shape=[])
|
||||
iterator = tf.data.Iterator.from_string_handle(
|
||||
dataset_handle, train_iterator.output_types, train_iterator.output_shapes)
|
||||
graph, targets, types = iterator.get_next()
|
||||
|
||||
model = graph_model.GraphBasedModel(
|
||||
n_recurrences, mlp_sizes, mlp_kwargs)
|
||||
prediction = model(graph)
|
||||
|
||||
# Defines loss and minimization operations.
|
||||
loss_ops = get_loss_ops(prediction, targets, types)
|
||||
minimize_op = get_minimize_op(loss_ops.l2_loss, learning_rate, grad_clip)
|
||||
|
||||
best_so_far = -1
|
||||
train_stats = []
|
||||
test_stats = []
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.train.SingularMonitoredSession() as session:
|
||||
# Initializes train and test iterators with the training and test datasets.
|
||||
# The obtained training and test string-handles can be passed to the
|
||||
# dataset_handle placeholder to select the dataset.
|
||||
train_handle = session.run(train_iterator.string_handle())
|
||||
test_handle = session.run(test_iterator.string_handle())
|
||||
feed_dict = {p: [x[i] for x in training_data]
|
||||
for i, p in enumerate(placeholders)}
|
||||
session.run(train_iterator.initializer, feed_dict=feed_dict)
|
||||
feed_dict = {p: [x[i] for x in test_data]
|
||||
for i, p in enumerate(placeholders)}
|
||||
session.run(test_iterator.initializer, feed_dict=feed_dict)
|
||||
|
||||
# Trains model using stochatic gradient descent on the training dataset.
|
||||
n_training_steps = len(training_data) * n_epochs
|
||||
for i in range(n_training_steps):
|
||||
feed_dict = {dataset_handle: train_handle}
|
||||
train_loss, _ = session.run((loss_ops, minimize_op), feed_dict=feed_dict)
|
||||
train_stats.append(train_loss)
|
||||
|
||||
if (i+1) % measurement_store_interval == 0:
|
||||
# Evaluates model on test dataset.
|
||||
for _ in range(len(test_data)):
|
||||
feed_dict = {dataset_handle: test_handle}
|
||||
test_stats.append(session.run(loss_ops, feed_dict=feed_dict))
|
||||
|
||||
# Outputs performance statistics on training and test dataset.
|
||||
_log_stats_and_return_mean_correlation('Train', train_stats)
|
||||
correlation = _log_stats_and_return_mean_correlation('Test', test_stats)
|
||||
train_stats = []
|
||||
test_stats = []
|
||||
|
||||
# Updates best model based on the observed correlation on the test
|
||||
# dataset.
|
||||
if correlation > best_so_far:
|
||||
best_so_far = correlation
|
||||
if checkpoint_path:
|
||||
saver.save(session.raw_session(), checkpoint_path)
|
||||
|
||||
return best_so_far
|
||||
|
||||
|
||||
def apply_model(checkpoint_path,
|
||||
file_pattern,
|
||||
max_files_to_load = None,
|
||||
time_index = 9):
|
||||
"""Applies trained GraphModel using tensorflow.
|
||||
|
||||
Args:
|
||||
checkpoint_path: path from which the model is loaded.
|
||||
file_pattern: pattern matching the files with the data.
|
||||
max_files_to_load: the maximum number of files to load.
|
||||
If None, all files will be loaded.
|
||||
time_index: the time index (0-9) of the target mobilities.
|
||||
|
||||
Returns:
|
||||
Predictions of the model for all files.
|
||||
"""
|
||||
dataset_kwargs = dict(
|
||||
time_index=time_index,
|
||||
max_files_to_load=max_files_to_load)
|
||||
data = load_data(file_pattern, **dataset_kwargs)
|
||||
|
||||
tf.reset_default_graph()
|
||||
saver = tf.train.import_meta_graph(checkpoint_path + '.meta')
|
||||
graph = tf.get_default_graph()
|
||||
|
||||
placeholders = GlassSimulationData(
|
||||
positions=graph.get_tensor_by_name('Placeholder:0'),
|
||||
targets=graph.get_tensor_by_name('Placeholder_1:0'),
|
||||
types=graph.get_tensor_by_name('Placeholder_2:0'),
|
||||
box=graph.get_tensor_by_name('Placeholder_3:0'))
|
||||
prediction_tensor = graph.get_tensor_by_name('Graph_1/Squeeze:0')
|
||||
correlation_tensor = graph.get_tensor_by_name('Squeeze:0')
|
||||
|
||||
dataset_handle = graph.get_tensor_by_name('Placeholder_4:0')
|
||||
test_initalizer = graph.get_operation_by_name('MakeIterator_1')
|
||||
test_string_handle = graph.get_tensor_by_name('IteratorToStringHandle_1:0')
|
||||
|
||||
with tf.Session() as session:
|
||||
saver.restore(session, checkpoint_path)
|
||||
handle = session.run(test_string_handle)
|
||||
feed_dict = {p: [x[i] for x in data] for i, p in enumerate(placeholders)}
|
||||
session.run(test_initalizer, feed_dict=feed_dict)
|
||||
predictions = []
|
||||
correlations = []
|
||||
for _ in range(len(data)):
|
||||
p, c = session.run((prediction_tensor, correlation_tensor),
|
||||
feed_dict={dataset_handle: handle})
|
||||
predictions.append(p)
|
||||
correlations.append(c)
|
||||
|
||||
logging.info('Correlation: %.4f +/- %.4f',
|
||||
np.mean(correlations),
|
||||
np.std(correlations))
|
||||
return predictions
|
||||
64
glassy_dynamics/train_binary.py
Normal file
64
glassy_dynamics/train_binary.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Trains a graph-based network to predict particle mobilities in glasses."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from glassy_dynamics import train
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
'data_directory',
|
||||
'',
|
||||
'Directory which contains the train and test datasets.')
|
||||
flags.DEFINE_integer(
|
||||
'time_index',
|
||||
9,
|
||||
'The time index of the target mobilities.')
|
||||
flags.DEFINE_integer(
|
||||
'max_files_to_load',
|
||||
None,
|
||||
'The maximum number of files to load from the train and test datasets.')
|
||||
flags.DEFINE_string(
|
||||
'checkpoint_path',
|
||||
None,
|
||||
'Path used to store a checkpoint of the best model.')
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError('Too many command-line arguments.')
|
||||
|
||||
train_file_pattern = os.path.join(FLAGS.data_directory, 'train/aggregated*')
|
||||
test_file_pattern = os.path.join(FLAGS.data_directory, 'test/aggregated*')
|
||||
train.train_model(
|
||||
train_file_pattern=train_file_pattern,
|
||||
test_file_pattern=test_file_pattern,
|
||||
max_files_to_load=FLAGS.max_files_to_load,
|
||||
time_index=FLAGS.time_index,
|
||||
checkpoint_path=FLAGS.checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
126
glassy_dynamics/train_test.py
Normal file
126
glassy_dynamics/train_test.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for train."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from glassy_dynamics import train
|
||||
|
||||
|
||||
class TrainTest(tf.test.TestCase):
|
||||
|
||||
def test_get_targets(self):
|
||||
initial_positions = np.array([[0, 0, 0], [1, 2, 3]])
|
||||
trajectory_target_positions = [
|
||||
np.array([[1, 0, 0], [1, 2, 4]]),
|
||||
np.array([[0, 1, 0], [1, 0, 3]]),
|
||||
np.array([[0, 0, 5], [1, 2, 3]]),
|
||||
]
|
||||
expected_targets = np.array([7.0 / 3.0, 1.0])
|
||||
targets = train.get_targets(initial_positions, trajectory_target_positions)
|
||||
np.testing.assert_almost_equal(expected_targets, targets)
|
||||
|
||||
def test_load_data(self):
|
||||
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
|
||||
'test_small.pickle')
|
||||
|
||||
with self.subTest('ContentAndShapesAreAsExpected'):
|
||||
data = train.load_data(file_pattern, 0)
|
||||
self.assertEqual(len(data), 1)
|
||||
element = data[0]
|
||||
self.assertTupleEqual(element.positions.shape, (20, 3))
|
||||
self.assertTupleEqual(element.box.shape, (3,))
|
||||
self.assertTupleEqual(element.targets.shape, (20,))
|
||||
self.assertTupleEqual(element.types.shape, (20,))
|
||||
|
||||
with self.subTest('TargetsGrowAsAFunctionOfTime'):
|
||||
previous_mean_target = 0.0
|
||||
# Time index 9 refers to 1/e = 0.36 in the IS, and therefore it is between
|
||||
# Time index 5 (0.4) and time index 6 (0.3).
|
||||
for time_index in [0, 1, 2, 3, 4, 5, 9, 6, 7, 8]:
|
||||
data = train.load_data(file_pattern, time_index)[0]
|
||||
current_mean_target = data.targets.mean()
|
||||
self.assertGreater(current_mean_target, previous_mean_target)
|
||||
previous_mean_target = current_mean_target
|
||||
|
||||
|
||||
class TensorflowTrainTest(tf.test.TestCase):
|
||||
|
||||
def test_get_loss_op(self):
|
||||
"""Tests the correct calculation of the loss operations."""
|
||||
prediction = tf.constant([0.0, 1.0, 2.0, 1.0, 2.0], dtype=tf.float32)
|
||||
target = tf.constant([1.0, 25.0, 0.0, 4.0, 2.0], dtype=tf.float32)
|
||||
types = tf.constant([0, 1, 0, 0, 0], dtype=tf.int32)
|
||||
loss_ops = train.get_loss_ops(prediction, target, types)
|
||||
loss = self.evaluate(loss_ops)
|
||||
self.assertAlmostEqual(loss.l1_loss, 1.5)
|
||||
self.assertAlmostEqual(loss.l2_loss, 14.0 / 4.0)
|
||||
self.assertAlmostEqual(loss.correlation, -0.15289416)
|
||||
|
||||
def test_get_minimize_op(self):
|
||||
"""Tests the minimize operation by minimizing a single variable."""
|
||||
var = tf.Variable([1.0], name='test')
|
||||
loss = var**2
|
||||
minimize = train.get_minimize_op(loss, 1e-1)
|
||||
with self.session():
|
||||
tf.global_variables_initializer().run()
|
||||
for _ in range(100):
|
||||
minimize.run()
|
||||
value = var.eval()
|
||||
self.assertLess(abs(value[0]), 0.01)
|
||||
|
||||
def test_train_model(self):
|
||||
"""Tests if we can overfit to a small test dataset."""
|
||||
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
|
||||
'test_small.pickle')
|
||||
best_correlation_value = train.train_model(
|
||||
train_file_pattern=file_pattern,
|
||||
test_file_pattern=file_pattern,
|
||||
n_epochs=1000,
|
||||
augment_data_using_rotations=False,
|
||||
learning_rate=1e-4,
|
||||
n_recurrences=2,
|
||||
edge_threshold=5,
|
||||
mlp_sizes=(32, 32),
|
||||
measurement_store_interval=1000)
|
||||
# The test dataset contains only a single sample with 20 particles.
|
||||
# Therefore we expect the model to be able to memorize the targets perfectly
|
||||
# if the model works correctly.
|
||||
self.assertGreater(best_correlation_value, 0.99)
|
||||
|
||||
def test_apply_model(self):
|
||||
"""Tests if we can apply a model to a small test dataset."""
|
||||
checkpoint_path = os.path.join(os.path.dirname(__file__), 'checkpoints',
|
||||
't044_s09.ckpt')
|
||||
file_pattern = os.path.join(os.path.dirname(__file__), 'testdata',
|
||||
'test_large.pickle')
|
||||
predictions = train.apply_model(checkpoint_path=checkpoint_path,
|
||||
file_pattern=file_pattern,
|
||||
time_index=0)
|
||||
data = train.load_data(file_pattern, 0)
|
||||
targets = data[0].targets
|
||||
correlation_value = np.corrcoef(predictions[0], targets)[0, 1]
|
||||
self.assertGreater(correlation_value, 0.5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
Reference in New Issue
Block a user