import glob
import os
import time
from copy import deepcopy
from functools import reduce
from math import gcd
from operator import iadd
import numpy as np
import torch
from tensorboardX import SummaryWriter
from prl.typing import AgentABC, HistoryABC, EnvironmentABC, AgentCallbackABC
from prl.utils import (
timeit,
time_logger,
agent_logger,
memory_logger,
nn_logger,
misc_logger,
)
[docs]class AgentCallback(AgentCallbackABC):
"""
Interface for Callbacks defining actions that are executed automatically during
different phases of agent training.
"""
def __init__(self):
self.time_logger_cid = time_logger.register()
self.agent_logger_cid = agent_logger.register()
self.memory_logger_cid = memory_logger.register()
self.nn_logger_cid = nn_logger.register()
self.misc_logger_cid = misc_logger.register()
[docs] def on_iteration_end(self, agent: AgentABC) -> bool:
"""
Method called at the end of every iteration in `prl.base.Agent.train` method.
Args:
agent: Agent in which this callback is called.
Returns:
True if training should be interrupted, False otherwise
"""
[docs] def on_training_end(self, agent: AgentABC):
"""
Method called after `prl.base.Agent.post_train_cleanup`.
Args:
agent: Agent in which this callback is called.
"""
[docs] def on_training_begin(self, agent: AgentABC):
"""Method called after `prl.base.Agent.pre_train_setup`.
Args:
agent: Agent in which this callback is called
"""
[docs]class EarlyStopping(AgentCallback):
"""
Implements EarlyStopping for RL Agents. Training is stopped after reaching given
target reward.
Args:
target_reward: Target reward.
iteration_interval: Interval between calculating test reward.
Using low values may make training process slower.
number_of_test_runs: Number of test runs when calculating reward.
Higher value averages variance out, but makes training longer.
verbose: Whether to print message after stopping training (1), or not (0).
Note:
By reward, we mean here untransformed reward given by `Agent.test` method.
For more info on methods see base class.
"""
def __init__(
self,
target_reward: float,
iteration_interval: int = 1,
number_of_test_runs: int = 1,
verbose: int = 1,
):
super().__init__()
self.number_of_test_runs = number_of_test_runs
self.iteration_interval = iteration_interval
self.target_reward = target_reward
self.verbose = verbose
self.needs_tests = True
[docs] def on_iteration_end(self, agent: AgentABC):
agent_logs = agent_logger.flush(self.agent_logger_cid)
mean_test_reward = np.mean(agent_logs[0]["test_episode_total_reward"])
break_flag = mean_test_reward >= self.target_reward
if break_flag and self.verbose:
print(
"Early stopping in iteration_number %s. "
"Achieved mean raw reward of %.4f (target was %.4f)"
% (agent.iteration_count, mean_test_reward, self.target_reward)
)
return break_flag
[docs]class BaseAgentCheckpoint(AgentCallback):
"""
Saving agents during training. This is a base class that implements only logic.
One should use classes with saving method matching networks' framework.
For more info on methods see base class.
Args:
target_path: Directory in which agents will be saved. Must exist before
creating this callback.
save_best_only: Whether to save all models, or only the one with highest reward.
iteration_interval: Interval between calculating test reward. Using low values may make training process slower
number_of_test_runs: Number of test runs when calculating reward. Higher value averages variance out, but makes training longer.
"""
def __init__(
self,
target_path: str,
save_best_only: bool = True,
iteration_interval: int = 1,
number_of_test_runs: int = 1,
):
super().__init__()
assert os.path.exists(target_path), (
"Provided path (%s) does not exist!" % target_path
)
self.number_of_test_runs = number_of_test_runs
self.iteration_interval = iteration_interval
self.target_path = target_path
self.save_best_only = save_best_only
self.best_score = -np.inf
self.needs_tests = True
def _try_save(self, agent: AgentABC):
agent_logs = agent_logger.flush(self.agent_logger_cid)
mean_test_reward = np.mean(agent_logs[0]["test_episode_total_reward"])
if mean_test_reward > self.best_score or not self.save_best_only:
if mean_test_reward > self.best_score:
self.best_score = mean_test_reward
self._save_agent(agent, mean_test_reward)
[docs] def on_iteration_end(self, agent: AgentABC):
self._try_save(agent)
[docs] def on_training_end(self, agent: AgentABC):
self._try_save(agent)
def _save_agent(self, agent: AgentABC, reward: float):
raise NotImplementedError(
"This is a base class for agent checkpoints. Please use one of "
"the subclasses corresponding to your backend."
)
[docs]class PyTorchAgentCheckpoint(BaseAgentCheckpoint):
"""Class for saving PyTorch-based agents. For more details, see parent class."""
def _save_agent(self, agent: AgentABC, reward: float):
if self.save_best_only:
old_model_paths = glob.glob(os.path.join(self.target_path, agent.id + "*"))
file_name = "%s_%s_%.4f" % (agent.id, agent.iteration_count, reward)
full_path = os.path.join(self.target_path, file_name)
torch.save(agent, full_path)
if self.save_best_only:
[os.remove(old_model_path) for old_model_path in old_model_paths]
[docs]class TrainingLogger(AgentCallback):
"""
Logs training information after certain amount of iterations.
Data may appear in output, or be written into a file.
For more info on methods see base class.
Args:
on_screen: Whether to show info in output.
to_file: Whether to save info into a file.
file_path: Path to file with output.
iteration_interval: How often should info be logged on screen. File output remains logged every iteration.
"""
def __init__(
self,
on_screen: bool = True,
to_file: bool = False,
file_path: str = None,
iteration_interval: int = 1,
):
super().__init__()
assert not to_file or (to_file and file_path is not None)
self.on_screen = on_screen
self.to_file = to_file
self.file_path = file_path
self.iteration_interval = iteration_interval
self.needs_tests = False
self._file_created = False
self._last_agent_step_time = time.time()
def _log_on_screen(
self,
iteration_number: int,
mean_steps_per_second: float,
loss: float,
mean_reward: float,
mean_episode_length: float,
):
print(
"Iteration: %d. "
"Training metrics: loss=%.4f, total_reward_mean=%.2f, "
"mean_episode_length=%.2f, mean_steps_per_second=%.4f, "
% (
iteration_number,
loss,
mean_reward,
mean_episode_length,
mean_steps_per_second,
)
)
def _log_to_file(
self,
iteration_number: int,
mean_steps_per_second: float,
loss: float,
mean_reward: float,
mean_episode_length: float,
):
mode = "w" if not self._file_created else "a"
with open(self.file_path, mode) as f:
if not self._file_created:
f.write(
"iteration_number,loss,total_reward_mean,"
"mean_episode_length,mean_steps_per_second\n"
)
self._file_created = True
f.write(
"%s,%.4f,%.2f,%.2f,%.4f\n"
% (
iteration_number,
loss,
mean_reward,
mean_episode_length,
mean_steps_per_second,
)
)
def _calculate_summaries(self, agent_log, nn_log):
# Sometimes agent_step is missing (rarely) - why?
if "agent_step" in agent_log[2]:
agent_step = agent_log[2]["agent_step"]
# Include last step time from previous iteration to make this work
# with just one step per iteration.
agent_step = [self._last_agent_step_time] + agent_step
mean_time_per_step = np.mean(np.diff(agent_step))
mean_steps_per_second = 1 / mean_time_per_step
self._last_agent_step_time = agent_step[-1]
else:
mean_steps_per_second = 0
mean_loss_since_last_log = np.mean(list(nn_log[0].values())[0])
# Episode might not have ended in this iteration.
if "episode_total_reward" in agent_log[0]:
mean_reward_since_last_log = np.mean(agent_log[0]["episode_total_reward"])
mean_episode_length_since_last_log = np.mean(agent_log[0]["episode_length"])
else:
mean_reward_since_last_log = 0
mean_episode_length_since_last_log = 0
return (
mean_steps_per_second,
mean_loss_since_last_log,
mean_reward_since_last_log,
mean_episode_length_since_last_log,
)
[docs] def on_iteration_end(self, agent: AgentABC):
recent_agent_logs = agent_logger.flush(self.agent_logger_cid)
recent_nn_logs = nn_logger.flush(self.nn_logger_cid)
summaries = self._calculate_summaries(recent_agent_logs, recent_nn_logs)
if self.on_screen:
self._log_on_screen(agent.iteration_count, *summaries)
if self.to_file:
self._log_to_file(agent.iteration_count, *summaries)
[docs]class ValidationLogger(AgentCallback):
"""
Logs validation information after certain amount of iterations.
Data may appear in output, or be written into a file.
For more info on methods see base class.
Args:
on_screen: Whether to show info in output.
to_file: Whether to save info into a file.
file_path: Path to file with output.
iteration_interval: How often should info be logged on screen. File output
remains logged every iteration.
number_of_test_runs: Number of played episodes in history's summary logs.
"""
def __init__(
self,
on_screen: bool = True,
to_file: bool = False,
file_path: str = None,
iteration_interval: int = 1,
number_of_test_runs: int = 3,
):
super().__init__()
assert not to_file or (to_file and file_path is not None)
self.on_screen = on_screen
self.to_file = to_file
self.file_path = file_path
self.iteration_interval = iteration_interval
self.number_of_test_runs = number_of_test_runs
self.needs_tests = True
def _log_on_screen(self, iteration_number: int, history_summary: tuple):
print(
"Iteration: %d. Validation metrics: total_reward_mean=%.1f, "
"mean_length=%.1f" % ((iteration_number,) + history_summary)
)
def _log_to_file(self, iteration_number: int, history_summary: tuple):
mode = "w" if iteration_number == 0 else "a"
with open(self.file_path, mode) as f:
if iteration_number == 0:
f.write("iteration_number,total_reward_mean,mean_length\n")
f.write("%s,%.1f,%.1f\n" % ((iteration_number,) + history_summary))
[docs] def on_iteration_end(self, agent: AgentABC):
agent_logs = agent_logger.flush(self.agent_logger_cid)
mean_test_reward = np.mean(agent_logs[0]["test_episode_total_reward"])
mean_test_episode_length = np.mean(agent_logs[0]["test_episode_length"])
history_summary = (mean_test_reward, mean_test_episode_length)
if self.on_screen:
self._log_on_screen(agent.iteration_count, history_summary)
if self.to_file:
self._log_to_file(agent.iteration_count, history_summary)
[docs]class TensorboardLogger(AgentCallback):
"""
Writes various information to tensorboard during training.
For more info on methods see base class.
Args:
file_path: Path to file with output.
iteration_interval: Interval between calculating test reward. Using low values may make training process slower.
number_of_test_runs: Number of test runs when calculating reward. Higher value averages variance out, but makes training longer.
show_time_logs: If shows logs from time_logger.
"""
def __init__(
self,
file_path: str = "logs_" + str(int(time.time())),
iteration_interval: int = 1,
number_of_test_runs: int = 1,
show_time_logs: bool = False,
):
super().__init__()
self.writer = SummaryWriter(file_path)
self.iteration_interval = iteration_interval
self.number_of_test_runs = number_of_test_runs
self.show_time_logs = show_time_logs
self.needs_tests = True
[docs] def on_iteration_end(self, agent: AgentABC):
names = ["agent_logger", "nn_logger", "memory_logger", "misc_logger"]
loggers = [agent_logger, nn_logger, memory_logger, misc_logger]
cids = [
self.agent_logger_cid,
self.nn_logger_cid,
self.memory_logger_cid,
self.misc_logger_cid,
]
if self.show_time_logs:
names.append("time_logger")
loggers.append(time_logger)
cids.append(self.time_logger_cid)
for name, logger, cid in zip(names, loggers, cids):
data, indicies, timestamps = logger.flush(cid)
for key in data.keys():
for i in range(len(data[key])):
self.writer.add_scalar(
"%s/%s" % (name, key),
data[key][i],
indicies[key][i],
timestamps[key][i],
)
[docs] def on_training_end(self, agent: AgentABC):
self.writer.close()
[docs]class CallbackHandler:
"""
Callback that handles all given handles. Calls appropriate methods on each callback
and aggregates break codes.
For more info on methods see base class.
"""
def __init__(self, callback_list: list, env: EnvironmentABC):
self.callback_list = callback_list or []
self.common_iteration_interval = None
self.common_test_procedure_interval = None
self.number_of_test_runs = None
self.env = deepcopy(env)
self.setup_callbacks()
[docs] def setup_callbacks(self):
"""
Sets up callbacks. This calculates optimal intervals for calling callbacks,
and for calling testing procedure.
"""
all_callback_iteration_intervals = (
callback.iteration_interval for callback in self.callback_list
)
test_callback_iteration_intervals = (
callback.iteration_interval
for callback in self.callback_list
if callback.needs_tests
)
numbers_of_callbacks_test_runs = (
callback.number_of_test_runs
for callback in self.callback_list
if callback.needs_tests
)
# we have to run enough test iterations to satisfy callback with
# highest amount of runs
try:
self.number_of_test_runs = max(numbers_of_callbacks_test_runs)
except ValueError:
# when there are no test runs we set number of test runs to 0
self.number_of_test_runs = 0
self.common_iteration_interval = reduce(
gcd, all_callback_iteration_intervals, 0
)
self.common_test_procedure_interval = reduce(
gcd, test_callback_iteration_intervals, 0
)
[docs] def run_tests(self, agent: AgentABC) -> HistoryABC:
history_list = [agent.test(self.env) for _ in range(self.number_of_test_runs)]
history = reduce(iadd, history_list)
return history
[docs] @staticmethod
def check_run_condition(current_count, interval):
if interval == 0:
return False
return current_count % interval == 0
[docs] @timeit
def on_iteration_end(self, agent: AgentABC):
if not self.check_run_condition(
agent.iteration_count, self.common_iteration_interval
):
return False
if self.check_run_condition(
agent.iteration_count, self.common_test_procedure_interval
):
self.run_tests(agent)
break_signals = [
callback.on_iteration_end(agent)
for callback in self.callback_list
if self.check_run_condition(
agent.iteration_count, callback.iteration_interval
)
]
return any(break_signals)
[docs] @timeit
def on_training_end(self, agent: AgentABC):
if self.number_of_test_runs > 0:
self.run_tests(agent)
[callback.on_training_end(agent) for callback in self.callback_list]
[docs] @timeit
def on_training_begin(self, agent: AgentABC):
[callback.on_training_begin(agent) for callback in self.callback_list]