torchrl.registry package

Submodules

torchrl.registry.problems module

class torchrl.registry.problems.HParams(kwargs=None)[source]

Bases: object

This class is friendly wrapper over Python Dictionary to represent the named hyperparameters.

Example

One can manually set arbitrary strings as hyperparameters as

import torchrl.registry as registry
hparams = registry.HParams()
hparams.paramA = 'myparam'
hparams.paramB = 10

or just send in a dictionary object containing all the relevant key/value pairs.

import torchrl.registry as registry
hparams = registry.HParams({'paramA': 'myparam', 'paramB': 10})
assert hparams.paramA == 'myparam'
assert hparams.paramB == 10

Both form equivalent hyperparameter objects.

To update/override the hyperparamers, use the update() method.

hparams.update({'paramA': 20, 'paramB': 'otherparam', 'paramC': 5.0})
assert hparams.paramA == 20
assert hparams.paramB == 'otherparam'
Parameters:
  • kwargs (dict) – Python dictionary representing named hyperparameters and
  • values.
update(items: dict)[source]

Merge two Hyperparameter objects, overriding any repeated keys from the items parameter.

Parameters:items (dict) – Python dictionary containing updated values.
class torchrl.registry.problems.Problem(hparams: torchrl.registry.problems.HParams, problem_args: argparse.Namespace, log_dir: str, device: str = 'cuda', show_progress: bool = True, checkpoint_prefix='checkpoint')[source]

Bases: object

This abstract class defines a Reinforcement Learning problem.

Parameters:
  • hparams (HParams) – Object containing all named-hyperparameters.
  • problem_args (argparse.Namespace) – Argparse namespace object containing Problem arguments like seed, log_interval, eval_interval.
  • log_dir (str) – Path to log directory.
  • device (str) – String passed to torch.device().
  • show_progress (bool) – If true, an animated progress is shown based on tqdm.
  • checkpoint_prefix (str) – Prefix for the saved checkpoint files.

Todo

  • Remove usage of argparse.Namespace for problem_args and use HParams instead. As a temporary usage fix, convert any dictionary into argparse.Namespace using argparse.Namespace(**mydict). Tracked by #61.
eval(epoch)[source]

This method must be overridden by the derived Problem class and should contain the core idea behind the evaluation of the trained model. This is also responsible for any metric logging using the self.logger object.

self.args.num_eval should be a helpful variable.

Note

It is a good idea to always use train() to set training False here.

Parameters:epoch (int) – Epoch number in question.
init_agent() → torchrl.agents.base_agent.BaseAgent[source]

This method is called by the constructor and must be overriden by any derived class. Using the hyperparameters and problem arguments, one should construct an agent here.

Returns:Any derived agent class.
Return type:BaseAgent
load_checkpoint(load_dir, epoch=None)[source]

This method loads the latest checkpoint from a directory. It also updates the self.start_epoch attribute so that any further calls to save_checkpoint don’t overwrite the previously saved checkpoints. The file name format is <CHECKPOINT_PREFIX>-<EPOCH>.ckpt.

Parameters:
  • load_dir (str) – Path to directory containing checkpoint files.
  • epoch (int) – Epoch number to load. If None, then the file with the latest timestamp is loaded from the given directory.
make_runner(n_envs=1, seed=None) → torchrl.runners.base_runner.BaseRunner[source]

This method is called by the constructor and must be overriden by any derived class. Using the hyperparameters and problem arguments, one should construct an environment runner here.

Returns:
Any derived runner
class.
Return type:BaseRunner
run()[source]

This is the entrypoint to a problem class and can be overridden if desired. However, a common rollout, train and eval loop has already been provided here. All variables for logging are prefixed with “log_”.

self.args.log_interval and self.args.eval_interval should be helpful variables.

Note

This precoded routine implements the following general steps

  • Set agent to train mode using train().
  • Rollout trajectories using runner’s rollout().
  • Unset agent’s train mode.
  • Run the training routine using train() which could potentially be using agent’s learn().
  • Evaluate the learned agent using eval().
  • Periodically log and save checkpoints using save_checkpoint().

