mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-16 05:34:12 +08:00
Add a colab for generating figures.
Export training curves to file and fix some inconsistencies. PiperOrigin-RevId: 324825810
This commit is contained in:
committed by
Diego de Las Casas
parent
99aaa6930a
commit
60550a5bc6
@@ -135,7 +135,7 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
||||
break
|
||||
|
||||
# Terminate option.
|
||||
if self._compute_reward(option, action_step.observation) > 0:
|
||||
if self._should_terminate(option, action_step.observation):
|
||||
break
|
||||
|
||||
if not self._call_and_return:
|
||||
@@ -143,6 +143,16 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
||||
|
||||
return option_step
|
||||
|
||||
def _should_terminate(self, option, obs):
|
||||
if self._compute_reward(option, obs) > 0:
|
||||
return True
|
||||
elif np.all(self._options_np[option] <= 0):
|
||||
# TODO(shaobohou) A hack ensure option with non-positive weights
|
||||
# terminates after one step
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.DiscreteArray(
|
||||
num_values=self._options_np.shape[0], name="action")
|
||||
@@ -228,7 +238,7 @@ class EnvironmentWithKeyboardDirect(dm_env.Environment):
|
||||
break
|
||||
|
||||
# Terminate option.
|
||||
if self._compute_reward(option, action_step.observation) > 0:
|
||||
if self._should_terminate(option, action_step.observation):
|
||||
break
|
||||
|
||||
if not self._call_and_return:
|
||||
@@ -236,6 +246,16 @@ class EnvironmentWithKeyboardDirect(dm_env.Environment):
|
||||
|
||||
return option_step
|
||||
|
||||
def _should_terminate(self, option, obs):
|
||||
if self._compute_reward(option, obs) > 0:
|
||||
return True
|
||||
elif np.all(option <= 0):
|
||||
# TODO(shaobohou) A hack ensure option with non-positive weights
|
||||
# terminates after one step
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.BoundedArray(shape=(self._keyboard.num_cumulants,),
|
||||
dtype=np.float32,
|
||||
@@ -280,10 +300,7 @@ def _discretize_actions(num_actions_per_dim,
|
||||
|
||||
# Remove options with all zeros.
|
||||
non_zero_entries = np.sum(np.square(discretized_actions), axis=-1) != 0.0
|
||||
# Remove options with no positive elements.
|
||||
non_negative_entries = np.any(discretized_actions > 0, axis=-1)
|
||||
discretized_actions = discretized_actions[np.logical_and(
|
||||
non_zero_entries, non_negative_entries)]
|
||||
discretized_actions = discretized_actions[non_zero_entries]
|
||||
logging.info("Total number of discretized actions: %s",
|
||||
len(discretized_actions))
|
||||
logging.info("Discretized actions: %s", discretized_actions)
|
||||
|
||||
@@ -16,7 +16,10 @@
|
||||
# ============================================================================
|
||||
"""A simple training loop."""
|
||||
|
||||
import csv
|
||||
|
||||
from absl import logging
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
|
||||
|
||||
def _ema(base, val, decay=0.995):
|
||||
@@ -32,31 +35,42 @@ def run(env, agent, num_episodes, report_every=200, num_eval_reps=1):
|
||||
num_episodes: Number of episodes to train for.
|
||||
report_every: Frequency at which training progress are reported (episodes).
|
||||
num_eval_reps: Number of eval episodes to run per training episode.
|
||||
|
||||
Returns:
|
||||
A list of dicts containing training and evaluation returns, and a list of
|
||||
reported returns smoothed by EMA.
|
||||
"""
|
||||
|
||||
train_returns = []
|
||||
returns = []
|
||||
logged_returns = []
|
||||
train_return_ema = 0.
|
||||
eval_returns = []
|
||||
eval_return_ema = 0.
|
||||
for episode_id in range(num_episodes):
|
||||
for episode in range(num_episodes):
|
||||
returns.append(dict(episode=episode))
|
||||
|
||||
# Run a training episode.
|
||||
train_episode_return = run_episode(env, agent, is_training=True)
|
||||
train_returns.append(train_episode_return)
|
||||
train_return_ema = _ema(train_return_ema, train_episode_return)
|
||||
returns[-1]["train"] = train_episode_return
|
||||
|
||||
# Run an evaluation episode.
|
||||
returns[-1]["eval"] = []
|
||||
for _ in range(num_eval_reps):
|
||||
eval_episode_return = run_episode(env, agent, is_training=False)
|
||||
eval_returns.append(eval_episode_return)
|
||||
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
|
||||
returns[-1]["eval"].append(eval_episode_return)
|
||||
|
||||
if ((episode_id + 1) % report_every) == 0:
|
||||
if ((episode + 1) % report_every) == 0 or episode == 0:
|
||||
logged_returns.append(
|
||||
dict(episode=episode, train=train_return_ema, eval=[eval_return_ema]))
|
||||
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||
episode_id + 1, train_return_ema, eval_return_ema)
|
||||
episode + 1, train_return_ema, eval_return_ema)
|
||||
if hasattr(agent, "get_logs"):
|
||||
logging.info("Episode %s, agent logs: %s", episode_id + 1,
|
||||
logging.info("Episode %s, agent logs: %s", episode + 1,
|
||||
agent.get_logs())
|
||||
|
||||
return returns, logged_returns
|
||||
|
||||
|
||||
def run_episode(environment, agent, is_training=False):
|
||||
"""Run a single episode."""
|
||||
@@ -75,3 +89,14 @@ def run_episode(environment, agent, is_training=False):
|
||||
episode_return = environment.episode_return
|
||||
|
||||
return episode_return
|
||||
|
||||
|
||||
def write_returns_to_file(path, returns):
|
||||
"""Write returns to file."""
|
||||
|
||||
with gfile.GFile(path, "w") as file:
|
||||
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(["episode", "train"] +
|
||||
[f"eval_{idx}" for idx in range(len(returns[0]["eval"]))])
|
||||
for row in returns:
|
||||
writer.writerow([row["episode"], row["train"]] + row["eval"])
|
||||
|
||||
@@ -30,7 +30,7 @@ python3 train_keyboard.py --logtostderr --policy_weights_name=5
|
||||
|
||||
Then generate the polar plot data as follows:
|
||||
|
||||
python3 eval_keyboard_fig5a.py --logtostderr \
|
||||
python3 eval_keyboard_fig5.py --logtostderr \
|
||||
--keyboard_paths=/tmp/option_keyboard/keyboard_12/tfhub,/tmp/option_keyboard/keyboard_34/tfhub,/tmp/option_keyboard/keyboard_5/tfhub \
|
||||
--num_episodes=1000
|
||||
|
||||
@@ -57,11 +57,14 @@ Example outout:
|
||||
[ 0.099 0.349 0.055 ]]
|
||||
"""
|
||||
|
||||
import csv
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
@@ -75,20 +78,12 @@ from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_list("keyboard_paths", [], "Path to keyboard model.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out returns.")
|
||||
|
||||
|
||||
def evaluate_keyboard(keyboard_path):
|
||||
def evaluate_keyboard(keyboard_path, weights_to_sweep):
|
||||
"""Evaluate a keyboard."""
|
||||
|
||||
angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
|
||||
weights_to_sweep = np.stack(
|
||||
[np.cos(angles_to_sweep),
|
||||
np.sin(angles_to_sweep)], axis=-1)
|
||||
weights_to_sweep /= np.sum(
|
||||
np.maximum(weights_to_sweep, 0.0), axis=-1, keepdims=True)
|
||||
weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
|
||||
tf.logging.info(weights_to_sweep)
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
||||
|
||||
@@ -124,20 +119,41 @@ def evaluate_keyboard(keyboard_path):
|
||||
f"{FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||
all_returns.append(returns)
|
||||
|
||||
return all_returns, weights_to_sweep
|
||||
return all_returns
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
|
||||
weights_to_sweep = np.stack(
|
||||
[np.sin(angles_to_sweep),
|
||||
np.cos(angles_to_sweep)], axis=-1)
|
||||
weights_to_sweep /= np.sum(
|
||||
np.maximum(weights_to_sweep, 0.0), axis=-1, keepdims=True)
|
||||
weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
|
||||
tf.logging.info(weights_to_sweep)
|
||||
|
||||
all_returns = []
|
||||
for keyboard_path in FLAGS.keyboard_paths:
|
||||
returns, _ = evaluate_keyboard(keyboard_path)
|
||||
returns = evaluate_keyboard(keyboard_path, weights_to_sweep)
|
||||
all_returns.append(returns)
|
||||
|
||||
print("Results:")
|
||||
print(np.mean(all_returns, axis=-1).T)
|
||||
|
||||
if FLAGS.output_path:
|
||||
with gfile.GFile(FLAGS.output_path, "w") as file:
|
||||
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(["angle", "return", "idx"])
|
||||
for idx, returns in enumerate(all_returns):
|
||||
for row in np.array(returns).T.tolist():
|
||||
assert len(angles_to_sweep) == len(row)
|
||||
for ang, val in zip(angles_to_sweep, row):
|
||||
ang = "{:.4g}".format(ang)
|
||||
val = "{:.4g}".format(val)
|
||||
writer.writerow([ang, val, idx])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
|
||||
839
option_keyboard/gpe_gpi_experiments/generate_figures.ipynb
Normal file
839
option_keyboard/gpe_gpi_experiments/generate_figures.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -32,6 +32,9 @@ from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("report_every", 5,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -56,7 +59,13 @@ def main(argv):
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=FLAGS.report_every)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -32,7 +32,10 @@ from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_list("test_w", [], "The w to test.")
|
||||
flags.DEFINE_list("test_w", None, "The w to test.")
|
||||
flags.DEFINE_integer("report_every", 200,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -58,7 +61,13 @@ def main(argv):
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=FLAGS.report_every)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
97
option_keyboard/gpe_gpi_experiments/run_regressed_w_fig4b.py
Normal file
97
option_keyboard/gpe_gpi_experiments/run_regressed_w_fig4b.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
r"""Run an experiment.
|
||||
|
||||
Run GPE/GPI on task (1, -1) with w obtained by regression.
|
||||
|
||||
|
||||
For example, first train a keyboard:
|
||||
|
||||
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12 \
|
||||
--export_path=/tmp/option_keyboard/keyboard
|
||||
|
||||
|
||||
Then, evaluate the keyboard with w by regression.
|
||||
|
||||
python3 run_regressed_w_fig4b.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 4000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("report_every", 5,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_fig4_task_config()
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||
env=base_env,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-2,),
|
||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||
)
|
||||
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=FLAGS.report_every,
|
||||
num_eval_reps=20)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
@@ -27,7 +27,7 @@ python3 train_keyboard.py -- --logtostderr --policy_weights_name=12 \
|
||||
|
||||
Then, evaluate the keyboard with w by regression.
|
||||
|
||||
python3 run_regressed_w_fig4.py -- --logtostderr \
|
||||
python3 run_regressed_w_fig4c.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
@@ -47,8 +47,11 @@ from option_keyboard import smart_module
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("num_episodes", 100, "Number of training episodes.")
|
||||
flags.DEFINE_integer("report_every", 1,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -75,16 +78,18 @@ def main(argv):
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=1e-1,),
|
||||
optimizer_kwargs=dict(learning_rate=3e-2,),
|
||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||
)
|
||||
|
||||
experiment.run(
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=2,
|
||||
report_every=FLAGS.report_every,
|
||||
num_eval_reps=100)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -35,7 +35,7 @@ python3 train_keyboard_with_phi.py -- --logtostderr \
|
||||
|
||||
Finally, evaluate the keyboard with w by regression.
|
||||
|
||||
python3 run_regressed_w_with_phi_fig4b.py -- --logtostderr \
|
||||
python3 run_regressed_w_with_phi_fig4c.py -- --logtostderr \
|
||||
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_3d/tfhub
|
||||
"""
|
||||
@@ -56,9 +56,12 @@ from option_keyboard import smart_module
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("num_episodes", 100, "Number of training episodes.")
|
||||
flags.DEFINE_integer("report_every", 1,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("phi_model_path", None, "Path to phi model.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -88,16 +91,18 @@ def main(argv):
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=1e-1,),
|
||||
optimizer_kwargs=dict(learning_rate=3e-2,),
|
||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||
)
|
||||
|
||||
experiment.run(
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=2,
|
||||
report_every=FLAGS.report_every,
|
||||
num_eval_reps=100)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -30,11 +30,14 @@ python3 run_true_w_fig4.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
import csv
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
@@ -48,6 +51,7 @@ from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out returns.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -86,6 +90,13 @@ def main(argv):
|
||||
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||
tf.logging.info("#" * 80)
|
||||
|
||||
if FLAGS.output_path:
|
||||
with gfile.GFile(FLAGS.output_path, "w") as file:
|
||||
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(["return"])
|
||||
for val in returns:
|
||||
writer.writerow([val])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
|
||||
@@ -26,15 +26,18 @@ python3 train_keyboard.py -- --logtostderr --policy_weights_name=12
|
||||
|
||||
Then, evaluate the keyboard with a fixed w.
|
||||
|
||||
python3 run_true_w_fig4.py -- --logtostderr \
|
||||
python3 run_true_w_fig6.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
import csv
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
@@ -48,7 +51,8 @@ from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
flags.DEFINE_list("test_w", [], "The w to test.")
|
||||
flags.DEFINE_list("test_w", None, "The w to test.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out returns.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -87,6 +91,13 @@ def main(argv):
|
||||
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||
tf.logging.info("#" * 80)
|
||||
|
||||
if FLAGS.output_path:
|
||||
with gfile.GFile(FLAGS.output_path, "w") as file:
|
||||
writer = csv.writer(file, delimiter=" ", quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(["return"])
|
||||
for val in returns:
|
||||
writer.writerow([val])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
|
||||
@@ -139,10 +139,8 @@ def main(argv):
|
||||
tasks = [
|
||||
[1.0, 0.0],
|
||||
[0.0, 1.0],
|
||||
[-1.0, 0.0],
|
||||
[0.0, -1.0],
|
||||
[0.7, 0.3],
|
||||
[-0.3, -0.7],
|
||||
[1.0, 1.0],
|
||||
[-1.0, 1.0],
|
||||
]
|
||||
|
||||
if FLAGS.normalisation == "L1":
|
||||
|
||||
@@ -29,6 +29,9 @@ from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("report_every", 200,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -53,7 +56,13 @@ def main(argv):
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=FLAGS.report_every)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -36,7 +36,10 @@ FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||
"Number of pretraining episodes.")
|
||||
flags.DEFINE_integer("report_every", 200,
|
||||
"Frequency at which metrics are reported.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to pretrained keyboard model.")
|
||||
flags.DEFINE_string("output_path", None, "Path to write out training curves.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
@@ -84,7 +87,13 @@ def main(argv):
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
_, ema_returns = experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=FLAGS.report_every)
|
||||
if FLAGS.output_path:
|
||||
experiment.write_returns_to_file(FLAGS.output_path, ema_returns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,11 +17,10 @@
|
||||
"""Smart module export/import utilities."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compat.v1.io import gfile
|
||||
import tensorflow_hub as hub
|
||||
import tree as nest
|
||||
import wrapt
|
||||
@@ -164,9 +163,9 @@ class SmartModuleExport(object):
|
||||
module_session.run(
|
||||
assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))
|
||||
|
||||
if overwrite and os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path)
|
||||
if overwrite and gfile.exists(path):
|
||||
gfile.rmtree(path)
|
||||
gfile.makedirs(path)
|
||||
hub_module.export(path, module_session)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user