prl.callbacks package

Submodules

prl.callbacks.callbacks module

class AgentCallback[source]

Bases: prl.typing.AgentCallbackABC

Interface for Callbacks defining actions that are executed automatically during different phases of agent training.

on_iteration_end(agent)[source]

Method called at the end of every iteration in prl.base.Agent.train method.

Parameters:agent (AgentABC) – Agent in which this callback is called.
Return type:bool
Returns:True if training should be interrupted, False otherwise
on_training_begin(agent)[source]

Method called after prl.base.Agent.pre_train_setup.

Parameters:agent (AgentABC) – Agent in which this callback is called
on_training_end(agent)[source]

Method called after prl.base.Agent.post_train_cleanup.

Parameters:agent (AgentABC) – Agent in which this callback is called.
class BaseAgentCheckpoint(target_path, save_best_only=True, iteration_interval=1, number_of_test_runs=1)[source]

Bases: prl.callbacks.callbacks.AgentCallback

Saving agents during training. This is a base class that implements only logic. One should use classes with saving method matching networks’ framework. For more info on methods see base class.

Parameters:
  • target_path (str) – Directory in which agents will be saved. Must exist before
  • this callback. (creating) –
  • save_best_only (bool) – Whether to save all models, or only the one with highest reward.
  • iteration_interval (int) – Interval between calculating test reward. Using low values may make training process slower
  • number_of_test_runs (int) – Number of test runs when calculating reward. Higher value averages variance out, but makes training longer.
on_iteration_end(agent)[source]

Method called at the end of every iteration in prl.base.Agent.train method.

Parameters:agent (AgentABC) – Agent in which this callback is called.
Returns:True if training should be interrupted, False otherwise
on_training_end(agent)[source]

Method called after prl.base.Agent.post_train_cleanup.

Parameters:agent (AgentABC) – Agent in which this callback is called.
class CallbackHandler(callback_list, env)[source]

Bases: object

Callback that handles all given handles. Calls appropriate methods on each callback and aggregates break codes. For more info on methods see base class.

static check_run_condition(current_count, interval)[source]
on_iteration_end(agent)[source]
on_training_begin(agent)[source]
on_training_end(agent)[source]
run_tests(agent)[source]
Return type:HistoryABC
setup_callbacks()[source]

Sets up callbacks. This calculates optimal intervals for calling callbacks, and for calling testing procedure.

class EarlyStopping(target_reward, iteration_interval=1, number_of_test_runs=1, verbose=1)[source]

Bases: prl.callbacks.callbacks.AgentCallback

Implements EarlyStopping for RL Agents. Training is stopped after reaching given target reward.

Parameters:
  • target_reward (float) – Target reward.
  • iteration_interval (int) – Interval between calculating test reward. Using low values may make training process slower.
  • number_of_test_runs (int) – Number of test runs when calculating reward. Higher value averages variance out, but makes training longer.
  • verbose (int) – Whether to print message after stopping training (1), or not (0).

Note

By reward, we mean here untransformed reward given by Agent.test method. For more info on methods see base class.

on_iteration_end(agent)[source]

Method called at the end of every iteration in prl.base.Agent.train method.

Parameters:agent (AgentABC) – Agent in which this callback is called.
Returns:True if training should be interrupted, False otherwise
class PyTorchAgentCheckpoint(target_path, save_best_only=True, iteration_interval=1, number_of_test_runs=1)[source]

Bases: prl.callbacks.callbacks.BaseAgentCheckpoint

Class for saving PyTorch-based agents. For more details, see parent class.

class TensorboardLogger(file_path='logs_1581541984', iteration_interval=1, number_of_test_runs=1, show_time_logs=False)[source]

Bases: prl.callbacks.callbacks.AgentCallback

Writes various information to tensorboard during training. For more info on methods see base class.

Parameters:
  • file_path (str) – Path to file with output.
  • iteration_interval (int) – Interval between calculating test reward. Using low values may make training process slower.
  • number_of_test_runs (int) – Number of test runs when calculating reward. Higher value averages variance out, but makes training longer.
  • show_time_logs (bool) – If shows logs from time_logger.
on_iteration_end(agent)[source]

Method called at the end of every iteration in prl.base.Agent.train method.

Parameters:agent (AgentABC) – Agent in which this callback is called.
Returns:True if training should be interrupted, False otherwise
on_training_end(agent)[source]

Method called after prl.base.Agent.post_train_cleanup.

Parameters:agent (AgentABC) – Agent in which this callback is called.
class TrainingLogger(on_screen=True, to_file=False, file_path=None, iteration_interval=1)[source]

Bases: prl.callbacks.callbacks.AgentCallback

Logs training information after certain amount of iterations. Data may appear in output, or be written into a file. For more info on methods see base class.

Parameters:
  • on_screen (bool) – Whether to show info in output.
  • to_file (bool) – Whether to save info into a file.
  • file_path (Optional[str]) – Path to file with output.
  • iteration_interval (int) – How often should info be logged on screen. File output remains logged every iteration.
on_iteration_end(agent)[source]

Method called at the end of every iteration in prl.base.Agent.train method.

Parameters:agent (AgentABC) – Agent in which this callback is called.
Returns:True if training should be interrupted, False otherwise
class ValidationLogger(on_screen=True, to_file=False, file_path=None, iteration_interval=1, number_of_test_runs=3)[source]

Bases: prl.callbacks.callbacks.AgentCallback

Logs validation information after certain amount of iterations. Data may appear in output, or be written into a file. For more info on methods see base class.

Parameters:
  • on_screen (bool) – Whether to show info in output.
  • to_file (bool) – Whether to save info into a file.
  • file_path (Optional[str]) – Path to file with output.
  • iteration_interval (int) – How often should info be logged on screen. File output
  • logged every iteration. (remains) –
  • number_of_test_runs (int) – Number of played episodes in history’s summary logs.
on_iteration_end(agent)[source]

Method called at the end of every iteration in prl.base.Agent.train method.

Parameters:agent (AgentABC) – Agent in which this callback is called.
Returns:True if training should be interrupted, False otherwise

Module contents