mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 17:18:46 +08:00
613 lines
19 KiB
Plaintext
613 lines
19 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "KDiJzbb8KFvP"
|
|
},
|
|
"source": [
|
|
"Copyright 2020 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
|
"this file except in compliance with the License. You may obtain a copy of the\n",
|
|
"License at\n",
|
|
"\n",
|
|
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
|
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
|
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
|
"specific language governing permissions and limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "zzJlIvx4tnrM"
|
|
},
|
|
"source": [
|
|
"# RL Unplugged: Offline D4PG - RWRL\n",
|
|
"\n",
|
|
"## Guide to training an Acme D4PG agent on RWRL data.\n",
|
|
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/rl_unplugged/rwrl_d4pg.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "L6EObDZat-6b"
|
|
},
|
|
"source": [
|
|
"## Installation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "GxZogj3guCzN"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install dm-acme\n",
|
|
"!pip install dm-acme[reverb]\n",
|
|
"!pip install dm-acme[tf]\n",
|
|
"!pip install dm-sonnet"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "pDil_q8cDjov"
|
|
},
|
|
"source": [
|
|
"### MuJoCo\n",
|
|
"\n",
|
|
"More detailed instructions in [this tutorial](https://colab.research.google.com/github/deepmind/dm_control/blob/master/tutorial.ipynb#scrollTo=YvyGCsgSCxHQ)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "VEEj3Qw60y73"
|
|
},
|
|
"source": [
|
|
"#### Institutional MuJoCo license."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "IbZxYDxzoz5R"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Edit and run\n",
|
|
"mjkey = \"\"\"\n",
|
|
"\n",
|
|
"REPLACE THIS LINE WITH YOUR MUJOCO LICENSE KEY\n",
|
|
"\n",
|
|
"\"\"\".strip()\n",
|
|
"\n",
|
|
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
|
"\n",
|
|
"# Install OpenGL deps\n",
|
|
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
|
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
|
"\n",
|
|
"# Fetch MuJoCo binaries from Roboti\n",
|
|
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
|
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
|
"\n",
|
|
"# Copy over MuJoCo license\n",
|
|
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
|
"\n",
|
|
"# Configure dm_control to use the OSMesa rendering backend\n",
|
|
"%env MUJOCO_GL=osmesa\n",
|
|
"\n",
|
|
"# Install dm_control\n",
|
|
"!pip install dm_control"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "-_7tVg-zzjzW"
|
|
},
|
|
"source": [
|
|
"#### Machine-locked MuJoCo license."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "ZvCGB5LSr-30"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Add your MuJoCo License and run\n",
|
|
"mjkey = \"\"\"\n",
|
|
"\"\"\".strip()\n",
|
|
"\n",
|
|
"mujoco_dir = \"$HOME/.mujoco\"\n",
|
|
"\n",
|
|
"# Install OpenGL dependencies\n",
|
|
"!apt-get update \u0026\u0026 apt-get install -y --no-install-recommends \\\n",
|
|
" libgl1-mesa-glx libosmesa6 libglew2.0\n",
|
|
"\n",
|
|
"# Get MuJoCo binaries\n",
|
|
"!wget -q https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip\n",
|
|
"!unzip -o -q mujoco.zip -d \"$mujoco_dir\"\n",
|
|
"\n",
|
|
"# Copy over MuJoCo license\n",
|
|
"!echo \"$mjkey\" \u003e \"$mujoco_dir/mjkey.txt\"\n",
|
|
"\n",
|
|
"# Configure dm_control to use the OSMesa rendering backend\n",
|
|
"%env MUJOCO_GL=osmesa\n",
|
|
"\n",
|
|
"# Install dm_control, including extra dependencies needed for the locomotion\n",
|
|
"# mazes.\n",
|
|
"!pip install dm_control[locomotion_mazes]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "ZrszOaSSMGiF"
|
|
},
|
|
"source": [
|
|
"### RWRL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Oe5D7pEQAq82"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!git clone https://github.com/google-research/realworldrl_suite.git\n",
|
|
"!pip install realworldrl_suite/"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "jSr-8p2BA1wl"
|
|
},
|
|
"source": [
|
|
"### RL Unplugged"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "DGCc6SBaAnyb"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!git clone https://github.com/deepmind/deepmind-research.git\n",
|
|
"%cd deepmind-research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "IE2nV9Hivnv5"
|
|
},
|
|
"source": [
|
|
"## Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "RI7NgnJIvs4s"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import collections\n",
|
|
"import copy\n",
|
|
"from typing import Mapping, Sequence\n",
|
|
"\n",
|
|
"import acme\n",
|
|
"from acme import specs\n",
|
|
"from acme.agents.tf import actors\n",
|
|
"from acme.agents.tf import d4pg\n",
|
|
"from acme.tf import networks\n",
|
|
"from acme.tf import utils as tf2_utils\n",
|
|
"from acme.utils import loggers\n",
|
|
"from acme.wrappers import single_precision\n",
|
|
"from acme.tf import utils as tf2_utils\n",
|
|
"import numpy as np\n",
|
|
"import realworldrl_suite.environments as rwrl_envs\n",
|
|
"from reverb import replay_sample\n",
|
|
"import six\n",
|
|
"from rl_unplugged import rwrl\n",
|
|
"import sonnet as snt\n",
|
|
"import tensorflow as tf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "a2PCwF3bwBII"
|
|
},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "VaEJbXjampPy"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"domain_name = 'cartpole' #@param\n",
|
|
"task_name = 'swingup' #@param\n",
|
|
"difficulty = 'easy' #@param\n",
|
|
"combined_challenge = 'easy' #@param\n",
|
|
"combined_challenge_str = str(combined_challenge).lower()\n",
|
|
"\n",
|
|
"tmp_path = '/tmp/rwrl'\n",
|
|
"gs_path = f'gs://rl_unplugged/rwrl'\n",
|
|
"data_path = (f'combined_challenge_{combined_challenge_str}/{domain_name}/'\n",
|
|
" f'{task_name}/offline_rl_challenge_{difficulty}')\n",
|
|
"\n",
|
|
"!mkdir -p {tmp_path}/{data_path}\n",
|
|
"!gsutil cp -r {gs_path}/{data_path}/* {tmp_path}/{data_path}\n",
|
|
"\n",
|
|
"num_shards_str, = !ls {tmp_path}/{data_path}/* | wc -l\n",
|
|
"num_shards = int(num_shards_str)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "mQ1as51Mww7X"
|
|
},
|
|
"source": [
|
|
"## Dataset and environment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "K8HnIIi3ywYC"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Auxiliary functions\n",
|
|
"\n",
|
|
"def flatten_observation(observation):\n",
|
|
" \"\"\"Flattens multiple observation arrays into a single tensor.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" observation: A mutable mapping from observation names to tensors.\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" A flattened and concatenated observation array.\n",
|
|
"\n",
|
|
" Raises:\n",
|
|
" ValueError: If `observation` is not a `collections.MutableMapping`.\n",
|
|
" \"\"\"\n",
|
|
" if not isinstance(observation, collections.MutableMapping):\n",
|
|
" raise ValueError('Can only flatten dict-like observations.')\n",
|
|
"\n",
|
|
" if isinstance(observation, collections.OrderedDict):\n",
|
|
" keys = six.iterkeys(observation)\n",
|
|
" else:\n",
|
|
" # Keep a consistent ordering for other mappings.\n",
|
|
" keys = sorted(six.iterkeys(observation))\n",
|
|
"\n",
|
|
" observation_arrays = [tf.reshape(observation[key], [-1]) for key in keys]\n",
|
|
" return tf.concat(observation_arrays, 0)\n",
|
|
"\n",
|
|
"\n",
|
|
"def preprocess_fn(sample):\n",
|
|
" o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]\n",
|
|
" o_tm1 = flatten_observation(o_tm1)\n",
|
|
" o_t = flatten_observation(o_t)\n",
|
|
" return replay_sample.ReplaySample(\n",
|
|
" info=sample.info, data=(o_tm1, a_tm1, r_t, d_t, o_t))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "5kHzJpfcw306"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 10 #@param\n",
|
|
"\n",
|
|
"environment = rwrl_envs.load(\n",
|
|
" domain_name=domain_name,\n",
|
|
" task_name=f'realworld_{task_name}',\n",
|
|
" environment_kwargs=dict(log_safety_vars=False, flat_observation=True),\n",
|
|
" combined_challenge=combined_challenge)\n",
|
|
"environment = single_precision.SinglePrecisionWrapper(environment)\n",
|
|
"environment_spec = specs.make_environment_spec(environment)\n",
|
|
"act_spec = environment_spec.actions\n",
|
|
"obs_spec = environment_spec.observations\n",
|
|
"\n",
|
|
"dataset = rwrl.dataset(\n",
|
|
" tmp_path,\n",
|
|
" combined_challenge=combined_challenge_str,\n",
|
|
" domain=domain_name,\n",
|
|
" task=task_name,\n",
|
|
" difficulty=difficulty,\n",
|
|
" num_shards=num_shards,\n",
|
|
" shuffle_buffer_size=10)\n",
|
|
"dataset = dataset.map(preprocess_fn).batch(batch_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "adb0cyE5qu9G"
|
|
},
|
|
"source": [
|
|
"## D4PG learner"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "both",
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "CriIaelxqwQg"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#@title Auxiliary functions\n",
|
|
"\n",
|
|
"def make_networks(\n",
|
|
" action_spec: specs.BoundedArray,\n",
|
|
" hidden_size: int = 1024,\n",
|
|
" num_blocks: int = 4,\n",
|
|
" num_mixtures: int = 5,\n",
|
|
" vmin: float = -150.,\n",
|
|
" vmax: float = 150.,\n",
|
|
" num_atoms: int = 51,\n",
|
|
"):\n",
|
|
" \"\"\"Creates networks used by the agent.\"\"\"\n",
|
|
" num_dimensions = np.prod(action_spec.shape, dtype=int)\n",
|
|
"\n",
|
|
" policy_network = snt.Sequential([\n",
|
|
" networks.LayerNormAndResidualMLP(\n",
|
|
" hidden_size=hidden_size, num_blocks=num_blocks),\n",
|
|
" # Converts the policy output into the same shape as the action spec.\n",
|
|
" snt.Linear(num_dimensions),\n",
|
|
" # Note that TanhToSpec applies tanh to the input.\n",
|
|
" networks.TanhToSpec(action_spec)\n",
|
|
" ])\n",
|
|
" # The multiplexer concatenates the (maybe transformed) observations/actions.\n",
|
|
" critic_network = snt.Sequential([\n",
|
|
" networks.CriticMultiplexer(\n",
|
|
" critic_network=networks.LayerNormAndResidualMLP(\n",
|
|
" hidden_size=hidden_size, num_blocks=num_blocks),\n",
|
|
" observation_network=tf2_utils.batch_concat),\n",
|
|
" networks.DiscreteValuedHead(vmin, vmax, num_atoms)\n",
|
|
" ])\n",
|
|
"\n",
|
|
" return {\n",
|
|
" 'policy': policy_network,\n",
|
|
" 'critic': critic_network,\n",
|
|
" }\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "83naOY7a_A4I"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create the networks to optimize.\n",
|
|
"online_networks = make_networks(act_spec)\n",
|
|
"target_networks = copy.deepcopy(online_networks)\n",
|
|
"\n",
|
|
"# Create variables.\n",
|
|
"tf2_utils.create_variables(online_networks['policy'], [obs_spec])\n",
|
|
"tf2_utils.create_variables(online_networks['critic'], [obs_spec, act_spec])\n",
|
|
"tf2_utils.create_variables(target_networks['policy'], [obs_spec])\n",
|
|
"tf2_utils.create_variables(target_networks['critic'], [obs_spec, act_spec])\n",
|
|
"\n",
|
|
"# The learner updates the parameters (and initializes them).\n",
|
|
"learner = d4pg.D4PGLearner(\n",
|
|
" policy_network=online_networks['policy'],\n",
|
|
" critic_network=online_networks['critic'],\n",
|
|
" target_policy_network=target_networks['policy'],\n",
|
|
" target_critic_network=target_networks['critic'],\n",
|
|
" dataset=dataset,\n",
|
|
" discount=0.99,\n",
|
|
" target_update_period=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "PYkjKaduy_xj"
|
|
},
|
|
"source": [
|
|
"## Training loop"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 136
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 15293,
|
|
"status": "ok",
|
|
"timestamp": 1593617156917,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "HbQOyCG4zCwa",
|
|
"outputId": "7a1055c4-4256-41b9-92a6-6fb05af89725"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Learner] Critic Loss = 4.016 | Policy Loss = 0.500 | Steps = 1 | Walltime = 0\n",
|
|
"[Learner] Critic Loss = 3.851 | Policy Loss = 0.279 | Steps = 16 | Walltime = 2.165\n",
|
|
"[Learner] Critic Loss = 3.832 | Policy Loss = 0.190 | Steps = 32 | Walltime = 3.216\n",
|
|
"[Learner] Critic Loss = 3.744 | Policy Loss = 0.223 | Steps = 48 | Walltime = 4.262\n",
|
|
"[Learner] Critic Loss = 3.782 | Policy Loss = 0.287 | Steps = 64 | Walltime = 5.305\n",
|
|
"[Learner] Critic Loss = 3.799 | Policy Loss = 0.315 | Steps = 80 | Walltime = 6.353\n",
|
|
"[Learner] Critic Loss = 3.796 | Policy Loss = 0.230 | Steps = 96 | Walltime = 7.397\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for _ in range(100):\n",
|
|
" learner.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "LJ_XsuQSzFSV"
|
|
},
|
|
"source": [
|
|
"## Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"height": 102
|
|
},
|
|
"colab_type": "code",
|
|
"executionInfo": {
|
|
"elapsed": 21830,
|
|
"status": "ok",
|
|
"timestamp": 1593617178762,
|
|
"user": {
|
|
"displayName": "",
|
|
"photoUrl": "",
|
|
"userId": ""
|
|
},
|
|
"user_tz": -60
|
|
},
|
|
"id": "blvNCANKb22J",
|
|
"outputId": "b05b7272-78cc-4bef-fe64-b95cc6fe5688"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 135.244 | Episodes = 1 | Steps = 1000 | Steps Per Second = 222.692\n",
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 136.972 | Episodes = 2 | Steps = 2000 | Steps Per Second = 233.214\n",
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 136.173 | Episodes = 3 | Steps = 3000 | Steps Per Second = 229.496\n",
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 146.131 | Episodes = 4 | Steps = 4000 | Steps Per Second = 230.199\n",
|
|
"[Evaluation] Episode Length = 1000 | Episode Return = 137.818 | Episodes = 5 | Steps = 5000 | Steps Per Second = 232.834\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='evaluation', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create an environment loop.\n",
|
|
"loop = acme.EnvironmentLoop(\n",
|
|
" environment=environment,\n",
|
|
" actor=actors.DeprecatedFeedForwardActor(online_networks['policy']),\n",
|
|
" logger=logger)\n",
|
|
"\n",
|
|
"loop.run(5)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"collapsed_sections": [
|
|
"VEEj3Qw60y73",
|
|
"-_7tVg-zzjzW"
|
|
],
|
|
"last_runtime": {
|
|
"build_target": "",
|
|
"kind": "local"
|
|
},
|
|
"name": "RL Unplugged: Offline D4PG - RWRL",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1OerSIsTjv4d3rQCjAsi0ljPaLan87juJ",
|
|
"timestamp": 1593080049369
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|