mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-13 12:00:08 +08:00
Add GPE/GPI experiments.
PiperOrigin-RevId: 323750949
This commit is contained in:
committed by
Diego de Las Casas
parent
59c0cf5044
commit
a24bda5ed0
@@ -21,7 +21,7 @@ def get_task_config():
|
||||
return dict(
|
||||
arena_size=11,
|
||||
num_channels=2,
|
||||
max_num_steps=50, # 5o for the actual task.
|
||||
max_num_steps=50, # 50 for the actual task.
|
||||
num_init_objects=10,
|
||||
object_priors=[0.5, 0.5],
|
||||
egocentric=True,
|
||||
@@ -39,3 +39,27 @@ def get_pretrain_config():
|
||||
egocentric=True,
|
||||
default_w=(1, 1),
|
||||
)
|
||||
|
||||
|
||||
def get_fig4_task_config():
|
||||
return dict(
|
||||
arena_size=11,
|
||||
num_channels=2,
|
||||
max_num_steps=50, # 50 for the actual task.
|
||||
num_init_objects=10,
|
||||
object_priors=[0.5, 0.5],
|
||||
egocentric=True,
|
||||
default_w=(1, -1),
|
||||
)
|
||||
|
||||
|
||||
def get_fig5_task_config(default_w):
|
||||
return dict(
|
||||
arena_size=11,
|
||||
num_channels=2,
|
||||
max_num_steps=50, # 50 for the actual task.
|
||||
num_init_objects=10,
|
||||
object_priors=[0.5, 0.5],
|
||||
egocentric=True,
|
||||
default_w=default_w,
|
||||
)
|
||||
|
||||
@@ -24,6 +24,10 @@ import dm_env
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
import tree
|
||||
|
||||
from option_keyboard import smart_module
|
||||
|
||||
|
||||
class EnvironmentWithLogging(dm_env.Environment):
|
||||
@@ -93,8 +97,9 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
||||
self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph])
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
saver = tf.train.Saver(var_list=keyboard.variables)
|
||||
saver.restore(session, keyboard_ckpt_path)
|
||||
if keyboard_ckpt_path:
|
||||
saver = tf.train.Saver(var_list=keyboard.variables)
|
||||
saver.restore(session, keyboard_ckpt_path)
|
||||
|
||||
def _compute_reward(self, option, obs):
|
||||
return np.sum(self._options_np[option] * obs["cumulants"])
|
||||
@@ -152,6 +157,102 @@ class EnvironmentWithKeyboard(dm_env.Environment):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class EnvironmentWithKeyboardDirect(dm_env.Environment):
|
||||
"""Wraps an environment with a keyboard.
|
||||
|
||||
This is different from EnvironmentWithKeyboard as the actions space is not
|
||||
discretized.
|
||||
|
||||
TODO(shaobohou) Merge the two implementations.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env,
|
||||
keyboard,
|
||||
keyboard_ckpt_path,
|
||||
additional_discount,
|
||||
call_and_return=False):
|
||||
self._env = env
|
||||
self._keyboard = keyboard
|
||||
self._discount = additional_discount
|
||||
self._call_and_return = call_and_return
|
||||
|
||||
obs_spec = self._extract_observation(env.observation_spec())
|
||||
obs_ph = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||
option_ph = tf.placeholder(
|
||||
shape=(keyboard.num_cumulants,), dtype=tf.float32)
|
||||
gpi_action = self._keyboard.gpi(obs_ph, option_ph)
|
||||
|
||||
session = tf.Session()
|
||||
self._gpi_action = session.make_callable(gpi_action, [obs_ph, option_ph])
|
||||
self._keyboard_action = session.make_callable(
|
||||
self._keyboard(tf.expand_dims(obs_ph, axis=0))[0], [obs_ph])
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
if keyboard_ckpt_path:
|
||||
saver = tf.train.Saver(var_list=keyboard.variables)
|
||||
saver.restore(session, keyboard_ckpt_path)
|
||||
|
||||
def _compute_reward(self, option, obs):
|
||||
assert option.shape == obs["cumulants"].shape
|
||||
return np.sum(option * obs["cumulants"])
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def step(self, option):
|
||||
"""Take a step in the keyboard, then the environment."""
|
||||
|
||||
step_count = 0
|
||||
option_step = None
|
||||
while True:
|
||||
obs = self._extract_observation(self._env.observation())
|
||||
action = self._gpi_action(obs, option)
|
||||
action_step = self._env.step(action)
|
||||
step_count += 1
|
||||
|
||||
if option_step is None:
|
||||
option_step = action_step
|
||||
else:
|
||||
new_discount = (
|
||||
option_step.discount * self._discount * action_step.discount)
|
||||
new_reward = (
|
||||
option_step.reward + new_discount * action_step.reward)
|
||||
option_step = option_step._replace(
|
||||
observation=action_step.observation,
|
||||
reward=new_reward,
|
||||
discount=new_discount,
|
||||
step_type=action_step.step_type)
|
||||
|
||||
if action_step.last():
|
||||
break
|
||||
|
||||
# Terminate option.
|
||||
if self._compute_reward(option, action_step.observation) > 0:
|
||||
break
|
||||
|
||||
if not self._call_and_return:
|
||||
break
|
||||
|
||||
return option_step
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.BoundedArray(shape=(self._keyboard.num_cumulants,),
|
||||
dtype=np.float32,
|
||||
minimum=-1.0,
|
||||
maximum=1.0,
|
||||
name="action")
|
||||
|
||||
def _extract_observation(self, obs):
|
||||
return obs["arena"]
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
def _discretize_actions(num_actions_per_dim,
|
||||
action_space_dim,
|
||||
min_val=-1.0,
|
||||
@@ -188,3 +289,71 @@ def _discretize_actions(num_actions_per_dim,
|
||||
logging.info("Discretized actions: %s", discretized_actions)
|
||||
|
||||
return discretized_actions
|
||||
|
||||
|
||||
class EnvironmentWithLearnedPhi(dm_env.Environment):
|
||||
"""Wraps an environment with learned phi model."""
|
||||
|
||||
def __init__(self, env, model_path):
|
||||
self._env = env
|
||||
|
||||
create_ph = lambda x: tf.placeholder(shape=x.shape, dtype=x.dtype)
|
||||
add_batch = lambda x: tf.expand_dims(x, axis=0)
|
||||
|
||||
# Make session and callables.
|
||||
with tf.Graph().as_default():
|
||||
model = smart_module.SmartModuleImport(hub.Module(model_path))
|
||||
|
||||
obs_spec = env.observation_spec()
|
||||
obs_ph = tree.map_structure(create_ph, obs_spec)
|
||||
action_ph = tf.placeholder(shape=(), dtype=tf.int32)
|
||||
phis = model(tree.map_structure(add_batch, obs_ph), add_batch(action_ph))
|
||||
|
||||
self.num_phis = phis.shape.as_list()[-1]
|
||||
self._last_phis = np.zeros((self.num_phis,), dtype=np.float32)
|
||||
|
||||
session = tf.Session()
|
||||
self._session = session
|
||||
self._phis_fn = session.make_callable(
|
||||
phis[0], tree.flatten([obs_ph, action_ph]))
|
||||
self._session.run(tf.global_variables_initializer())
|
||||
|
||||
def reset(self):
|
||||
self._last_phis = np.zeros((self.num_phis,), dtype=np.float32)
|
||||
return self._env.reset()
|
||||
|
||||
def step(self, action):
|
||||
"""Take action in the environment and do some logging."""
|
||||
|
||||
phis = self._phis_fn(*tree.flatten([self._env.observation(), action]))
|
||||
step = self._env.step(action)
|
||||
|
||||
if step.first():
|
||||
phis = self._phis_fn(*tree.flatten([self._env.observation(), action]))
|
||||
step = self._env.step(action)
|
||||
|
||||
step.observation["cumulants"] = phis
|
||||
self._last_phis = phis
|
||||
|
||||
return step
|
||||
|
||||
def action_spec(self):
|
||||
return self._env.action_spec()
|
||||
|
||||
def observation(self):
|
||||
obs = self._env.observation()
|
||||
obs["cumulants"] = self._last_phis
|
||||
return obs
|
||||
|
||||
def observation_spec(self):
|
||||
obs_spec = self._env.observation_spec()
|
||||
obs_spec["cumulants"] = dm_env.specs.BoundedArray(
|
||||
shape=(self.num_phis,),
|
||||
dtype=np.float32,
|
||||
minimum=-1e9,
|
||||
maximum=1e9,
|
||||
name="collected_resources")
|
||||
return obs_spec
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@@ -18,37 +18,44 @@
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
def _ema(base, val, decay=0.995):
|
||||
return base * decay + (1 - decay) * val
|
||||
|
||||
|
||||
def run(environment, agent, num_episodes, report_every=200):
|
||||
def run(env, agent, num_episodes, report_every=200, num_eval_reps=1):
|
||||
"""Runs an agent on an environment.
|
||||
|
||||
Args:
|
||||
environment: The environment.
|
||||
env: The environment.
|
||||
agent: The agent.
|
||||
num_episodes: Number of episodes to train for.
|
||||
report_every: Frequency at which training progress are reported (episodes).
|
||||
num_eval_reps: Number of eval episodes to run per training episode.
|
||||
"""
|
||||
|
||||
train_returns = []
|
||||
train_return_ema = 0.
|
||||
eval_returns = []
|
||||
eval_return_ema = 0.
|
||||
for episode_id in range(num_episodes):
|
||||
# Run a training episode.
|
||||
train_episode_return = run_episode(environment, agent, is_training=True)
|
||||
train_episode_return = run_episode(env, agent, is_training=True)
|
||||
train_returns.append(train_episode_return)
|
||||
train_return_ema = _ema(train_return_ema, train_episode_return)
|
||||
|
||||
# Run an evaluation episode.
|
||||
eval_episode_return = run_episode(environment, agent, is_training=False)
|
||||
eval_returns.append(eval_episode_return)
|
||||
for _ in range(num_eval_reps):
|
||||
eval_episode_return = run_episode(env, agent, is_training=False)
|
||||
eval_returns.append(eval_episode_return)
|
||||
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
|
||||
|
||||
if ((episode_id + 1) % report_every) == 0:
|
||||
logging.info(
|
||||
"Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||
episode_id + 1,
|
||||
np.mean(train_returns[-report_every:]),
|
||||
np.mean(eval_returns[-report_every:]),
|
||||
)
|
||||
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||
episode_id + 1, train_return_ema, eval_return_ema)
|
||||
if hasattr(agent, "get_logs"):
|
||||
logging.info("Episode %s, agent logs: %s", episode_id + 1,
|
||||
agent.get_logs())
|
||||
|
||||
|
||||
def run_episode(environment, agent, is_training=False):
|
||||
|
||||
144
option_keyboard/gpe_gpi_experiments/eval_keyboard_fig5.py
Normal file
144
option_keyboard/gpe_gpi_experiments/eval_keyboard_fig5.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# pylint: disable=line-too-long
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
r"""Run an experiment.
|
||||
|
||||
This script generates the raw data for the polar plots used to visualise how
|
||||
well a trained keyboard covers the space of w.
|
||||
|
||||
|
||||
For example, train 3 separate keyboards with different base policies:
|
||||
|
||||
python3 train_keyboard.py --logtostderr --policy_weights_name=12
|
||||
python3 train_keyboard.py --logtostderr --policy_weights_name=34
|
||||
python3 train_keyboard.py --logtostderr --policy_weights_name=5
|
||||
|
||||
|
||||
Then generate the polar plot data as follows:
|
||||
|
||||
python3 eval_keyboard_fig5a.py --logtostderr \
|
||||
--keyboard_paths=/tmp/option_keyboard/keyboard_12/tfhub,/tmp/option_keyboard/keyboard_34/tfhub,/tmp/option_keyboard/keyboard_5/tfhub \
|
||||
--num_episodes=1000
|
||||
|
||||
|
||||
Example outout:
|
||||
[[ 0.11 0.261 -0.933 ]
|
||||
[ 1.302 3.955 0.54 ]
|
||||
[ 2.398 4.434 1.2105359 ]
|
||||
[ 3.459 4.606 2.087 ]
|
||||
[ 4.09026795 4.60911325 3.06106882]
|
||||
[ 4.55499485 4.71947818 3.8123229 ]
|
||||
[ 4.715 4.835 4.395 ]
|
||||
[ 4.75743564 4.64095528 4.46330207]
|
||||
[ 4.82518207 4.71232378 4.56190708]
|
||||
[ 4.831 4.7155 4.5735 ]
|
||||
[ 4.78074425 4.6754641 4.58312762]
|
||||
[ 4.70154374 4.5416429 4.47850417]
|
||||
[ 4.694 4.631 4.427 ]
|
||||
[ 4.25085125 4.56606664 3.68157677]
|
||||
[ 3.61726795 4.4838453 2.68154403]
|
||||
[ 2.714 4.43 1.554 ]
|
||||
[ 1.69 4.505 0.9635359 ]
|
||||
[ 0.894 4.043 0.424 ]
|
||||
[ 0.099 0.349 0.055 ]]
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_list("keyboard_paths", [], "Path to keyboard model.")
|
||||
|
||||
|
||||
def evaluate_keyboard(keyboard_path):
|
||||
"""Evaluate a keyboard."""
|
||||
|
||||
angles_to_sweep = np.deg2rad(np.linspace(-90, 180, num=19, endpoint=True))
|
||||
weights_to_sweep = np.stack(
|
||||
[np.cos(angles_to_sweep),
|
||||
np.sin(angles_to_sweep)], axis=-1)
|
||||
weights_to_sweep /= np.sum(
|
||||
np.maximum(weights_to_sweep, 0.0), axis=-1, keepdims=True)
|
||||
weights_to_sweep = np.clip(weights_to_sweep, -1000, 1000)
|
||||
tf.logging.info(weights_to_sweep)
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
all_returns = []
|
||||
for w_to_sweep in weights_to_sweep.tolist():
|
||||
base_env_config = configs.get_fig5_task_config(w_to_sweep)
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
with tf.variable_scope(None, default_name="inner_loop"):
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||
env=base_env,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
# Disable training.
|
||||
optimizer_kwargs=dict(learning_rate=0.0,),
|
||||
init_w=w_to_sweep)
|
||||
|
||||
returns = []
|
||||
for _ in range(FLAGS.num_episodes):
|
||||
returns.append(experiment.run_episode(env, agent))
|
||||
tf.logging.info(f"Task: {w_to_sweep}, mean returns over "
|
||||
f"{FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||
all_returns.append(returns)
|
||||
|
||||
return all_returns, weights_to_sweep
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
all_returns = []
|
||||
for keyboard_path in FLAGS.keyboard_paths:
|
||||
returns, _ = evaluate_keyboard(keyboard_path)
|
||||
all_returns.append(returns)
|
||||
|
||||
print("Results:")
|
||||
print(np.mean(all_returns, axis=-1).T)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
95
option_keyboard/gpe_gpi_experiments/regressed_agent.py
Normal file
95
option_keyboard/gpe_gpi_experiments/regressed_agent.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Regressed agent."""
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class Agent():
|
||||
"""A DQN Agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size,
|
||||
optimizer_name,
|
||||
optimizer_kwargs,
|
||||
init_w,
|
||||
):
|
||||
"""A simple DQN agent.
|
||||
|
||||
Args:
|
||||
batch_size: Size of update batch.
|
||||
optimizer_name: Name of an optimizer from tf.train
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
init_w: The initial cumulant weight.
|
||||
"""
|
||||
self._batch_size = batch_size
|
||||
self._init_w = np.array(init_w)
|
||||
self._replay = []
|
||||
|
||||
# Regress w by gradient descent, could also use closed-form solution.
|
||||
self._n_cumulants = len(init_w)
|
||||
self._regressed_w = tf.get_variable(
|
||||
"regressed_w",
|
||||
dtype=tf.float32,
|
||||
initializer=lambda: tf.to_float(init_w))
|
||||
cumulants_ph = tf.placeholder(
|
||||
shape=(None, self._n_cumulants), dtype=tf.float32)
|
||||
rewards_ph = tf.placeholder(shape=(None,), dtype=tf.float32)
|
||||
predicted_rewards = tf.reduce_sum(
|
||||
tf.multiply(self._regressed_w, cumulants_ph), axis=-1)
|
||||
loss = tf.reduce_sum(tf.square(predicted_rewards - rewards_ph))
|
||||
|
||||
with tf.variable_scope("optimizer"):
|
||||
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
||||
train_op = self._optimizer.minimize(loss)
|
||||
|
||||
# Make session and callables.
|
||||
session = tf.Session()
|
||||
self._update_fn = session.make_callable(train_op,
|
||||
[cumulants_ph, rewards_ph])
|
||||
self._action = session.make_callable(self._regressed_w.read_value(), [])
|
||||
session.run(tf.global_variables_initializer())
|
||||
|
||||
def step(self, timestep, is_training=False):
|
||||
"""Select actions according to epsilon-greedy policy."""
|
||||
del timestep
|
||||
|
||||
if is_training:
|
||||
# Can also just use random actions at environment level.
|
||||
return np.random.uniform(low=-1.0, high=1.0, size=(self._n_cumulants,))
|
||||
|
||||
return self._action()
|
||||
|
||||
def update(self, step_tm1, action, step_t):
|
||||
"""Takes in a transition from the environment."""
|
||||
del step_tm1, action
|
||||
|
||||
transition = [
|
||||
step_t.observation["cumulants"],
|
||||
step_t.reward,
|
||||
]
|
||||
self._replay.append(transition)
|
||||
|
||||
if len(self._replay) == self._batch_size:
|
||||
batch = list(zip(*self._replay))
|
||||
self._update_fn(*batch)
|
||||
self._replay = [] # Just a queue.
|
||||
|
||||
def get_logs(self):
|
||||
return dict(regressed=self._action())
|
||||
64
option_keyboard/gpe_gpi_experiments/run_dqn_fig4b.py
Normal file
64
option_keyboard/gpe_gpi_experiments/run_dqn_fig4b.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Run an experiment.
|
||||
|
||||
Run a q-learning agent on task (1, -1).
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import dqn_agent
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Create the task environment.
|
||||
env_config = configs.get_fig4_task_config()
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
|
||||
# Create the flat agent.
|
||||
agent = dqn_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
66
option_keyboard/gpe_gpi_experiments/run_dqn_fig5.py
Normal file
66
option_keyboard/gpe_gpi_experiments/run_dqn_fig5.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Run an experiment.
|
||||
|
||||
Run a q-learning agent on a task.
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import dqn_agent
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_list("test_w", [], "The w to test.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Create the task environment.
|
||||
test_w = [float(x) for x in FLAGS.test_w]
|
||||
env_config = configs.get_fig5_task_config(test_w)
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
|
||||
# Create the flat agent.
|
||||
agent = dqn_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=FLAGS.num_episodes)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
92
option_keyboard/gpe_gpi_experiments/run_regressed_w_fig4.py
Normal file
92
option_keyboard/gpe_gpi_experiments/run_regressed_w_fig4.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
r"""Run an experiment.
|
||||
|
||||
Run GPE/GPI on task (1, -1) with w obtained by regression.
|
||||
|
||||
|
||||
For example, first train a keyboard:
|
||||
|
||||
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12 \
|
||||
--export_path=/tmp/option_keyboard/keyboard
|
||||
|
||||
|
||||
Then, evaluate the keyboard with w by regression.
|
||||
|
||||
python3 run_regressed_w_fig4.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_fig4_task_config()
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||
env=base_env,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=1e-1,),
|
||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||
)
|
||||
|
||||
experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=2,
|
||||
num_eval_reps=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
@@ -0,0 +1,105 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
r"""Run an experiment.
|
||||
|
||||
Run GPE/GPI on task (1, -1) with a learned phi model and w by regression.
|
||||
|
||||
|
||||
For example, first train a phi model with 3 dimenional phi:
|
||||
|
||||
python3 train_phi_model.py -- --logtostderr --use_random_tasks \
|
||||
--export_path=/tmp/option_keyboard/phi_model_3d --num_phis=3
|
||||
|
||||
|
||||
Then train a keyboard:
|
||||
|
||||
python3 train_keyboard_with_phi.py -- --logtostderr \
|
||||
--export_path=/tmp/option_keyboard/keyboard_3d \
|
||||
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
||||
--num_phis=2
|
||||
|
||||
|
||||
Finally, evaluate the keyboard with w by regression.
|
||||
|
||||
python3 run_regressed_w_with_phi_fig4b.py -- --logtostderr \
|
||||
--phi_model_path=/tmp/option_keyboard/phi_model_3d \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_3d/tfhub
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_string("phi_model_path", None, "Path to phi model.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_fig4_task_config()
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
base_env = environment_wrappers.EnvironmentWithLearnedPhi(
|
||||
base_env, FLAGS.phi_model_path)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||
env=base_env,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=1e-1,),
|
||||
init_w=np.random.normal(size=keyboard.num_cumulants) * 0.1,
|
||||
)
|
||||
|
||||
experiment.run(
|
||||
env,
|
||||
agent,
|
||||
num_episodes=FLAGS.num_episodes,
|
||||
report_every=2,
|
||||
num_eval_reps=100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
92
option_keyboard/gpe_gpi_experiments/run_true_w_fig4.py
Normal file
92
option_keyboard/gpe_gpi_experiments/run_true_w_fig4.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
r"""Run an experiment.
|
||||
|
||||
Run GPE/GPI on task (1, -1) with the groundtruth w.
|
||||
|
||||
|
||||
For example, first train a keyboard:
|
||||
|
||||
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12
|
||||
|
||||
|
||||
Then, evaluate the keyboard with groundtruth w.
|
||||
|
||||
python3 run_true_w_fig4.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_fig4_task_config()
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||
env=base_env,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
# Disable training.
|
||||
optimizer_kwargs=dict(learning_rate=0.0,),
|
||||
init_w=[1., -1.])
|
||||
|
||||
returns = []
|
||||
for _ in range(FLAGS.num_episodes):
|
||||
returns.append(experiment.run_episode(env, agent))
|
||||
tf.logging.info("#" * 80)
|
||||
tf.logging.info(
|
||||
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||
tf.logging.info("#" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
93
option_keyboard/gpe_gpi_experiments/run_true_w_fig6.py
Normal file
93
option_keyboard/gpe_gpi_experiments/run_true_w_fig6.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
r"""Run an experiment.
|
||||
|
||||
Run GPE/GPI on the "balancing" task with a fixed w
|
||||
|
||||
|
||||
For example, first train a keyboard:
|
||||
|
||||
python3 train_keyboard.py -- --logtostderr --policy_weights_name=12
|
||||
|
||||
|
||||
Then, evaluate the keyboard with a fixed w.
|
||||
|
||||
python3 run_true_w_fig4.py -- --logtostderr \
|
||||
--keyboard_path=/tmp/option_keyboard/keyboard_12/tfhub
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
from option_keyboard.gpe_gpi_experiments import regressed_agent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 1000, "Number of training episodes.")
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to keyboard model.")
|
||||
flags.DEFINE_list("test_w", [], "The w to test.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(FLAGS.keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_task_config()
|
||||
base_env = scavenger.Scavenger(**base_env_config)
|
||||
base_env = environment_wrappers.EnvironmentWithLogging(base_env)
|
||||
|
||||
# Wrap the task environment with the keyboard.
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboardDirect(
|
||||
env=base_env,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = regressed_agent.Agent(
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
# Disable training.
|
||||
optimizer_kwargs=dict(learning_rate=0.0,),
|
||||
init_w=[float(x) for x in FLAGS.test_w])
|
||||
|
||||
returns = []
|
||||
for _ in range(FLAGS.num_episodes):
|
||||
returns.append(experiment.run_episode(env, agent))
|
||||
tf.logging.info("#" * 80)
|
||||
tf.logging.info(
|
||||
f"Avg. return over {FLAGS.num_episodes} episodes is {np.mean(returns)}")
|
||||
tf.logging.info("#" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
65
option_keyboard/gpe_gpi_experiments/train_keyboard.py
Normal file
65
option_keyboard/gpe_gpi_experiments/train_keyboard.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train a keyboard."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import keyboard_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||
"Number of pretraining episodes.")
|
||||
flags.DEFINE_string("export_path", None,
|
||||
"Where to save the keyboard checkpoints.")
|
||||
flags.DEFINE_string("policy_weights_name", None,
|
||||
"A string repsenting the policy weights.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
all_policy_weights = {
|
||||
"1": [1., 0.],
|
||||
"2": [0., 1.],
|
||||
"3": [1., -1.],
|
||||
"4": [-1., 1.],
|
||||
"5": [1., 1.],
|
||||
}
|
||||
if FLAGS.policy_weights_name:
|
||||
policy_weights = np.array(
|
||||
[all_policy_weights[v] for v in FLAGS.policy_weights_name])
|
||||
num_episodes = ((FLAGS.num_pretrain_episodes // 2) *
|
||||
max(2, len(policy_weights)))
|
||||
export_path = FLAGS.export_path + "_" + FLAGS.policy_weights_name
|
||||
else:
|
||||
policy_weights = None
|
||||
num_episodes = FLAGS.num_pretrain_episodes
|
||||
export_path = FLAGS.export_path
|
||||
|
||||
keyboard_utils.create_and_train_keyboard(
|
||||
num_episodes=num_episodes,
|
||||
policy_weights=policy_weights,
|
||||
export_path=export_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
@@ -0,0 +1,49 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train a keyboard."""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import keyboard_utils
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||
"Number of pretraining episodes.")
|
||||
flags.DEFINE_integer("num_phis", None, "Size of phi")
|
||||
flags.DEFINE_string("phi_model_path", None,
|
||||
"Where to load the phi model checkpoints.")
|
||||
flags.DEFINE_string("export_path", None,
|
||||
"Where to save the keyboard checkpoints.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
keyboard_utils.create_and_train_keyboard_with_phi(
|
||||
num_episodes=FLAGS.num_pretrain_episodes,
|
||||
phi_model_path=FLAGS.phi_model_path,
|
||||
policy_weights=np.eye(FLAGS.num_phis, dtype=np.float32),
|
||||
export_path=FLAGS.export_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
270
option_keyboard/gpe_gpi_experiments/train_phi_model.py
Normal file
270
option_keyboard/gpe_gpi_experiments/train_phi_model.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train simple phi model."""
|
||||
|
||||
import collections
|
||||
import random
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tree
|
||||
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_phis", 2, "Dimensionality of phis.")
|
||||
flags.DEFINE_integer("num_train_steps", 2000, "Number of training steps.")
|
||||
flags.DEFINE_integer("num_replay_steps", 500, "Number of replay steps.")
|
||||
flags.DEFINE_integer("min_replay_size", 1000,
|
||||
"Minimum replay size before starting training.")
|
||||
flags.DEFINE_integer("num_train_repeats", 10, "Number of training repeats.")
|
||||
flags.DEFINE_float("learning_rate", 3e-3, "Learning rate.")
|
||||
flags.DEFINE_bool("use_random_tasks", False, "Use random tasks.")
|
||||
flags.DEFINE_string("normalisation", "L2",
|
||||
"Normalisation method for cumulant weights.")
|
||||
flags.DEFINE_string("export_path", None, "Export path.")
|
||||
|
||||
|
||||
StepOutput = collections.namedtuple("StepOutput",
|
||||
["obs", "actions", "rewards", "next_obs"])
|
||||
|
||||
|
||||
def collect_experience(env, num_episodes, verbose=False):
|
||||
"""Collect experience."""
|
||||
|
||||
num_actions = env.action_spec().maximum + 1
|
||||
|
||||
observations = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_observations = []
|
||||
|
||||
for _ in range(num_episodes):
|
||||
timestep = env.reset()
|
||||
episode_return = 0
|
||||
while not timestep.last():
|
||||
action = np.random.randint(num_actions)
|
||||
observations.append(timestep.observation)
|
||||
actions.append(action)
|
||||
|
||||
timestep = env.step(action)
|
||||
rewards.append(timestep.observation["aux_tasks_reward"])
|
||||
episode_return += timestep.reward
|
||||
|
||||
next_observations.append(timestep.observation)
|
||||
|
||||
if verbose:
|
||||
logging.info("Total return for episode: %f", episode_return)
|
||||
|
||||
observation_spec = tree.map_structure(lambda _: None, observations[0])
|
||||
|
||||
def stack_observations(obs_list):
|
||||
obs_list = [
|
||||
np.stack(obs) for obs in zip(*[tree.flatten(obs) for obs in obs_list])
|
||||
]
|
||||
obs_dict = tree.unflatten_as(observation_spec, obs_list)
|
||||
obs_dict.pop("aux_tasks_reward")
|
||||
return obs_dict
|
||||
|
||||
observations = stack_observations(observations)
|
||||
actions = np.array(actions, dtype=np.int32)
|
||||
rewards = np.stack(rewards)
|
||||
next_observations = stack_observations(next_observations)
|
||||
|
||||
return StepOutput(observations, actions, rewards, next_observations)
|
||||
|
||||
|
||||
class PhiModel(snt.AbstractModule):
|
||||
"""A model for learning phi."""
|
||||
|
||||
def __init__(self,
|
||||
n_actions,
|
||||
n_phis,
|
||||
network_kwargs,
|
||||
final_activation="sigmoid",
|
||||
name="PhiModel"):
|
||||
super(PhiModel, self).__init__(name=name)
|
||||
self._n_actions = n_actions
|
||||
self._n_phis = n_phis
|
||||
self._network_kwargs = network_kwargs
|
||||
self._final_activation = final_activation
|
||||
|
||||
def _build(self, observation, actions):
|
||||
obs = observation["arena"]
|
||||
|
||||
n_outputs = self._n_actions * self._n_phis
|
||||
flat_obs = snt.BatchFlatten()(obs)
|
||||
net = snt.nets.MLP(**self._network_kwargs)(flat_obs)
|
||||
net = snt.Linear(output_size=n_outputs)(net)
|
||||
net = snt.BatchReshape((self._n_actions, self._n_phis))(net)
|
||||
|
||||
indices = tf.stack([tf.range(tf.shape(actions)[0]), actions], axis=1)
|
||||
values = tf.gather_nd(net, indices)
|
||||
if self._final_activation:
|
||||
values = getattr(tf.nn, self._final_activation)(values)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def create_ph(tensor):
|
||||
return tf.placeholder(shape=(None,) + tensor.shape[1:], dtype=tensor.dtype)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
if FLAGS.use_random_tasks:
|
||||
tasks = np.random.normal(size=(8, 2))
|
||||
else:
|
||||
tasks = [
|
||||
[1.0, 0.0],
|
||||
[0.0, 1.0],
|
||||
[-1.0, 0.0],
|
||||
[0.0, -1.0],
|
||||
[0.7, 0.3],
|
||||
[-0.3, -0.7],
|
||||
]
|
||||
|
||||
if FLAGS.normalisation == "L1":
|
||||
tasks /= np.sum(np.abs(tasks), axis=-1, keepdims=True)
|
||||
elif FLAGS.normalisation == "L2":
|
||||
tasks /= np.linalg.norm(tasks, axis=-1, keepdims=True)
|
||||
else:
|
||||
raise ValueError("Unknown normlisation_method {}".format(
|
||||
FLAGS.normalisation))
|
||||
|
||||
logging.info("Tasks: %s", tasks)
|
||||
|
||||
env_config = dict(
|
||||
arena_size=11,
|
||||
num_channels=2,
|
||||
max_num_steps=100,
|
||||
num_init_objects=10,
|
||||
object_priors=[1.0, 1.0],
|
||||
egocentric=True,
|
||||
default_w=None,
|
||||
aux_tasks_w=tasks)
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
num_actions = env.action_spec().maximum + 1
|
||||
|
||||
model_config = dict(
|
||||
n_actions=num_actions,
|
||||
n_phis=FLAGS.num_phis,
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
)
|
||||
model = smart_module.SmartModuleExport(lambda: PhiModel(**model_config))
|
||||
|
||||
dummy_steps = collect_experience(env, num_episodes=10, verbose=True)
|
||||
num_rewards = dummy_steps.rewards.shape[-1]
|
||||
|
||||
# Placeholders
|
||||
steps_ph = tree.map_structure(create_ph, dummy_steps)
|
||||
|
||||
phis = model(steps_ph.obs, steps_ph.actions)
|
||||
phis_to_rewards = snt.Linear(
|
||||
num_rewards, initializers=dict(w=tf.zeros), use_bias=False)
|
||||
preds = phis_to_rewards(phis)
|
||||
loss_per_batch = tf.square(preds - steps_ph.rewards)
|
||||
loss_op = tf.reduce_mean(loss_per_batch)
|
||||
|
||||
replay = []
|
||||
|
||||
# Optimizer and train op.
|
||||
with tf.variable_scope("optimizer"):
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
||||
train_op = optimizer.minimize(loss_op)
|
||||
# Add normalisation of weights in phis_to_rewards
|
||||
if FLAGS.normalisation == "L1":
|
||||
w_norm = tf.reduce_sum(tf.abs(phis_to_rewards.w), axis=0, keepdims=True)
|
||||
elif FLAGS.normalisation == "L2":
|
||||
w_norm = tf.norm(phis_to_rewards.w, axis=0, keepdims=True)
|
||||
else:
|
||||
raise ValueError("Unknown normlisation_method {}".format(
|
||||
FLAGS.normalisation))
|
||||
|
||||
normalise_w = tf.assign(phis_to_rewards.w,
|
||||
phis_to_rewards.w / tf.maximum(w_norm, 1e-6))
|
||||
|
||||
def filter_steps(steps):
|
||||
mask = np.sum(np.abs(steps.rewards), axis=-1) > 0.1
|
||||
nonzero_inds = np.where(mask)[0]
|
||||
zero_inds = np.where(np.logical_not(mask))[0]
|
||||
zero_inds = np.random.choice(
|
||||
zero_inds, size=len(nonzero_inds), replace=False)
|
||||
selected_inds = np.concatenate([nonzero_inds, zero_inds])
|
||||
selected_steps = tree.map_structure(lambda x: x[selected_inds], steps)
|
||||
return selected_steps, selected_inds
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
step = 0
|
||||
while step < FLAGS.num_train_steps:
|
||||
step += 1
|
||||
steps_output = collect_experience(env, num_episodes=10)
|
||||
selected_step_outputs, selected_inds = filter_steps(steps_output)
|
||||
|
||||
if len(replay) > FLAGS.min_replay_size:
|
||||
# Do training.
|
||||
for _ in range(FLAGS.num_train_repeats):
|
||||
train_samples = random.choices(replay, k=128)
|
||||
train_samples = tree.map_structure(
|
||||
lambda *x: np.stack(x, axis=0), *train_samples)
|
||||
train_samples = tree.unflatten_as(steps_ph, train_samples)
|
||||
feed_dict = dict(
|
||||
zip(tree.flatten(steps_ph), tree.flatten(train_samples)))
|
||||
_, train_loss = sess.run([train_op, loss_op], feed_dict=feed_dict)
|
||||
sess.run(normalise_w)
|
||||
|
||||
# Do evaluation.
|
||||
if step % 50 == 0:
|
||||
feed_dict = dict(
|
||||
zip(tree.flatten(steps_ph), tree.flatten(selected_step_outputs)))
|
||||
eval_loss = sess.run(loss_op, feed_dict=feed_dict)
|
||||
logging.info("Step %d, train loss %f, eval loss %f, replay %s",
|
||||
step, train_loss, eval_loss, len(replay))
|
||||
print(sess.run(phis_to_rewards.get_variables())[0].T)
|
||||
|
||||
values = dict(step=step, train_loss=train_loss, eval_loss=eval_loss)
|
||||
logging.info(values)
|
||||
|
||||
# Add to replay.
|
||||
if step <= FLAGS.num_replay_steps:
|
||||
def select_fn(ind):
|
||||
return lambda x: x[ind]
|
||||
for idx in range(len(selected_inds)):
|
||||
replay.append(
|
||||
tree.flatten(
|
||||
tree.map_structure(select_fn(idx), selected_step_outputs)))
|
||||
|
||||
# Export trained model.
|
||||
if FLAGS.export_path:
|
||||
model.export(FLAGS.export_path, sess, overwrite=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.disable_v2_behavior()
|
||||
app.run(main)
|
||||
@@ -16,10 +16,14 @@
|
||||
# ============================================================================
|
||||
"""Keyboard agent."""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from option_keyboard import smart_module
|
||||
|
||||
|
||||
class Agent():
|
||||
"""An Option Keyboard Agent."""
|
||||
@@ -51,6 +55,7 @@ class Agent():
|
||||
optimizer_kwargs: Keyword arguments for the optimizer.
|
||||
"""
|
||||
|
||||
tf.logging.info(policy_weights)
|
||||
self._policy_weights = tf.convert_to_tensor(
|
||||
policy_weights, dtype=tf.float32)
|
||||
self._current_policy = None
|
||||
@@ -61,13 +66,16 @@ class Agent():
|
||||
|
||||
self._n_actions = action_spec.num_values
|
||||
self._n_policies, self._n_cumulants = policy_weights.shape
|
||||
self._network = OptionValueNet(
|
||||
self._n_policies,
|
||||
self._n_cumulants,
|
||||
self._n_actions,
|
||||
network_kwargs=network_kwargs,
|
||||
)
|
||||
|
||||
def create_network():
|
||||
return OptionValueNet(
|
||||
self._n_policies,
|
||||
self._n_cumulants,
|
||||
self._n_actions,
|
||||
network_kwargs=network_kwargs,
|
||||
)
|
||||
|
||||
self._network = smart_module.SmartModuleExport(create_network)
|
||||
self._replay = []
|
||||
|
||||
obs_spec = self._extract_observation(obs_spec)
|
||||
@@ -103,6 +111,12 @@ class Agent():
|
||||
td_error = tf.stop_gradient(c_t + g * qa_t) - qa_tm1
|
||||
loss = tf.reduce_sum(tf.square(td_error) / 2)
|
||||
|
||||
# Dummy calls to keyboard for SmartModule
|
||||
_ = self._network.gpi(o_tm1[0], c_t[0])
|
||||
_ = self._network.num_cumulants
|
||||
_ = self._network.num_policies
|
||||
_ = self._network.num_actions
|
||||
|
||||
with tf.variable_scope("optimizer"):
|
||||
self._optimizer = getattr(tf.train, optimizer_name)(**optimizer_kwargs)
|
||||
train_op = self._optimizer.minimize(loss)
|
||||
@@ -155,7 +169,9 @@ class Agent():
|
||||
|
||||
def export(self, path):
|
||||
tf.logging.info("Exporting keyboard to %s", path)
|
||||
self._saver.save(self._session, path)
|
||||
self._network.export(
|
||||
os.path.join(path, "tfhub"), self._session, overwrite=True)
|
||||
self._saver.save(self._session, os.path.join(path, "checkpoints"))
|
||||
|
||||
|
||||
class OptionValueNet(snt.AbstractModule):
|
||||
|
||||
88
option_keyboard/keyboard_utils.py
Normal file
88
option_keyboard/keyboard_utils.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Keyboard utils."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import keyboard_agent
|
||||
from option_keyboard import scavenger
|
||||
|
||||
|
||||
def create_and_train_keyboard(num_episodes,
|
||||
policy_weights=None,
|
||||
export_path=None):
|
||||
"""Train an option keyboard."""
|
||||
if policy_weights is None:
|
||||
policy_weights = np.eye(2, dtype=np.float32)
|
||||
|
||||
env_config = configs.get_pretrain_config()
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
|
||||
agent = keyboard_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
policy_weights=policy_weights,
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
if num_episodes:
|
||||
experiment.run(env, agent, num_episodes=num_episodes)
|
||||
agent.export(export_path)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def create_and_train_keyboard_with_phi(num_episodes,
|
||||
phi_model_path,
|
||||
policy_weights,
|
||||
export_path=None):
|
||||
"""Train an option keyboard."""
|
||||
env_config = configs.get_pretrain_config()
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
env = environment_wrappers.EnvironmentWithLearnedPhi(env, phi_model_path)
|
||||
|
||||
agent = keyboard_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
policy_weights=policy_weights,
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
if num_episodes:
|
||||
experiment.run(env, agent, num_episodes=num_episodes)
|
||||
agent.export(export_path)
|
||||
|
||||
return agent
|
||||
@@ -1,6 +1,9 @@
|
||||
absl-py
|
||||
dm-env==1.2
|
||||
dm-sonnet==1.34
|
||||
dm-tree
|
||||
numpy==1.16.4
|
||||
tensorflow==1.13.2
|
||||
tensorflow_hub==0.7.0
|
||||
tensorflow_probability==0.6.0
|
||||
wrapt
|
||||
|
||||
@@ -16,60 +16,44 @@
|
||||
# ============================================================================
|
||||
"""Run an experiment."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from option_keyboard import configs
|
||||
from option_keyboard import dqn_agent
|
||||
from option_keyboard import environment_wrappers
|
||||
from option_keyboard import experiment
|
||||
from option_keyboard import keyboard_agent
|
||||
from option_keyboard import keyboard_utils
|
||||
from option_keyboard import scavenger
|
||||
from option_keyboard import smart_module
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("num_episodes", 10000, "Number of training episodes.")
|
||||
flags.DEFINE_integer("num_pretrain_episodes", 20000,
|
||||
"Number of pretraining episodes.")
|
||||
|
||||
|
||||
def _train_keyboard(num_episodes):
|
||||
"""Train an option keyboard."""
|
||||
env_config = configs.get_pretrain_config()
|
||||
env = scavenger.Scavenger(**env_config)
|
||||
env = environment_wrappers.EnvironmentWithLogging(env)
|
||||
|
||||
agent = keyboard_agent.Agent(
|
||||
obs_spec=env.observation_spec(),
|
||||
action_spec=env.action_spec(),
|
||||
policy_weights=np.array([
|
||||
[1.0, 0.0],
|
||||
[0.0, 1.0],
|
||||
]),
|
||||
network_kwargs=dict(
|
||||
output_sizes=(64, 128),
|
||||
activate_final=True,
|
||||
),
|
||||
epsilon=0.1,
|
||||
additional_discount=0.9,
|
||||
batch_size=10,
|
||||
optimizer_name="AdamOptimizer",
|
||||
optimizer_kwargs=dict(learning_rate=3e-4,))
|
||||
|
||||
experiment.run(env, agent, num_episodes=num_episodes)
|
||||
|
||||
return agent
|
||||
flags.DEFINE_string("keyboard_path", None, "Path to pretrained keyboard model.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Pretrain the keyboard and save a checkpoint.
|
||||
pretrain_agent = _train_keyboard(num_episodes=FLAGS.num_pretrain_episodes)
|
||||
keyboard_ckpt_path = "/tmp/option_keyboard/keyboard.ckpt"
|
||||
pretrain_agent.export(keyboard_ckpt_path)
|
||||
if FLAGS.keyboard_path:
|
||||
keyboard_path = FLAGS.keyboard_path
|
||||
else:
|
||||
with tf.Graph().as_default():
|
||||
export_path = "/tmp/option_keyboard/keyboard"
|
||||
_ = keyboard_utils.create_and_train_keyboard(
|
||||
num_episodes=FLAGS.num_pretrain_episodes, export_path=export_path)
|
||||
keyboard_path = os.path.join(export_path, "tfhub")
|
||||
|
||||
# Load the keyboard.
|
||||
keyboard = smart_module.SmartModuleImport(hub.Module(keyboard_path))
|
||||
|
||||
# Create the task environment.
|
||||
base_env_config = configs.get_task_config()
|
||||
@@ -80,11 +64,11 @@ def main(argv):
|
||||
additional_discount = 0.9
|
||||
env = environment_wrappers.EnvironmentWithKeyboard(
|
||||
env=base_env,
|
||||
keyboard=pretrain_agent.keyboard,
|
||||
keyboard_ckpt_path=keyboard_ckpt_path,
|
||||
keyboard=keyboard,
|
||||
keyboard_ckpt_path=None,
|
||||
n_actions_per_dim=3,
|
||||
additional_discount=additional_discount,
|
||||
call_and_return=True)
|
||||
call_and_return=False)
|
||||
|
||||
# Create the player agent.
|
||||
agent = dqn_agent.Agent(
|
||||
|
||||
@@ -56,7 +56,8 @@ class Scavenger(auto_reset_environment.Base):
|
||||
num_init_objects=15,
|
||||
object_priors=None,
|
||||
egocentric=True,
|
||||
rewarder=None):
|
||||
rewarder=None,
|
||||
aux_tasks_w=None):
|
||||
self._arena_size = arena_size
|
||||
self._num_channels = num_channels
|
||||
self._max_num_steps = max_num_steps
|
||||
@@ -64,12 +65,13 @@ class Scavenger(auto_reset_environment.Base):
|
||||
self._egocentric = egocentric
|
||||
self._rewarder = (
|
||||
getattr(this_module, rewarder)() if rewarder is not None else None)
|
||||
self._aux_tasks_w = aux_tasks_w
|
||||
|
||||
if object_priors is None:
|
||||
self._object_priors = np.ones(num_channels) / num_channels
|
||||
else:
|
||||
assert len(object_priors) == num_channels
|
||||
self._object_priors = np.array(object_priors)
|
||||
self._object_priors = np.array(object_priors) / np.sum(object_priors)
|
||||
|
||||
if default_w is None:
|
||||
self._default_w = np.ones(shape=(num_channels,))
|
||||
@@ -203,10 +205,15 @@ class Scavenger(auto_reset_environment.Base):
|
||||
|
||||
collected_resources = np.copy(self._prev_collected).astype(np.float32)
|
||||
|
||||
return dict(
|
||||
obs = dict(
|
||||
arena=arena,
|
||||
cumulants=collected_resources,
|
||||
)
|
||||
if self._aux_tasks_w is not None:
|
||||
obs["aux_tasks_reward"] = np.dot(
|
||||
np.array(self._aux_tasks_w), self._prev_collected).astype(np.float32)
|
||||
|
||||
return obs
|
||||
|
||||
def observation_spec(self):
|
||||
arena = dm_env.specs.BoundedArray(
|
||||
@@ -222,10 +229,19 @@ class Scavenger(auto_reset_environment.Base):
|
||||
maximum=1e9,
|
||||
name="collected_resources")
|
||||
|
||||
return dict(
|
||||
obs_spec = dict(
|
||||
arena=arena,
|
||||
cumulants=collected_resources,
|
||||
)
|
||||
if self._aux_tasks_w is not None:
|
||||
obs_spec["aux_tasks_reward"] = dm_env.specs.BoundedArray(
|
||||
shape=(len(self._aux_tasks_w),),
|
||||
dtype=np.float32,
|
||||
minimum=-1e9,
|
||||
maximum=1e9,
|
||||
name="aux_tasks_reward")
|
||||
|
||||
return obs_spec
|
||||
|
||||
def action_spec(self):
|
||||
return dm_env.specs.DiscreteArray(num_values=len(Action), name="action")
|
||||
|
||||
228
option_keyboard/smart_module.py
Normal file
228
option_keyboard/smart_module.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Lint as: python3
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Smart module export/import utilities."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_hub as hub
|
||||
import tree as nest
|
||||
import wrapt
|
||||
|
||||
|
||||
_ALLOWED_TYPES = (bool, float, int, str)
|
||||
|
||||
|
||||
def _getcallargs(signature, *args, **kwargs):
|
||||
bound_args = signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
inputs = bound_args.arguments
|
||||
inputs.pop("self", None)
|
||||
return inputs
|
||||
|
||||
|
||||
def _to_placeholder(arg):
|
||||
if arg is None or isinstance(arg, bool):
|
||||
return arg
|
||||
|
||||
arg = tf.convert_to_tensor(arg)
|
||||
return tf.placeholder(dtype=arg.dtype, shape=arg.shape)
|
||||
|
||||
|
||||
class SmartModuleExport(object):
|
||||
"""Helper class for exporting TF-Hub modules."""
|
||||
|
||||
def __init__(self, object_factory):
|
||||
self._object_factory = object_factory
|
||||
self._wrapped_object = self._object_factory()
|
||||
self._variable_scope = tf.get_variable_scope()
|
||||
self._captured_calls = {}
|
||||
self._captured_attrs = {}
|
||||
|
||||
def _create_captured_method(self, method_name):
|
||||
"""Creates a wrapped method that captures its inputs."""
|
||||
with tf.variable_scope(self._variable_scope):
|
||||
method_ = getattr(self._wrapped_object, method_name)
|
||||
|
||||
@wrapt.decorator
|
||||
def wrapper(method, instance, args, kwargs):
|
||||
"""Wrapped method to capture inputs."""
|
||||
del instance
|
||||
|
||||
specs = inspect.signature(method)
|
||||
inputs = _getcallargs(specs, *args, **kwargs)
|
||||
|
||||
with tf.variable_scope(self._variable_scope):
|
||||
output = method(*args, **kwargs)
|
||||
|
||||
self._captured_calls[method_name] = [inputs, specs]
|
||||
|
||||
return output
|
||||
|
||||
return wrapper(method_) # pylint: disable=no-value-for-parameter
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Helper method for accessing an attributes of the wrapped object."""
|
||||
# if "_wrapped_object" not in self.__dict__:
|
||||
# return super(ExportableModule, self).__getattr__(name)
|
||||
|
||||
with tf.variable_scope(self._variable_scope):
|
||||
attr = getattr(self._wrapped_object, name)
|
||||
|
||||
if inspect.ismethod(attr) or inspect.isfunction(attr):
|
||||
return self._create_captured_method(name)
|
||||
else:
|
||||
if all([isinstance(v, _ALLOWED_TYPES) for v in nest.flatten(attr)]):
|
||||
self._captured_attrs[name] = attr
|
||||
return attr
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._create_captured_method("__call__")(*args, **kwargs)
|
||||
|
||||
def export(self, path, session, overwrite=False):
|
||||
"""Build the TF-Hub spec, module and sync ops."""
|
||||
|
||||
method_specs = {}
|
||||
|
||||
def module_fn():
|
||||
"""A module_fn for use with hub.create_module_spec()."""
|
||||
# We will use a copy of the original object to build the graph.
|
||||
wrapped_object = self._object_factory()
|
||||
|
||||
for method_name, method_info in self._captured_calls.items():
|
||||
captured_inputs, captured_specs = method_info
|
||||
tensor_inputs = nest.map_structure(_to_placeholder, captured_inputs)
|
||||
method_to_call = getattr(wrapped_object, method_name)
|
||||
tensor_outputs = method_to_call(**tensor_inputs)
|
||||
|
||||
flat_tensor_inputs = nest.flatten(tensor_inputs)
|
||||
flat_tensor_inputs = {
|
||||
str(k): v for k, v in zip(
|
||||
range(len(flat_tensor_inputs)), flat_tensor_inputs)
|
||||
}
|
||||
flat_tensor_outputs = nest.flatten(tensor_outputs)
|
||||
flat_tensor_outputs = {
|
||||
str(k): v for k, v in zip(
|
||||
range(len(flat_tensor_outputs)), flat_tensor_outputs)
|
||||
}
|
||||
|
||||
method_specs[method_name] = dict(
|
||||
specs=captured_specs,
|
||||
inputs=nest.map_structure(lambda _: None, tensor_inputs),
|
||||
outputs=nest.map_structure(lambda _: None, tensor_outputs))
|
||||
|
||||
signature_name = ("default"
|
||||
if method_name == "__call__" else method_name)
|
||||
hub.add_signature(signature_name, flat_tensor_inputs,
|
||||
flat_tensor_outputs)
|
||||
|
||||
hub.attach_message(
|
||||
"methods", tf.train.BytesList(value=[pickle.dumps(method_specs)]))
|
||||
hub.attach_message(
|
||||
"properties",
|
||||
tf.train.BytesList(value=[pickle.dumps(self._captured_attrs)]))
|
||||
|
||||
# Create the spec that will be later used in export.
|
||||
hub_spec = hub.create_module_spec(module_fn, drop_collections=["sonnet"])
|
||||
|
||||
# Get variables values
|
||||
module_weights = [
|
||||
session.run(v) for v in self._wrapped_object.get_all_variables()
|
||||
]
|
||||
|
||||
# create the sync ops
|
||||
with tf.Graph().as_default():
|
||||
hub_module = hub.Module(hub_spec, trainable=True, name="hub")
|
||||
|
||||
assign_ops = []
|
||||
assign_phs = []
|
||||
for _, v in sorted(hub_module.variable_map.items()):
|
||||
ph = tf.placeholder(shape=v.shape, dtype=v.dtype)
|
||||
assign_phs.append(ph)
|
||||
assign_ops.append(tf.assign(v, ph))
|
||||
|
||||
with tf.Session() as module_session:
|
||||
module_session.run(tf.local_variables_initializer())
|
||||
module_session.run(tf.global_variables_initializer())
|
||||
module_session.run(
|
||||
assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))
|
||||
|
||||
if overwrite and os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path)
|
||||
hub_module.export(path, module_session)
|
||||
|
||||
|
||||
class SmartModuleImport(object):
|
||||
"""A class for importing graph building objects from TF-Hub modules."""
|
||||
|
||||
def __init__(self, module):
|
||||
self._module = module
|
||||
self._method_specs = pickle.loads(
|
||||
self._module.get_attached_message("methods",
|
||||
tf.train.BytesList).value[0])
|
||||
self._properties = pickle.loads(
|
||||
self._module.get_attached_message("properties",
|
||||
tf.train.BytesList).value[0])
|
||||
|
||||
def _create_wrapped_method(self, method):
|
||||
"""Creates a wrapped method that converts nested inputs and outputs."""
|
||||
|
||||
def wrapped_method(*args, **kwargs):
|
||||
"""A wrapped method around a TF-Hub module signature."""
|
||||
|
||||
inputs = _getcallargs(self._method_specs[method]["specs"], *args,
|
||||
**kwargs)
|
||||
nest.assert_same_structure(self._method_specs[method]["inputs"], inputs)
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
flat_inputs = {
|
||||
str(k): v for k, v in zip(range(len(flat_inputs)), flat_inputs)
|
||||
}
|
||||
|
||||
signature = "default" if method == "__call__" else method
|
||||
flat_outputs = self._module(
|
||||
flat_inputs, signature=signature, as_dict=True)
|
||||
flat_outputs = [v for _, v in sorted(flat_outputs.items())]
|
||||
|
||||
output_spec = self._method_specs[method]["outputs"]
|
||||
if output_spec is None:
|
||||
if len(flat_outputs) != 1:
|
||||
raise ValueError(
|
||||
"Expected output containing a single tensor, found {}".format(
|
||||
flat_outputs))
|
||||
outputs = flat_outputs[0]
|
||||
else:
|
||||
outputs = nest.unflatten_as(output_spec, flat_outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
return wrapped_method
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self._method_specs:
|
||||
return self._create_wrapped_method(name)
|
||||
|
||||
if name in self._properties:
|
||||
return self._properties[name]
|
||||
|
||||
return getattr(self._module, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._create_wrapped_method("__call__")(*args, **kwargs)
|
||||
Reference in New Issue
Block a user