from abc import abstractmethod
from typing import Sequence
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from prl.function_approximators.function_approximators import FunctionApproximator
from prl.typing import PytorchNetABC
from prl.utils import timeit, nn_logger
[docs]class PytorchNet(PytorchNetABC):
"""
Neural networks for PytorchFA. It has separate predict method strictly for
Agent.act() method, wchich can act differently than forward() method.
Note:
This class has two abstract methods that need to be implemented (listed above).
"""
[docs] @abstractmethod
def forward(self, x: torch.Tensor):
"""Defines the computation performed at every training step.
Args:
x: input data
Returns:
network output
"""
pass
[docs] @abstractmethod
def predict(self, x: torch.Tensor):
"""Makes prediction based on input data.
Args:
x: input data
Returns:
prediction for agent.act(x) method
"""
pass
[docs]class PytorchFA(FunctionApproximator):
"""Class for pytorch based neural networks function approximators.
Args:
net: PytorchNet class neural network
loss: loss function
optimizer: optimizer
device: device for computation: "cpu" or "cuda"
batch_size: size of a training batch
last_batch: flag if the last batch (usually shorter than batch_size) is going to be feed into network
network_id: name of the network for debugging and logging purposes
"""
def __init__(
self,
net: PytorchNet,
loss: _Loss,
optimizer: Optimizer,
device: str = "cpu",
batch_size: int = 64,
last_batch: bool = True,
network_id: str = "pytorch_nn",
):
self._id = network_id
self._device = device
self._net = net.to(self._device)
self._optimizer = optimizer
self._loss = loss
self._batch_size = batch_size
self._last_batch = int(last_batch)
[docs] def convert_to_pytorch(self, y: np.ndarray):
if np.issubdtype(y.dtype, np.integer):
y = torch.LongTensor(y).to(self._device)
elif np.issubdtype(y.dtype, np.floating):
y = torch.FloatTensor(y).to(self._device)
return y
@property
def id(self):
return self._id
[docs] @timeit
def train(self, x: np.ndarray, *loss_args):
"""Trains network on a dataset
Args:
x: input array for the network
*loss_args: arguments passed directly to loss function
"""
x = torch.from_numpy(x).to(self._device)
indicies = torch.randperm(x.shape[0])
loss_args = tuple(self.convert_to_pytorch(y) for y in loss_args)
for i in range((indicies.shape[0] - 1) // self._batch_size + self._last_batch):
start = i * self._batch_size
end = (i + 1) * self._batch_size
self._optimizer.zero_grad()
batch = x[indicies[start:end]]
y_pred = self._net(batch)
loss_args_batch = (y[indicies[start:end]] for y in loss_args)
loss = self._loss(y_pred, *loss_args_batch)
loss.backward()
nn_logger.add(self.id + str(id(self)), loss.item())
self._optimizer.step()
[docs] @timeit
def predict(self, x: np.ndarray):
"""Makes prediction"""
x = timeit(torch.from_numpy, "from_numpy")(x).float().to(self._device)
return self._net.predict(x).cpu().data.numpy()
[docs]class PytorchMLP(PytorchNet):
def __init__(self, x_shape, y_size, output_activation, hidden_sizes: Sequence[int]):
super().__init__()
assert len(x_shape) == 1, "Input must be flat for MLP."
assert len(hidden_sizes) > 0
self.y_size = y_size
self.output_activation = output_activation
self.layers = nn.ModuleList()
self.layers.append(nn.Linear(x_shape[0], hidden_sizes[0]))
for i in range(1, len(hidden_sizes)):
self.layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
self.layers.append(nn.Linear(hidden_sizes[-1], y_size))
[docs] def forward(self, x):
for layer in self.layers[:-1]:
x = F.relu(layer(x))
return self.layers[-1](x)
[docs] def predict(self, x):
return self.output_activation(self.forward(x))
[docs]class PytorchConv(PytorchNet):
def __init__(self, x_shape, hidden_sizes: Sequence[int], y_size):
super().__init__()
assert len(x_shape) == 3, "Input must be an image for a conv network."
assert len(hidden_sizes) > 0
self.softmax = nn.Softmax(dim=1)
dims = [x_shape[-1]] + hidden_sizes
(height, width, _) = x_shape
self.conv_layers = nn.ModuleList(
[
nn.Conv2d(*in_out_dim, kernel_size=3, stride=2, padding=1)
for in_out_dim in zip(dims[:-1], dims[1:])
]
)
self.out_layer = nn.Linear(
height * width // 4 ** len(hidden_sizes) * dims[-1], y_size
)
[docs] def forward(self, x):
x = x.permute(0, 3, 1, 2)
for layer in self.conv_layers:
x = F.relu(layer(x))
return self.out_layer(x.view(x.size(0), -1))
[docs] def predict(self, x):
return self.softmax(self.forward(x))
[docs]class PolicyGradientLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction="mean"):
super().__init__(size_average, reduce, reduction)
[docs] def forward(self, nn_outputs, actions, returns):
output_log_probs = F.log_softmax(nn_outputs, dim=1)
log_prob_actions_v = returns * output_log_probs[range(len(actions)), actions]
return -log_prob_actions_v.mean()
[docs]class DQNLoss(_Loss):
def __init__(self, mode="huber", size_average=None, reduce=None, reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.mode = mode
self.loss = {"huber": F.smooth_l1_loss, "mse": F.mse_loss}[mode]
[docs] def forward(self, nn_outputs, actions, target_outputs):
target = nn_outputs.clone().detach()
target[np.arange(target.shape[0]), actions] = target_outputs
return self.loss(nn_outputs, target, reduction=self.reduction)