Files
deepmind-research/rl_unplugged/dm_control_suite.py
Rebecca Chen 0824c28deb Silence some pytype errors.
PiperOrigin-RevId: 515579955
2023-06-02 18:02:59 +01:00

831 lines
26 KiB
Python

# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Control RL Unplugged datasets.
Examples in the dataset represent sequences stored when running a partially
trained agent (trained in online way) as described in
https://arxiv.org/abs/2006.13888.
Every dataset has a SARSA version, and datasets for environments for solving
which we believe one may need a recurrent agent also include a version of the
dataset with overlapping sequences of length 40.
Datasets for the dm_control_suite environments only include proprio
observations, while datasets for dm_locomotion include both pixel and proprio
observations.
"""
import collections
import functools
import os
from typing import Dict, Optional, Tuple, Set
from acme import wrappers
from acme.adders import reverb as adders
from dm_control import composer
from dm_control import suite
from dm_control.composer.variation import colors
from dm_control.composer.variation import distributions
from dm_control.locomotion import arenas
from dm_control.locomotion import props
from dm_control.locomotion import tasks
from dm_control.locomotion import walkers
from dm_env import specs
import numpy as np
import reverb
import tensorflow as tf
import tree
def _build_rodent_escape_env():
"""Build environment where a rodent escapes from a bowl."""
walker = walkers.Rat(
observable_options={'egocentric_camera': dict(enabled=True)},
)
arena = arenas.bowl.Bowl(
size=(20., 20.),
aesthetic='outdoor_natural')
locomotion_task = tasks.escape.Escape(
walker=walker,
arena=arena,
physics_timestep=0.001,
control_timestep=.02)
raw_env = composer.Environment(
time_limit=20,
task=locomotion_task,
strip_singleton_obs_buffer_dim=True)
return raw_env
def _build_rodent_maze_env():
"""Build environment where a rodent runs to targets."""
walker = walkers.Rat(
observable_options={'egocentric_camera': dict(enabled=True)},
)
wall_textures = arenas.labmaze_textures.WallTextures(
style='style_01')
arena = arenas.mazes.RandomMazeWithTargets(
x_cells=11,
y_cells=11,
xy_scale=.5,
z_height=.3,
max_rooms=4,
room_min_size=4,
room_max_size=5,
spawns_per_room=1,
targets_per_room=3,
wall_textures=wall_textures,
aesthetic='outdoor_natural')
rodent_task = tasks.random_goal_maze.ManyGoalsMaze(
walker=walker,
maze_arena=arena,
target_builder=functools.partial(
props.target_sphere.TargetSphere,
radius=0.05,
height_above_ground=.125,
rgb1=(0, 0, 0.4),
rgb2=(0, 0, 0.7)),
target_reward_scale=50.,
contact_termination=False,
control_timestep=.02,
physics_timestep=0.001)
raw_env = composer.Environment(
time_limit=30,
task=rodent_task,
strip_singleton_obs_buffer_dim=True)
return raw_env
def _build_rodent_corridor_gaps():
"""Build environment where a rodent runs over gaps."""
walker = walkers.Rat(
observable_options={'egocentric_camera': dict(enabled=True)},
)
platform_length = distributions.Uniform(low=0.4, high=0.8)
gap_length = distributions.Uniform(low=0.05, high=0.2)
arena = arenas.corridors.GapsCorridor(
corridor_width=2,
platform_length=platform_length,
gap_length=gap_length,
corridor_length=40,
aesthetic='outdoor_natural')
rodent_task = tasks.corridors.RunThroughCorridor(
walker=walker,
arena=arena,
walker_spawn_position=(5, 0, 0),
walker_spawn_rotation=0,
target_velocity=1.0,
contact_termination=False,
terminate_at_height=-0.3,
physics_timestep=0.001,
control_timestep=.02)
raw_env = composer.Environment(
time_limit=30,
task=rodent_task,
strip_singleton_obs_buffer_dim=True)
return raw_env
def _build_rodent_two_touch_env():
"""Build environment where a rodent touches targets."""
walker = walkers.Rat(
observable_options={'egocentric_camera': dict(enabled=True)},
)
arena_floor = arenas.floors.Floor(
size=(10., 10.), aesthetic='outdoor_natural')
task_reach = tasks.reach.TwoTouch(
walker=walker,
arena=arena_floor,
target_builders=[
functools.partial(
props.target_sphere.TargetSphereTwoTouch,
radius=0.025),
],
randomize_spawn_rotation=True,
target_type_rewards=[25.],
shuffle_target_builders=False,
target_area=(1.5, 1.5),
physics_timestep=0.001,
control_timestep=.02)
raw_env = composer.Environment(
time_limit=30,
task=task_reach,
strip_singleton_obs_buffer_dim=True)
return raw_env
def _build_humanoid_walls_env():
"""Build humanoid walker walls environment."""
walker = walkers.CMUHumanoidPositionControlled(
name='walker',
observable_options={'egocentric_camera': dict(enabled=True)},
)
wall_width = distributions.Uniform(low=1, high=7)
wall_height = distributions.Uniform(low=2.5, high=4.0)
swap_wall_side = distributions.Bernoulli(prob=0.5)
wall_r = distributions.Uniform(low=0.5, high=0.6)
wall_g = distributions.Uniform(low=0.21, high=0.41)
wall_rgba = colors.RgbVariation(r=wall_r, g=wall_g, b=0, alpha=1)
arena = arenas.WallsCorridor(
wall_gap=5.0,
wall_width=wall_width,
wall_height=wall_height,
swap_wall_side=swap_wall_side,
wall_rgba=wall_rgba,
corridor_width=10,
corridor_length=100)
humanoid_task = tasks.RunThroughCorridor(
walker=walker,
arena=arena,
walker_spawn_rotation=1.57, # pi / 2
physics_timestep=0.005,
control_timestep=0.03)
raw_env = composer.Environment(
time_limit=30,
task=humanoid_task,
strip_singleton_obs_buffer_dim=True)
return raw_env
def _build_humanoid_corridor_env():
"""Build humanoid walker walls environment."""
walker = walkers.CMUHumanoidPositionControlled(
name='walker',
observable_options={'egocentric_camera': dict(enabled=True)},
)
arena = arenas.EmptyCorridor(
corridor_width=10,
corridor_length=100)
humanoid_task = tasks.RunThroughCorridor(
walker=walker,
arena=arena,
walker_spawn_rotation=1.57, # pi / 2
physics_timestep=0.005,
control_timestep=0.03)
raw_env = composer.Environment(
time_limit=30,
task=humanoid_task,
strip_singleton_obs_buffer_dim=True)
return raw_env
def _build_humanoid_corridor_gaps():
"""Build humanoid walker walls environment."""
walker = walkers.CMUHumanoidPositionControlled(
name='walker',
observable_options={'egocentric_camera': dict(enabled=True)},
)
platform_length = distributions.Uniform(low=0.3, high=2.5)
gap_length = distributions.Uniform(low=0.75, high=1.25)
arena = arenas.GapsCorridor(
corridor_width=10,
platform_length=platform_length,
gap_length=gap_length,
corridor_length=100)
humanoid_task = tasks.RunThroughCorridor(
walker=walker,
arena=arena,
walker_spawn_position=(2, 0, 0),
walker_spawn_rotation=1.57, # pi / 2
physics_timestep=0.005,
control_timestep=0.03)
raw_env = composer.Environment(
time_limit=30,
task=humanoid_task,
strip_singleton_obs_buffer_dim=True)
return raw_env
class MujocoActionNormalizer(wrappers.EnvironmentWrapper):
"""Rescale actions to [-1, 1] range for mujoco physics engine.
For control environments whose actions have bounded range in [-1, 1], this
adaptor rescale actions to the desired range. This allows actor network to
output unscaled actions for better gradient dynamics.
"""
def __init__(self, environment, rescale='clip'):
super().__init__(environment)
self._rescale = rescale
def step(self, action):
"""Rescale actions to [-1, 1] range before stepping wrapped environment."""
if self._rescale == 'tanh':
scaled_actions = tree.map_structure(np.tanh, action)
elif self._rescale == 'clip':
scaled_actions = tree.map_structure(lambda a: np.clip(a, -1., 1.), action)
else:
raise ValueError('Unrecognized scaling option: %s' % self._rescale)
return self._environment.step(scaled_actions)
class NormilizeActionSpecWrapper(wrappers.EnvironmentWrapper):
"""Turn each dimension of the actions into the range of [-1, 1]."""
def __init__(self, environment):
super().__init__(environment)
action_spec = environment.action_spec()
# pytype: disable=attribute-error # always-use-return-annotations
self._scale = action_spec.maximum - action_spec.minimum
self._offset = action_spec.minimum
minimum = action_spec.minimum * 0 - 1.
maximum = action_spec.minimum * 0 + 1.
self._action_spec = specs.BoundedArray(
action_spec.shape,
action_spec.dtype,
minimum,
maximum,
name=action_spec.name)
# pytype: enable=attribute-error # always-use-return-annotations
def _from_normal_actions(self, actions):
actions = 0.5 * (actions + 1.0) # a_t is now in the range [0, 1]
# scale range to [minimum, maximum]
return actions * self._scale + self._offset
def step(self, action):
action = self._from_normal_actions(action)
return self._environment.step(action)
def action_spec(self):
return self._action_spec
class FilterObservationsWrapper(wrappers.EnvironmentWrapper):
"""Filter out all the observations not specified to this wrapper."""
def __init__(self, environment, observations_to_keep):
super().__init__(environment)
self._observations_to_keep = observations_to_keep
spec = self._environment.observation_spec()
filtered = [(k, spec[k]) for k in observations_to_keep]
self._observation_spec = collections.OrderedDict(filtered)
def _filter_observation(self, timestep):
observation = timestep.observation
filtered = [(k, observation[k]) for k in self._observations_to_keep]
return timestep._replace(observation=collections.OrderedDict(filtered))
def step(self, action):
return self._filter_observation(self._environment.step(action))
def reset(self):
return self._filter_observation(self._environment.reset())
def observation_spec(self):
return self._observation_spec
class ControlSuite:
"""Create bits needed to run agents on an Control Suite dataset."""
def __init__(self, task_name='humanoid_run'):
"""Initializes datasets/environments for the Deepmind Control suite.
Args:
task_name: take name. Must be one of,
finger_turn_hard, manipulator_insert_peg, humanoid_run,
cartpole_swingup, cheetah_run, fish_swim, manipulator_insert_ball,
walker_stand, walker_walk
"""
self.task_name = task_name
self._uint8_features = set([])
self._environment = None
if task_name == 'swim':
self._domain_name = 'fish'
self._task_name = 'swim'
self._shapes = {
'observation/target': (3,),
'observation/velocity': (13,),
'observation/upright': (1,),
'observation/joint_angles': (7,),
'action': (5,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()
}
elif task_name == 'humanoid_run':
self._domain_name = 'humanoid'
self._task_name = 'run'
self._shapes = {
'observation/velocity': (27,),
'observation/com_velocity': (3,),
'observation/torso_vertical': (3,),
'observation/extremities': (12,),
'observation/head_height': (1,),
'observation/joint_angles': (21,),
'action': (21,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()
}
elif task_name == 'manipulator_insert_ball':
self._domain_name = 'manipulator'
self._task_name = 'insert_ball'
self._shapes = {
'observation/arm_pos': (16,),
'observation/arm_vel': (8,),
'observation/touch': (5,),
'observation/hand_pos': (4,),
'observation/object_pos': (4,),
'observation/object_vel': (3,),
'observation/target_pos': (4,),
'action': (5,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()}
elif task_name == 'manipulator_insert_peg':
self._domain_name = 'manipulator'
self._task_name = 'insert_peg'
self._shapes = {
'observation/arm_pos': (16,),
'observation/arm_vel': (8,),
'observation/touch': (5,),
'observation/hand_pos': (4,),
'observation/object_pos': (4,),
'observation/object_vel': (3,),
'observation/target_pos': (4,),
'episodic_reward': (),
'action': (5,),
'discount': (),
'reward': (),
'step_type': ()}
elif task_name == 'cartpole_swingup':
self._domain_name = 'cartpole'
self._task_name = 'swingup'
self._shapes = {
'observation/position': (3,),
'observation/velocity': (2,),
'action': (1,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()}
elif task_name == 'walker_walk':
self._domain_name = 'walker'
self._task_name = 'walk'
self._shapes = {
'observation/orientations': (14,),
'observation/velocity': (9,),
'observation/height': (1,),
'action': (6,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()}
elif task_name == 'walker_stand':
self._domain_name = 'walker'
self._task_name = 'stand'
self._shapes = {
'observation/orientations': (14,),
'observation/velocity': (9,),
'observation/height': (1,),
'action': (6,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()}
elif task_name == 'cheetah_run':
self._domain_name = 'cheetah'
self._task_name = 'run'
self._shapes = {
'observation/position': (8,),
'observation/velocity': (9,),
'action': (6,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()}
elif task_name == 'finger_turn_hard':
self._domain_name = 'finger'
self._task_name = 'turn_hard'
self._shapes = {
'observation/position': (4,),
'observation/velocity': (3,),
'observation/touch': (2,),
'observation/target_position': (2,),
'observation/dist_to_target': (1,),
'action': (2,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()}
else:
raise ValueError('Task \'{}\' not found.'.format(task_name))
self._data_path = 'dm_control_suite/{}/train'.format(task_name)
@property
def shapes(self):
return self._shapes
@property
def data_path(self):
return self._data_path
@property
def uint8_features(self):
return self._uint8_features
@property
def environment(self):
"""Build and return the environment."""
if self._environment is not None:
return self._environment
self._environment = suite.load(
domain_name=self._domain_name,
task_name=self._task_name)
self._environment = wrappers.SinglePrecisionWrapper(self._environment)
self._environment = NormilizeActionSpecWrapper(self._environment)
return self._environment
class CmuThirdParty:
"""Create bits needed to run agents on an locomotion humanoid dataset."""
def __init__(self, task_name='humanoid_walls'):
# 'humanoid_corridor|humanoid_gaps|humanoid_walls'
self._task_name = task_name
self._pixel_keys = self.get_pixel_keys()
self._uint8_features = set(['observation/walker/egocentric_camera'])
self.additional_paths = {}
self._proprio_keys = [
'walker/joints_vel',
'walker/sensors_velocimeter',
'walker/sensors_gyro',
'walker/joints_pos',
'walker/world_zaxis',
'walker/body_height',
'walker/sensors_accelerometer',
'walker/end_effectors_pos'
]
self._shapes = {
'observation/walker/joints_vel': (56,),
'observation/walker/sensors_velocimeter': (3,),
'observation/walker/sensors_gyro': (3,),
'observation/walker/joints_pos': (56,),
'observation/walker/world_zaxis': (3,),
'observation/walker/body_height': (1,),
'observation/walker/sensors_accelerometer': (3,),
'observation/walker/end_effectors_pos': (12,),
'observation/walker/egocentric_camera': (
64,
64,
3,
),
'action': (56,),
'discount': (),
'reward': (),
'episodic_reward': (),
'step_type': ()
}
if task_name == 'humanoid_corridor':
self._data_path = 'dm_locomotion/humanoid_corridor/seq2/train'
elif task_name == 'humanoid_gaps':
self._data_path = 'dm_locomotion/humanoid_gaps/seq2/train'
elif task_name == 'humanoid_walls':
self._data_path = 'dm_locomotion/humanoid_walls/seq40/train'
else:
raise ValueError('Task \'{}\' not found.'.format(task_name))
@staticmethod
def get_pixel_keys():
return ('walker/egocentric_camera',)
@property
def uint8_features(self):
return self._uint8_features
@property
def shapes(self):
return self._shapes
@property
def data_path(self):
return self._data_path
@property
def environment(self):
"""Build and return the environment."""
if self._task_name == 'humanoid_corridor':
self._environment = _build_humanoid_corridor_env()
elif self._task_name == 'humanoid_gaps':
self._environment = _build_humanoid_corridor_gaps()
elif self._task_name == 'humanoid_walls':
self._environment = _build_humanoid_walls_env()
self._environment = NormilizeActionSpecWrapper(self._environment)
self._environment = MujocoActionNormalizer(
environment=self._environment, rescale='clip')
self._environment = wrappers.SinglePrecisionWrapper(self._environment)
all_observations = list(self._proprio_keys) + list(self._pixel_keys)
self._environment = FilterObservationsWrapper(self._environment,
all_observations)
return self._environment
class Rodent:
"""Create bits needed to run agents on an Rodent dataset."""
def __init__(self, task_name='rodent_gaps'):
# 'rodent_escape|rodent_two_touch|rodent_gaps|rodent_mazes'
self._task_name = task_name
self._pixel_keys = self.get_pixel_keys()
self._uint8_features = set(['observation/walker/egocentric_camera'])
self._proprio_keys = [
'walker/joints_pos', 'walker/joints_vel', 'walker/tendons_pos',
'walker/tendons_vel', 'walker/appendages_pos', 'walker/world_zaxis',
'walker/sensors_accelerometer', 'walker/sensors_velocimeter',
'walker/sensors_gyro', 'walker/sensors_touch',
]
self._shapes = {
'observation/walker/joints_pos': (30,),
'observation/walker/joints_vel': (30,),
'observation/walker/tendons_pos': (8,),
'observation/walker/tendons_vel': (8,),
'observation/walker/appendages_pos': (15,),
'observation/walker/world_zaxis': (3,),
'observation/walker/sensors_accelerometer': (3,),
'observation/walker/sensors_velocimeter': (3,),
'observation/walker/sensors_gyro': (3,),
'observation/walker/sensors_touch': (4,),
'observation/walker/egocentric_camera': (64, 64, 3),
'action': (38,),
'discount': (),
'reward': (),
'step_type': ()
}
if task_name == 'rodent_gaps':
self._data_path = 'dm_locomotion/rodent_gaps/seq2/train'
elif task_name == 'rodent_escape':
self._data_path = 'dm_locomotion/rodent_bowl_escape/seq2/train'
elif task_name == 'rodent_two_touch':
self._data_path = 'dm_locomotion/rodent_two_touch/seq40/train'
elif task_name == 'rodent_mazes':
self._data_path = 'dm_locomotion/rodent_mazes/seq40/train'
else:
raise ValueError('Task \'{}\' not found.'.format(task_name))
@staticmethod
def get_pixel_keys():
return ('walker/egocentric_camera',)
@property
def shapes(self):
return self._shapes
@property
def uint8_features(self):
return self._uint8_features
@property
def data_path(self):
return self._data_path
@property
def environment(self):
"""Return environment."""
if self._task_name == 'rodent_escape':
self._environment = _build_rodent_escape_env()
elif self._task_name == 'rodent_gaps':
self._environment = _build_rodent_corridor_gaps()
elif self._task_name == 'rodent_two_touch':
self._environment = _build_rodent_two_touch_env()
elif self._task_name == 'rodent_mazes':
self._environment = _build_rodent_maze_env()
self._environment = NormilizeActionSpecWrapper(self._environment)
self._environment = MujocoActionNormalizer(
environment=self._environment, rescale='clip')
self._environment = wrappers.SinglePrecisionWrapper(self._environment)
all_observations = list(self._proprio_keys) + list(self._pixel_keys)
self._environment = FilterObservationsWrapper(self._environment,
all_observations)
return self._environment
def _parse_seq_tf_example(example, uint8_features, shapes):
"""Parse tf.Example containing one or two episode steps."""
def to_feature(key, shape):
if key in uint8_features:
return tf.io.FixedLenSequenceFeature(
shape=[], dtype=tf.string, allow_missing=True)
else:
return tf.io.FixedLenSequenceFeature(
shape=shape, dtype=tf.float32, allow_missing=True)
feature_map = {}
for k, v in shapes.items():
feature_map[k] = to_feature(k, v)
parsed = tf.io.parse_single_example(example, features=feature_map)
observation = {}
restructured = {}
for k in parsed.keys():
if 'observation' not in k:
restructured[k] = parsed[k]
continue
if k in uint8_features:
observation[k.replace('observation/', '')] = tf.reshape(
tf.io.decode_raw(parsed[k], out_type=tf.uint8), (-1,) + shapes[k])
else:
observation[k.replace('observation/', '')] = parsed[k]
restructured['observation'] = observation
restructured['length'] = tf.shape(restructured['action'])[0]
return restructured
def _build_sequence_example(sequences):
"""Convert raw sequences into a Reverb sequence sample."""
data = adders.Step(
observation=sequences['observation'],
action=sequences['action'],
reward=sequences['reward'],
discount=sequences['discount'],
start_of_episode=(),
extras=())
info = reverb.SampleInfo(
key=tf.constant(0, tf.uint64),
probability=tf.constant(1.0, tf.float64),
table_size=tf.constant(0, tf.int64),
priority=tf.constant(1.0, tf.float64),
times_sampled=tf.constant(1.0, tf.int32))
return reverb.ReplaySample(info=info, data=data)
def _build_sarsa_example(sequences):
"""Convert raw sequences into a Reverb n-step SARSA sample."""
o_tm1 = tree.map_structure(lambda t: t[0], sequences['observation'])
o_t = tree.map_structure(lambda t: t[1], sequences['observation'])
a_tm1 = tree.map_structure(lambda t: t[0], sequences['action'])
a_t = tree.map_structure(lambda t: t[1], sequences['action'])
r_t = tree.map_structure(lambda t: t[0], sequences['reward'])
p_t = tree.map_structure(lambda t: t[0], sequences['discount'])
info = reverb.SampleInfo(
key=tf.constant(0, tf.uint64),
probability=tf.constant(1.0, tf.float64),
table_size=tf.constant(0, tf.int64),
priority=tf.constant(1.0, tf.float64),
times_sampled=tf.constant(1.0, tf.int32))
return reverb.ReplaySample(info=info, data=(o_tm1, a_tm1, r_t, p_t, o_t, a_t))
def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False):
"""Batch data while handling unequal lengths."""
padded_shapes = {}
padded_shapes['observation'] = {}
for k, v in shapes.items():
if 'observation' in k:
padded_shapes['observation'][
k.replace('observation/', '')] = (-1,) + v
else:
padded_shapes[k] = (-1,) + v
padded_shapes['length'] = ()
return example_ds.padded_batch(batch_size,
padded_shapes=padded_shapes,
drop_remainder=drop_remainder)
def dataset(root_path: str,
data_path: str,
shapes: Dict[str, Tuple[int]],
num_threads: int,
batch_size: int,
uint8_features: Optional[Set[str]] = None,
num_shards: int = 100,
shuffle_buffer_size: int = 100000,
sarsa: bool = True) -> tf.data.Dataset:
"""Create tf dataset for training."""
uint8_features = uint8_features if uint8_features else {}
path = os.path.join(root_path, data_path)
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
file_ds = tf.data.Dataset.from_tensor_slices(filenames)
file_ds = file_ds.repeat().shuffle(num_shards)
example_ds = file_ds.interleave(
functools.partial(tf.data.TFRecordDataset, compression_type='GZIP'),
cycle_length=tf.data.experimental.AUTOTUNE,
block_length=5)
example_ds = example_ds.shuffle(shuffle_buffer_size)
def map_func(example):
example = _parse_seq_tf_example(example, uint8_features, shapes)
return example
example_ds = example_ds.map(map_func, num_parallel_calls=num_threads)
example_ds = example_ds.repeat().shuffle(batch_size * 10)
if sarsa:
example_ds = example_ds.map(
_build_sarsa_example,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
example_ds.batch(batch_size)
else:
example_ds = _padded_batch(
example_ds, batch_size, shapes, drop_remainder=True)
example_ds = example_ds.map(
_build_sequence_example,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
example_ds = example_ds.prefetch(tf.data.experimental.AUTOTUNE)
return example_ds