mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2025-12-12 19:43:15 +08:00
Add GPE/GPI experiments.
PiperOrigin-RevId: 323750949
This commit is contained in:
committed by
Diego de Las Casas
parent
59c0cf5044
commit
a24bda5ed0
@@ -18,37 +18,44 @@
|
||||
|
||||
from absl import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
def _ema(base, val, decay=0.995):
|
||||
return base * decay + (1 - decay) * val
|
||||
|
||||
|
||||
def run(environment, agent, num_episodes, report_every=200):
|
||||
def run(env, agent, num_episodes, report_every=200, num_eval_reps=1):
|
||||
"""Runs an agent on an environment.
|
||||
|
||||
Args:
|
||||
environment: The environment.
|
||||
env: The environment.
|
||||
agent: The agent.
|
||||
num_episodes: Number of episodes to train for.
|
||||
report_every: Frequency at which training progress are reported (episodes).
|
||||
num_eval_reps: Number of eval episodes to run per training episode.
|
||||
"""
|
||||
|
||||
train_returns = []
|
||||
train_return_ema = 0.
|
||||
eval_returns = []
|
||||
eval_return_ema = 0.
|
||||
for episode_id in range(num_episodes):
|
||||
# Run a training episode.
|
||||
train_episode_return = run_episode(environment, agent, is_training=True)
|
||||
train_episode_return = run_episode(env, agent, is_training=True)
|
||||
train_returns.append(train_episode_return)
|
||||
train_return_ema = _ema(train_return_ema, train_episode_return)
|
||||
|
||||
# Run an evaluation episode.
|
||||
eval_episode_return = run_episode(environment, agent, is_training=False)
|
||||
eval_returns.append(eval_episode_return)
|
||||
for _ in range(num_eval_reps):
|
||||
eval_episode_return = run_episode(env, agent, is_training=False)
|
||||
eval_returns.append(eval_episode_return)
|
||||
eval_return_ema = _ema(eval_return_ema, eval_episode_return)
|
||||
|
||||
if ((episode_id + 1) % report_every) == 0:
|
||||
logging.info(
|
||||
"Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||
episode_id + 1,
|
||||
np.mean(train_returns[-report_every:]),
|
||||
np.mean(eval_returns[-report_every:]),
|
||||
)
|
||||
logging.info("Episode %s, avg train return %.3f, avg eval return %.3f",
|
||||
episode_id + 1, train_return_ema, eval_return_ema)
|
||||
if hasattr(agent, "get_logs"):
|
||||
logging.info("Episode %s, agent logs: %s", episode_id + 1,
|
||||
agent.get_logs())
|
||||
|
||||
|
||||
def run_episode(environment, agent, is_training=False):
|
||||
|
||||
Reference in New Issue
Block a user