mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-06 17:18:46 +08:00
673 lines
23 KiB
Plaintext
673 lines
23 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pdgOfM42e7in"
|
|
},
|
|
"source": [
|
|
"Copyright 2021 DeepMind Technologies Limited.\n",
|
|
"\n",
|
|
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use\n",
|
|
"this file except in compliance with the License. You may obtain a copy of the\n",
|
|
"License at\n",
|
|
"\n",
|
|
"[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n",
|
|
"\n",
|
|
"Unless required by applicable law or agreed to in writing, software distributed\n",
|
|
"under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n",
|
|
"CONDITIONS OF ANY KIND, either express or implied. See the License for the\n",
|
|
"specific language governing permissions and limitations under the License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "WOzmAie8e-NK"
|
|
},
|
|
"source": [
|
|
"# RL Unplugged: Offline R2D2 - DeepMind Lab\n",
|
|
"\n",
|
|
"## A Colab example of an Acme R2D2 agent on DeepMind Lab data.\n",
|
|
"# \u003ca href=\"https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/rl_unplugged/dmlab_r2d2.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tr2MoADAQepq"
|
|
},
|
|
"source": [
|
|
"## Installation\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "SvWeEWGd5Nx_"
|
|
},
|
|
"source": [
|
|
"### External dependencies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "fTuqZxDv4v0y"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!apt-get install libsdl2-dev\n",
|
|
"!apt-get install libosmesa6-dev\n",
|
|
"!apt-get install libffi-dev\n",
|
|
"!apt-get install gettext\n",
|
|
"!apt-get install python3-numpy-dev python3-dev"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ewPoBUDd04xh"
|
|
},
|
|
"source": [
|
|
"### Bazel"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "ewVV3-Oh0sBm"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"BAZEL_VERSION = '3.6.0'\n",
|
|
"!wget https://github.com/bazelbuild/bazel/releases/download/{BAZEL_VERSION}/bazel-{BAZEL_VERSION}-installer-linux-x86_64.sh\n",
|
|
"!chmod +x bazel-{BAZEL_VERSION}-installer-linux-x86_64.sh\n",
|
|
"!./bazel-{BAZEL_VERSION}-installer-linux-x86_64.sh\n",
|
|
"!bazel --version"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "fwdmJWW3KB7g"
|
|
},
|
|
"source": [
|
|
"### DeepMind Lab"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Ng9xopirzVYA"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!git clone https://github.com/deepmind/lab.git"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "NeO57QYqDG-L"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%%writefile lab/bazel/python.BUILD\n",
|
|
"\n",
|
|
"# Description:\n",
|
|
"# Build rule for Python and Numpy.\n",
|
|
"# This rule works for Debian and Ubuntu. Other platforms might keep the\n",
|
|
"# headers in different places, cf. 'How to build DeepMind Lab' in build.md.\n",
|
|
"\n",
|
|
"cc_library(\n",
|
|
" name = \"python\",\n",
|
|
" hdrs = select(\n",
|
|
" {\n",
|
|
" \"@bazel_tools//tools/python:PY3\": glob([\n",
|
|
" \"usr/include/python3.6m/*.h\",\n",
|
|
" \"usr/local/lib/python3.6/dist-packages/numpy/core/include/numpy/*.h\",\n",
|
|
" ]),\n",
|
|
" },\n",
|
|
" no_match_error = \"Internal error, Python version should be one of PY2 or PY3\",\n",
|
|
" ),\n",
|
|
" includes = select(\n",
|
|
" {\n",
|
|
" \"@bazel_tools//tools/python:PY3\": [\n",
|
|
" \"usr/include/python3.6m\",\n",
|
|
" \"usr/local/lib/python3.6/dist-packages/numpy/core/include\",\n",
|
|
" ],\n",
|
|
" },\n",
|
|
" no_match_error = \"Internal error, Python version should be one of PY2 or PY3\",\n",
|
|
" ),\n",
|
|
" visibility = [\"//visibility:public\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"alias(\n",
|
|
" name = \"python_headers\",\n",
|
|
" actual = \":python\",\n",
|
|
" visibility = [\"//visibility:public\"],\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "pRuLCRzpzX8E"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!cd lab \u0026\u0026 bazel build -c opt --python_version=PY3 //python/pip_package:build_pip_package"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Oen2E99T0E58"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!cd lab \u0026\u0026 ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "okrzzmrC0H_O"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install /tmp/dmlab_pkg/deepmind_lab-1.0-py3-none-any.whl --force-reinstall"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "hq-VGgvbRKSI"
|
|
},
|
|
"source": [
|
|
"### Python dependencies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "8Fme7zxOKejg"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install dm_env\n",
|
|
"!pip install dm-acme[reverb]\n",
|
|
"!pip install dm-acme[tf]\n",
|
|
"!pip install dm-sonnet"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Rfd4jQGFt-HB"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Upgrade to recent commit for latest R2D2 learner.\n",
|
|
"!pip install --upgrade git+https://github.com/deepmind/acme.git@3dfda9d392312d948906e6c567c7f56d8c911de5"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "DvicrJPBqemz"
|
|
},
|
|
"source": [
|
|
"## Imports and Utils"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "_8qxA0KLU468"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Imports\n",
|
|
"import copy\n",
|
|
"import functools\n",
|
|
"\n",
|
|
"from acme import environment_loop\n",
|
|
"from acme import specs\n",
|
|
"from acme.adders import reverb as acme_reverb\n",
|
|
"from acme.agents.tf import actors\n",
|
|
"from acme.agents.tf.r2d2 import learning as r2d2\n",
|
|
"from acme.tf import utils as tf_utils\n",
|
|
"from acme.tf import networks\n",
|
|
"from acme.utils import loggers\n",
|
|
"from acme.wrappers import observation_action_reward\n",
|
|
"import tree\n",
|
|
"\n",
|
|
"import deepmind_lab\n",
|
|
"import dm_env\n",
|
|
"import numpy as np\n",
|
|
"import reverb\n",
|
|
"import sonnet as snt\n",
|
|
"import tensorflow as tf\n",
|
|
"import trfl"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "becmQVMMuCRU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Environment\n",
|
|
"\n",
|
|
"_ACTION_MAP = {\n",
|
|
" 0: (0, 0, 0, 1, 0, 0, 0),\n",
|
|
" 1: (0, 0, 0, -1, 0, 0, 0),\n",
|
|
" 2: (0, 0, -1, 0, 0, 0, 0),\n",
|
|
" 3: (0, 0, 1, 0, 0, 0, 0),\n",
|
|
" 4: (-10, 0, 0, 0, 0, 0, 0),\n",
|
|
" 5: (10, 0, 0, 0, 0, 0, 0),\n",
|
|
" 6: (-60, 0, 0, 0, 0, 0, 0),\n",
|
|
" 7: (60, 0, 0, 0, 0, 0, 0),\n",
|
|
" 8: (0, 10, 0, 0, 0, 0, 0),\n",
|
|
" 9: (0, -10, 0, 0, 0, 0, 0),\n",
|
|
" 10: (-10, 0, 0, 1, 0, 0, 0),\n",
|
|
" 11: (10, 0, 0, 1, 0, 0, 0),\n",
|
|
" 12: (-60, 0, 0, 1, 0, 0, 0),\n",
|
|
" 13: (60, 0, 0, 1, 0, 0, 0),\n",
|
|
" 14: (0, 0, 0, 0, 1, 0, 0),\n",
|
|
"}\n",
|
|
"\n",
|
|
"class DeepMindLabEnvironment(dm_env.Environment):\n",
|
|
" \"\"\"DeepMind Lab environment.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, level_name: str, action_repeats: int = 4):\n",
|
|
" \"\"\"Construct environment.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" level_name: DeepMind lab level name (e.g. 'rooms_watermaze').\n",
|
|
" action_repeats: Number of times the same action is repeated on every\n",
|
|
" step().\n",
|
|
" \"\"\"\n",
|
|
" config = dict(fps='30',\n",
|
|
" height='72',\n",
|
|
" width='96',\n",
|
|
" maxAltCameraHeight='1',\n",
|
|
" maxAltCameraWidth='1',\n",
|
|
" hasAltCameras='false')\n",
|
|
"\n",
|
|
" # seekavoid_arena_01 is not part of dmlab30.\n",
|
|
" if level_name != 'seekavoid_arena_01':\n",
|
|
" level_name = 'contributed/dmlab30/{}'.format(level_name)\n",
|
|
"\n",
|
|
" self._lab = deepmind_lab.Lab(level_name, ['RGB_INTERLEAVED'], config)\n",
|
|
" self._action_repeats = action_repeats\n",
|
|
" self._reward = 0\n",
|
|
"\n",
|
|
" def _observation(self):\n",
|
|
" last_action = getattr(self, '_action', 0)\n",
|
|
" last_reward = getattr(self, '_reward', 0)\n",
|
|
" self._last_observation = observation_action_reward.OAR(\n",
|
|
" observation=self._lab.observations()['RGB_INTERLEAVED'],\n",
|
|
" action=np.array(last_action, dtype=np.int64),\n",
|
|
" reward=np.array(last_reward, dtype=np.float32))\n",
|
|
" return self._last_observation\n",
|
|
"\n",
|
|
" def reset(self):\n",
|
|
" self._lab.reset()\n",
|
|
" return dm_env.restart(self._observation())\n",
|
|
"\n",
|
|
" def step(self, action):\n",
|
|
" if not self._lab.is_running():\n",
|
|
" return dm_env.restart(self.reset())\n",
|
|
"\n",
|
|
" self._action = action.item()\n",
|
|
" if self._action not in _ACTION_MAP:\n",
|
|
" raise ValueError('Action not available')\n",
|
|
" lab_action = np.array(_ACTION_MAP[self._action], dtype=np.intc)\n",
|
|
" self._reward = self._lab.step(lab_action, num_steps=self._action_repeats)\n",
|
|
"\n",
|
|
" if self._lab.is_running():\n",
|
|
" return dm_env.transition(self._reward, self._observation())\n",
|
|
" return dm_env.termination(self._reward, self._last_observation)\n",
|
|
"\n",
|
|
" def observation_spec(self):\n",
|
|
" return observation_action_reward.OAR(\n",
|
|
" observation=dm_env.specs.Array(shape=(72, 96, 3), dtype=np.uint8),\n",
|
|
" action=dm_env.specs.Array(shape=(), dtype=np.int64),\n",
|
|
" reward=dm_env.specs.Array(shape=(), dtype=np.float32))\n",
|
|
"\n",
|
|
" def action_spec(self):\n",
|
|
" return dm_env.specs.DiscreteArray(num_values=15, dtype=np.int64)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"cellView": "form",
|
|
"id": "4ms1TBjDSXr0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# @title Dataset\n",
|
|
"\n",
|
|
"def _decode_images(pngs):\n",
|
|
" \"\"\"Decode tensor of PNGs.\"\"\"\n",
|
|
" decode_rgb_png = functools.partial(tf.io.decode_png, channels=3)\n",
|
|
" images = tf.map_fn(decode_rgb_png, pngs, dtype=tf.uint8,\n",
|
|
" parallel_iterations=10)\n",
|
|
" # [N, 72, 96, 3]\n",
|
|
" images.set_shape((pngs.shape[0], 72, 96, 3))\n",
|
|
" return images\n",
|
|
"\n",
|
|
"def _tf_example_to_step_ds(tf_example: tf.train.Example,\n",
|
|
" episode_length: int) -\u003e reverb.ReplaySample:\n",
|
|
" \"\"\"Create a Reverb replay sample from a TF example.\"\"\"\n",
|
|
"\n",
|
|
" # Parse tf.Example.\n",
|
|
" def sequence_feature(shape, dtype=tf.float32):\n",
|
|
" return tf.io.FixedLenFeature(shape=[episode_length] + shape, dtype=dtype)\n",
|
|
"\n",
|
|
" feature_description = {\n",
|
|
" 'episode_id': tf.io.FixedLenFeature([], tf.int64),\n",
|
|
" 'start_idx': tf.io.FixedLenFeature([], tf.int64),\n",
|
|
" 'episode_return': tf.io.FixedLenFeature([], tf.float32),\n",
|
|
" 'observations_pixels': sequence_feature([], tf.string),\n",
|
|
" 'observations_reward': sequence_feature([]),\n",
|
|
" # actions are one-hot arrays.\n",
|
|
" 'observations_action': sequence_feature([15]),\n",
|
|
" 'actions': sequence_feature([], tf.int64),\n",
|
|
" 'rewards': sequence_feature([]),\n",
|
|
" 'discounted_rewards': sequence_feature([]),\n",
|
|
" 'discounts': sequence_feature([]),\n",
|
|
" }\n",
|
|
"\n",
|
|
" data = tf.io.parse_single_example(tf_example, feature_description)\n",
|
|
" pixels = _decode_images(data['observations_pixels'])\n",
|
|
"\n",
|
|
" observation = observation_action_reward.OAR(\n",
|
|
" observation=pixels,\n",
|
|
" action=tf.argmax(data['observations_action'],\n",
|
|
" axis=1, output_type=tf.int64),\n",
|
|
" reward=data['observations_reward'])\n",
|
|
"\n",
|
|
" data = acme_reverb.Step(\n",
|
|
" observation=observation,\n",
|
|
" action=data['actions'],\n",
|
|
" reward=data['rewards'],\n",
|
|
" discount=data['discounts'],\n",
|
|
" start_of_episode=tf.zeros((episode_length,), tf.bool),\n",
|
|
" extras={})\n",
|
|
"\n",
|
|
" # Keys are all zero and probabilities are all one.\n",
|
|
" info = reverb.SampleInfo(key=tf.zeros((episode_length,), tf.int64),\n",
|
|
" probability=tf.ones((episode_length,), tf.float32),\n",
|
|
" table_size=tf.zeros((episode_length,), tf.int64),\n",
|
|
" priority=tf.ones((episode_length,), tf.float32))\n",
|
|
" sample = reverb.ReplaySample(info=info, data=data)\n",
|
|
" return tf.data.Dataset.from_tensor_slices(sample)\n",
|
|
"\n",
|
|
"def subsequences(step_ds: tf.data.Dataset,\n",
|
|
" length: int, shift: int = 1\n",
|
|
" ) -\u003e tf.data.Dataset:\n",
|
|
" \"\"\"Dataset of subsequences from a dataset of episode steps.\"\"\"\n",
|
|
" window_ds = step_ds.window(length, shift=shift, stride=1)\n",
|
|
" return window_ds.interleave(_nest_ds).batch(length, drop_remainder=True)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _nest_ds(nested_ds: tf.data.Dataset) -\u003e tf.data.Dataset:\n",
|
|
" \"\"\"Produces a dataset of nests from a nest of datasets of the same size.\"\"\"\n",
|
|
" flattened_ds = tuple(tree.flatten(nested_ds))\n",
|
|
" zipped_ds = tf.data.Dataset.zip(flattened_ds)\n",
|
|
" return zipped_ds.map(lambda *x: tree.unflatten_as(nested_ds, x))\n",
|
|
"\n",
|
|
"\n",
|
|
"def make_dataset(path: str,\n",
|
|
" episode_length: int,\n",
|
|
" sequence_length: int,\n",
|
|
" sequence_shift: int,\n",
|
|
" num_shards: int = 500) -\u003e tf.data.Dataset:\n",
|
|
" \"\"\"Create dataset of DeepMind Lab sequences.\"\"\"\n",
|
|
"\n",
|
|
" filenames = [f'{path}/tfrecord-{i:05d}-of-{num_shards:05d}'\n",
|
|
" for i in range(num_shards)]\n",
|
|
" file_ds = tf.data.Dataset.from_tensor_slices(filenames)\n",
|
|
" file_ds = file_ds.repeat().shuffle(num_shards)\n",
|
|
" tfrecord_dataset = functools.partial(tf.data.TFRecordDataset,\n",
|
|
" compression_type='GZIP')\n",
|
|
"\n",
|
|
" # Dataset of tf.Examples containing full episodes.\n",
|
|
" example_ds = file_ds.interleave(tfrecord_dataset)\n",
|
|
"\n",
|
|
" # Dataset of episodes, each represented as a dataset of steps.\n",
|
|
" _tf_example_to_step_ds_with_length = functools.partial(\n",
|
|
" _tf_example_to_step_ds, episode_length=episode_length)\n",
|
|
" episode_ds = example_ds.map(_tf_example_to_step_ds_with_length,\n",
|
|
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
|
|
"\n",
|
|
" # Dataset of sequences.\n",
|
|
" training_sequences = functools.partial(subsequences, length=sequence_length,\n",
|
|
" shift=sequence_shift)\n",
|
|
" return episode_ds.interleave(training_sequences)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "sV2vXWAsU5Zg"
|
|
},
|
|
"source": [
|
|
"## Experiment"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "0F-l-4LolX1c"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# task | episode length | run\n",
|
|
"# ----------------------------------------------------------------------------\n",
|
|
"# seekavoid_arena_01 | 301 | training_{0..2}\n",
|
|
"# seekavoid_arena_01 | 301 | snapshot_{0..1}_eps_0.0\n",
|
|
"# seekavoid_arena_01 | 301 | snapshot_{0..1}_eps_0.01\n",
|
|
"# seekavoid_arena_01 | 301 | snapshot_{0..1}_eps_0.1\n",
|
|
"# seekavoid_arena_01 | 301 | snapshot_{0..1}_eps_0.25\n",
|
|
"# explore_object_rewards_few | 1351 | training_{0..2}\n",
|
|
"# explore_object_rewards_many | 1801 | training_{0..2}\n",
|
|
"# rooms_select_nonmatching_object | 181 | training_{0..2}\n",
|
|
"# rooms_watermaze | 1801 | training_{0..2}\n",
|
|
"\n",
|
|
"TASK = 'seekavoid_arena_01'\n",
|
|
"RUN = 'training_0'\n",
|
|
"EPISODE_LENGTH = 301\n",
|
|
"BATCH_SIZE = 1\n",
|
|
"DATASET_PATH = f'gs://rl_unplugged/dmlab/{TASK}/{RUN}'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "H7YN_qwDVPqQ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"environment = DeepMindLabEnvironment(TASK, action_repeats=2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "-B25Lcgt8JD4"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataset = make_dataset(DATASET_PATH, num_shards=500,\n",
|
|
" episode_length=EPISODE_LENGTH,\n",
|
|
" sequence_length=120,\n",
|
|
" sequence_shift=40)\n",
|
|
"dataset = dataset.padded_batch(BATCH_SIZE, drop_remainder=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gTO61WolqkzG"
|
|
},
|
|
"source": [
|
|
"### Learning"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "cBFmIYxTtBg4"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create network.\n",
|
|
"def process_observations(x):\n",
|
|
" return x._replace(observation=tf.image.convert_image_dtype(x.observation, tf.float32))\n",
|
|
"\n",
|
|
"environment_spec = specs.make_environment_spec(environment)\n",
|
|
"num_actions = environment_spec.actions.maximum + 1\n",
|
|
"network = snt.DeepRNN([\n",
|
|
" process_observations,\n",
|
|
" networks.R2D2AtariNetwork(num_actions=num_actions)\n",
|
|
"])\n",
|
|
"tf_utils.create_variables(network, [environment_spec.observations])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "QLoAU2zwwi3X"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='learner', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create the R2D2 learner.\n",
|
|
"learner = r2d2.R2D2Learner(\n",
|
|
" environment_spec=environment_spec,\n",
|
|
" network=network,\n",
|
|
" target_network=copy.deepcopy(network),\n",
|
|
" discount=0.99,\n",
|
|
" learning_rate=1e-4,\n",
|
|
" importance_sampling_exponent=0.2,\n",
|
|
" target_update_period=100,\n",
|
|
" burn_in_length=0,\n",
|
|
" sequence_length=120,\n",
|
|
" store_lstm_state=False,\n",
|
|
" dataset=dataset,\n",
|
|
" logger=logger)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "lqMgZS9UWfWl"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for _ in range(5):\n",
|
|
" learner.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "eMpO7eBeqmZn"
|
|
},
|
|
"source": [
|
|
"### Evaluation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "orUDJVmpA0lU"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a logger.\n",
|
|
"logger = loggers.TerminalLogger(label='evaluator', time_delta=1.)\n",
|
|
"\n",
|
|
"# Create evaluation loop.\n",
|
|
"eval_network = snt.DeepRNN([\n",
|
|
" network,\n",
|
|
" lambda q: trfl.epsilon_greedy(q, epsilon=0.4**8).sample(),\n",
|
|
"])\n",
|
|
"eval_loop = environment_loop.EnvironmentLoop(\n",
|
|
" environment=environment,\n",
|
|
" actor=actors.DeprecatedRecurrentActor(policy_network=eval_network),\n",
|
|
" logger=logger)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "6FDsfWVXCcYZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"eval_loop.run(2)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"collapsed_sections": [
|
|
"tr2MoADAQepq",
|
|
"SvWeEWGd5Nx_",
|
|
"ewPoBUDd04xh",
|
|
"fwdmJWW3KB7g"
|
|
],
|
|
"name": "RL Unplugged: Offline R2D2 - DeepMind Lab",
|
|
"provenance": [
|
|
{
|
|
"file_id": "1vgfEtkThYTNWHhi3pisuRFxgmoMniuQz",
|
|
"timestamp": 1605722818242
|
|
},
|
|
{
|
|
"file_id": "/v2/external/notebooks/intro.ipynb",
|
|
"timestamp": 1602763830869
|
|
}
|
|
]
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|