Since, this routine handles multiple parallel trajectories, care must be taken to reset the environment instances (this should be handled by the appropriate runner or as desired).

save_checkpoint(epoch)[source]

Save checkpoint at a given epoch. The format is <CHECKPOINT_PREFIX>-<EPOCH>.ckpt

Parameters:epoch (int) – Value of the epoch number.
train(history_list: list) → dict[source]

This method must be overridden by the derived Problem class and should contain the core idea behind the training step.

There are no restrictions to what comes into this argument as long as the derived class takes care of following. Typically this should involve a list of rollouts (possibly for each parallel trajectory) and all relevant values for each rollout - observation, action, reward, next observation, termination flag and potentially other information. This raw data must be processed as desired. See hist_to_tensor() for a sample routine.

Note

It is a good idea to always use train() appropriately here.

Parameters:history_list (list) – A list of histories. This will typically be returned by the rollout() method of the runner.
Returns:A Python dictionary containing labeled losses.
Return type:dict

torchrl.registry.registry module

torchrl.registry.registry.register_hparam(name: Union[Callable, str])[source]

A decorator to register hyperparameter function.

Example

import torch.registry as registry

@registry.register_hparam
def my_new_hparams():
  hparams = registry.HParams()
  hparams.x = 1
  return hparams

This will be registered by name my_new_hparams. Optionally, we can also provide a name as argument to the decorator.

@registry.register_hparam('my_renamed_hparams')
Parameters:name (str, Callable) – Optionally pass a string argument for name or will be the callable.
Returns:A decorated function.
Return type:Callable
torchrl.registry.registry.list_hparams() → list[source]

List all registered hyperparameters.

Returns:List of hyperparameter name strings.
Return type:list
torchrl.registry.registry.get_hparam(hparam_set_id: str) → Callable[source]

Get registered hyperparameter by name.

Parameters:hparam_set_id (str) – A string representing name of hyperparameter set.
Returns:A function that returns HParams.
Return type:Callable
torchrl.registry.registry.remove_hparam(hparam_set_id: str)[source]

De-register a hyperparameter set.

Parameters:hparam_set_id (str) – Name of registered hyperparameter.
torchrl.registry.registry.register_problem(name: Union[Callable, str])[source]

A decorator to register problems.

Example

import torch.registry as registry

@registry.register_problem
class MyProblem(registry.Problem):
  ...

This will be registered by name my_problem. Optionally, we can also provide a name as argument to the decorator.

@registry.register_problem('my_renamed_problem')
Parameters:name (str, Callable) – Optionally pass a string argument for name or will be the callable.
Returns:A decorated function.
Return type:Callable
torchrl.registry.registry.list_problems()[source]

List all registered Problems.

Returns:List of string containing all problem names.
Return type:list
torchrl.registry.registry.get_problem(problem_id: str)[source]

Get uninstatiated problem class.

Parameters:problem_id (str) – Name of registered problem.
Returns:Any derived problem class.
Return type:torchrl.registry.problems.Problem
torchrl.registry.registry.remove_problem(problem_id: str)[source]

De-register a problem.

Parameters:problem_id (str) – Name of registered problem.
torchrl.registry.registry.list_problem_hparams()[source]

List all registered hyperparameters associated with a problem. Any static method of a Problem class whose name is prefixed with hparams_ is associated to a problem. This routine returns all such associations available.

Example

The format of returned values is

{
  "problem_name": [
    "hparam_set1", "hparam_set2"
  ],
  "other_problem": [
    "other_problem_hparam1"
  ]
}
Returns:List of problem-hyperparameter associations of the following format.
Return type:list
torchrl.registry.registry.get_problem_hparam(problem_id: str)[source]

Get the associated hyperparameters to a problem.

Parameters:problem_id (str) – Name of registered problem.
Returns:List of hyperparameter sets.
Return type:list

Module contents