Add GPE/GPI experiments.

PiperOrigin-RevId: 323750949
This commit is contained in:
Shaobo Hou
2020-07-29 10:49:45 +01:00
committed by Diego de Las Casas
parent 59c0cf5044
commit a24bda5ed0
20 changed files with 1732 additions and 62 deletions

View File

@@ -18,37 +18,44 @@
from absl import logging
import numpy as np
def _ema(base, val, decay=0.995):
return base * decay + (1 - decay) * val
def run(environment, agent, num_episodes, report_every=200):
def run(env, agent, num_episodes, report_every=200, num_eval_reps=1):
"""Runs an agent on an environment.
Args:
environment: The environment.
env: The environment.
agent: The agent.
num_episodes: Number of episodes to train for.
report_every: Frequency at which training progress are reported (episodes).
num_eval_reps: Number of eval episodes to run per training episode.
"""
train_returns = []
train_return_ema = 0.
eval_returns = []
eval_return_ema = 0.
for episode_id in range(num_episodes):
# Run a training episode.
train_episode_return = run_episode(environment, agent, is_training=True)
train_episode_return = run_episode(env, agent, is_training=True)
train_returns.append(train_episode_return)
train_return_ema = _ema(train_return_ema, train_episode_return)
# Run an evaluation episode.
eval_episode_return = run_episode(environment, agent, is_training=False)
eval_returns.append(eval_episode_return)
for _ in range(num_eval_reps):
eval_episode_return = run_episode(env, agent, is_training=False)
eval_returns.append(eval_episode_return)
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
if ((episode_id + 1) % report_every) == 0:
logging.info(
"Episode %s, avg train return %.3f, avg eval return %.3f",
episode_id + 1,
np.mean(train_returns[-report_every:]),
np.mean(eval_returns[-report_every:]),
)
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
episode_id + 1, train_return_ema, eval_return_ema)
if hasattr(agent, "get_logs"):
logging.info("Episode %s, agent logs: %s", episode_id + 1,
agent.get_logs())
def run_episode(environment, agent, is_training=False):