Files
deepmind-research/rl_unplugged/rwrl_d4pg.ipynb
Piotr Stanczyk f30fae8dbc Internal change
PiperOrigin-RevId: 431903167
2022-05-26 17:42:33 +01:00

613 lines
19 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KDiJzbb8KFvP"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
"this file except in compliance with the License. You may obtain a copy of the\n",
"License at\n",
"\n",
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed\n",
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
"specific language governing permissions and limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zzJlIvx4tnrM"
},
"source": [
"# RL Unplugged: Offline D4PG - RWRL\n",
"\n",
"## Guide to training an Acme D4PG agent on RWRL data.\n",
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/rwrl_d4pg.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "L6EObDZat-6b"
},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GxZogj3guCzN"
},
"outputs": [],
"source": [
"!pip install dm-acme\n",
"!pip install dm-acme[reverb]\n",
"!pip install dm-acme[tf]\n",
"!pip install dm-sonnet"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pDil_q8cDjov"
},
"source": [
"### MuJoCo\n",
"\n",
"More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "VEEj3Qw60y73"
},
"source": [
"#### Institutional MuJoCo license."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "IbZxYDxzoz5R"
},
"outputs": [],
"source": [
"#@title Edit and run\n",
"mjkey = \"\"\"\n",
"\n",
"REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
"\n",
"\"\"\".strip()\n",
"\n",
"mujoco_dir = \"$HOME/.mujoco\"\n",
"\n",
"# Install OpenGL deps\n",
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
"\n",
"# Fetch MuJoCo binaries from Roboti\n",
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
"\n",
"# Copy over MuJoCo license\n",
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
"\n",
"# Configure dm_control to use the OSMesa rendering backend\n",
"%env MUJOCO_GL=osmesa\n",
"\n",
"# Install dm_control\n",
"!pip install dm_control"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-_7tVg-zzjzW"
},
"source": [
"#### Machine-locked MuJoCo license."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "ZvCGB5LSr-30"
},
"outputs": [],
"source": [
"#@title Add your MuJoCo License and run\n",
"mjkey = \"\"\"\n",
"\"\"\".strip()\n",
"\n",
"mujoco_dir = \"$HOME/.mujoco\"\n",
"\n",
"# Install OpenGL dependencies\n",
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
"\n",
"# Get MuJoCo binaries\n",
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
"\n",
"# Copy over MuJoCo license\n",
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
"\n",
"# Configure dm_control to use the OSMesa rendering backend\n",
"%env MUJOCO_GL=osmesa\n",
"\n",
"# Install dm_control, including extra dependencies needed for the locomotion\n",
"# mazes.\n",
"!pip install dm_control[locomotion_mazes]"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ZrszOaSSMGiF"
},
"source": [
"### RWRL"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Oe5D7pEQAq82"
},
"outputs": [],
"source": [
"!git clone https://github.com/google-research/realworldrl_suite.git\n",
"!pip install realworldrl_suite/"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jSr-8p2BA1wl"
},
"source": [
"### RL Unplugged"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DGCc6SBaAnyb"
},
"outputs": [],
"source": [
"!git clone https://github.com/deepmind/deepmind-research.git\n",
"%cd deepmind-research"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IE2nV9Hivnv5"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RI7NgnJIvs4s"
},
"outputs": [],
"source": [
"import collections\n",
"import copy\n",
"from typing import Mapping, Sequence\n",
"\n",
"import acme\n",
"from acme import specs\n",
"from acme.agents.tf import actors\n",
"from acme.agents.tf import d4pg\n",
"from acme.tf import networks\n",
"from acme.tf import utils as tf2_utils\n",
"from acme.utils import loggers\n",
"from acme.wrappers import single_precision\n",
"from acme.tf import utils as tf2_utils\n",
"import numpy as np\n",
"import realworldrl_suite.environments as rwrl_envs\n",
"from reverb import replay_sample\n",
"import six\n",
"from rl_unplugged import rwrl\n",
"import sonnet as snt\n",
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "a2PCwF3bwBII"
},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "VaEJbXjampPy"
},
"outputs": [],
"source": [
"domain_name = 'cartpole' #@param\n",
"task_name = 'swingup' #@param\n",
"difficulty = 'easy' #@param\n",
"combined_challenge = 'easy' #@param\n",
"combined_challenge_str = str(combined_challenge).lower()\n",
"\n",
"tmp_path = '/tmp/rwrl'\n",
"gs_path = f'gs://rl_unplugged/rwrl'\n",
"data_path = (f'combined_challenge_{combined_challenge_str}/{domain_name}/'\n",
" f'{task_name}/offline_rl_challenge_{difficulty}')\n",
"\n",
"!mkdir -p {tmp_path}/{data_path}\n",
"!gsutil cp -r {gs_path}/{data_path}/* {tmp_path}/{data_path}\n",
"\n",
"num_shards_str, = !ls {tmp_path}/{data_path}/* | wc -l\n",
"num_shards = int(num_shards_str)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mQ1as51Mww7X"
},
"source": [
"## Dataset and environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "K8HnIIi3ywYC"
},
"outputs": [],
"source": [
"#@title Auxiliary functions\n",
"\n",
"def flatten_observation(observation):\n",
" \"\"\"Flattens multiple observation arrays into a single tensor.\n",
"\n",
" Args:\n",
" observation: A mutable mapping from observation names to tensors.\n",
"\n",
" Returns:\n",
" A flattened and concatenated observation array.\n",
"\n",
" Raises:\n",
" ValueError: If `observation` is not a `collections.MutableMapping`.\n",
" \"\"\"\n",
" if not isinstance(observation, collections.MutableMapping):\n",
" raise ValueError('Can only flatten dict-like observations.')\n",
"\n",
" if isinstance(observation, collections.OrderedDict):\n",
" keys = six.iterkeys(observation)\n",
" else:\n",
" # Keep a consistent ordering for other mappings.\n",
" keys = sorted(six.iterkeys(observation))\n",
"\n",
" observation_arrays = [tf.reshape(observation[key], [-1]) for key in keys]\n",
" return tf.concat(observation_arrays, 0)\n",
"\n",
"\n",
"def preprocess_fn(sample):\n",
" o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]\n",
" o_tm1 = flatten_observation(o_tm1)\n",
" o_t = flatten_observation(o_t)\n",
" return replay_sample.ReplaySample(\n",
" info=sample.info, data=(o_tm1, a_tm1, r_t, d_t, o_t))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5kHzJpfcw306"
},
"outputs": [],
"source": [
"batch_size = 10 #@param\n",
"\n",
"environment = rwrl_envs.load(\n",
" domain_name=domain_name,\n",
" task_name=f'realworld_{task_name}',\n",
" environment_kwargs=dict(log_safety_vars=False, flat_observation=True),\n",
" combined_challenge=combined_challenge)\n",
"environment = single_precision.SinglePrecisionWrapper(environment)\n",
"environment_spec = specs.make_environment_spec(environment)\n",
"act_spec = environment_spec.actions\n",
"obs_spec = environment_spec.observations\n",
"\n",
"dataset = rwrl.dataset(\n",
" tmp_path,\n",
" combined_challenge=combined_challenge_str,\n",
" domain=domain_name,\n",
" task=task_name,\n",
" difficulty=difficulty,\n",
" num_shards=num_shards,\n",
" shuffle_buffer_size=10)\n",
"dataset = dataset.map(preprocess_fn).batch(batch_size)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "adb0cyE5qu9G"
},
"source": [
"## D4PG learner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "CriIaelxqwQg"
},
"outputs": [],
"source": [
"#@title Auxiliary functions\n",
"\n",
"def make_networks(\n",
" action_spec: specs.BoundedArray,\n",
" hidden_size: int = 1024,\n",
" num_blocks: int = 4,\n",
" num_mixtures: int = 5,\n",
" vmin: float = -150.,\n",
" vmax: float = 150.,\n",
" num_atoms: int = 51,\n",
"):\n",
" \"\"\"Creates networks used by the agent.\"\"\"\n",
" num_dimensions = np.prod(action_spec.shape, dtype=int)\n",
"\n",
" policy_network = snt.Sequential([\n",
" networks.LayerNormAndResidualMLP(\n",
" hidden_size=hidden_size, num_blocks=num_blocks),\n",
" # Converts the policy output into the same shape as the action spec.\n",
" snt.Linear(num_dimensions),\n",
" # Note that TanhToSpec applies tanh to the input.\n",
" networks.TanhToSpec(action_spec)\n",
" ])\n",
" # The multiplexer concatenates the (maybe transformed) observations/actions.\n",
" critic_network = snt.Sequential([\n",
" networks.CriticMultiplexer(\n",
" critic_network=networks.LayerNormAndResidualMLP(\n",
" hidden_size=hidden_size, num_blocks=num_blocks),\n",
" observation_network=tf2_utils.batch_concat),\n",
" networks.DiscreteValuedHead(vmin, vmax, num_atoms)\n",
" ])\n",
"\n",
" return {\n",
" 'policy': policy_network,\n",
" 'critic': critic_network,\n",
" }\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "83naOY7a_A4I"
},
"outputs": [],
"source": [
"# Create the networks to optimize.\n",
"online_networks = make_networks(act_spec)\n",
"target_networks = copy.deepcopy(online_networks)\n",
"\n",
"# Create variables.\n",
"tf2_utils.create_variables(online_networks['policy'], [obs_spec])\n",
"tf2_utils.create_variables(online_networks['critic'], [obs_spec, act_spec])\n",
"tf2_utils.create_variables(target_networks['policy'], [obs_spec])\n",
"tf2_utils.create_variables(target_networks['critic'], [obs_spec, act_spec])\n",
"\n",
"# The learner updates the parameters (and initializes them).\n",
"learner = d4pg.D4PGLearner(\n",
" policy_network=online_networks['policy'],\n",
" critic_network=online_networks['critic'],\n",
" target_policy_network=target_networks['policy'],\n",
" target_critic_network=target_networks['critic'],\n",
" dataset=dataset,\n",
" discount=0.99,\n",
" target_update_period=100)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PYkjKaduy_xj"
},
"source": [
"## Training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 136
},
"colab_type": "code",
"executionInfo": {
"elapsed": 15293,
"status": "ok",
"timestamp": 1593617156917,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": -60
},
"id": "HbQOyCG4zCwa",
"outputId": "7a1055c4-4256-41b9-92a6-6fb05af89725"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Learner] Critic Loss = 4.016 | Policy Loss = 0.500 | Steps = 1 | Walltime = 0\n",
"[Learner] Critic Loss = 3.851 | Policy Loss = 0.279 | Steps = 16 | Walltime = 2.165\n",
"[Learner] Critic Loss = 3.832 | Policy Loss = 0.190 | Steps = 32 | Walltime = 3.216\n",
"[Learner] Critic Loss = 3.744 | Policy Loss = 0.223 | Steps = 48 | Walltime = 4.262\n",
"[Learner] Critic Loss = 3.782 | Policy Loss = 0.287 | Steps = 64 | Walltime = 5.305\n",
"[Learner] Critic Loss = 3.799 | Policy Loss = 0.315 | Steps = 80 | Walltime = 6.353\n",
"[Learner] Critic Loss = 3.796 | Policy Loss = 0.230 | Steps = 96 | Walltime = 7.397\n"
]
}
],
"source": [
"for _ in range(100):\n",
" learner.step()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "LJ_XsuQSzFSV"
},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 102
},
"colab_type": "code",
"executionInfo": {
"elapsed": 21830,
"status": "ok",
"timestamp": 1593617178762,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": -60
},
"id": "blvNCANKb22J",
"outputId": "b05b7272-78cc-4bef-fe64-b95cc6fe5688"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Evaluation] Episode Length = 1000 | Episode Return = 135.244 | Episodes = 1 | Steps = 1000 | Steps Per Second = 222.692\n",
"[Evaluation] Episode Length = 1000 | Episode Return = 136.972 | Episodes = 2 | Steps = 2000 | Steps Per Second = 233.214\n",
"[Evaluation] Episode Length = 1000 | Episode Return = 136.173 | Episodes = 3 | Steps = 3000 | Steps Per Second = 229.496\n",
"[Evaluation] Episode Length = 1000 | Episode Return = 146.131 | Episodes = 4 | Steps = 4000 | Steps Per Second = 230.199\n",
"[Evaluation] Episode Length = 1000 | Episode Return = 137.818 | Episodes = 5 | Steps = 5000 | Steps Per Second = 232.834\n"
]
}
],
"source": [
"# Create a logger.\n",
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
"\n",
"# Create an environment loop.\n",
"loop = acme.EnvironmentLoop(\n",
" environment=environment,\n",
" actor=actors.DeprecatedFeedForwardActor(online_networks['policy']),\n",
" logger=logger)\n",
"\n",
"loop.run(5)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"VEEj3Qw60y73",
"-_7tVg-zzjzW"
],
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "RL Unplugged: Offline D4PG - RWRL",
"provenance": [
{
"file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ",
"timestamp": 1593080049369
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}