from kondo import Experiment
from tqdm.auto import tqdm
import functools
import torch
from torchrl.envs import make_gym_env, ParallelEnvs
from torchrl.utils.storage import Transition
from torchrl.controllers import Controller, RandomController
[docs]def log_dict(logger, info: dict, tag: str, step=None):
for k, v in info.items():
if v is not None:
try:
logger.add_scalar(f'{tag}/{k}', v,
global_step=step)
except AssertionError:
# NOTE(sanyam): some info may not be scalar and is ignored.
pass
[docs]class BaseExperiment(Experiment):
def __init__(self, env_id: str = None, n_envs: int = 1,
n_frames: int = int(1e3), n_rand_frames: int = 0,
n_train_interval: int = 100, **kwargs):
assert env_id is not None, '"env_id" cannot be None'
super().__init__(**kwargs)
self.device = torch.device('cuda' if self.cuda else 'cpu')
self.n_frames = n_frames
self.n_rand_frames = n_rand_frames
self.n_train_interval = n_train_interval
self.envs = ParallelEnvs(functools.partial(make_gym_env, env_id),
n_envs=n_envs, base_seed=self.seed)
self.controller = self.build_controller()
self._cur_frames = 0
[docs] def build_controller(self) -> Controller:
return RandomController(self.envs.action_space)
[docs] def act(self, obs_list: list) -> list:
if self._cur_frames < self.n_rand_frames:
return RandomController(self.envs.action_space).act(obs_list)
return self.controller.act(obs_list)
[docs] def store(self, transition_list):
'''Placeholder method for storage related usage.
'''
[docs] def train(self) -> dict:
'''Placeholder method for training related usage.
'''
return {}
[docs] def run(self):
with tqdm(initial=self._cur_frames,
total=self.n_frames, unit='steps') as steps_bar:
obs_list = self.envs.reset(range(self.envs.n_procs))
while self._cur_frames < self.n_frames:
action_list = self.act(obs_list)
step_list = self.envs.step(list(range(self.envs.n_procs)), action_list)
steps_bar.update(self.envs.n_procs)
transition_list = []
for i, (obs, action, (next_obs, rew, done, _)) in \
enumerate(zip(obs_list, action_list, step_list)):
obs_list[i] = next_obs
self._cur_frames += 1
transition_list.append(
Transition(obs=obs, action=action,
reward=rew, next_obs=next_obs,
done=done))
if done:
log_dict(self.logger,
self.envs.exec_remote('info', proc_list=[i])[0],
tag='episode', step=self._cur_frames)
obs_list[i] = self.envs.reset([i])[0]
self.store(transition_list)
if self._cur_frames >= self.n_rand_frames \
and self._cur_frames % self.n_train_interval == 0:
train_info = self.train()
log_dict(self.logger, train_info, tag='train', step=self._cur_frames)
self.logger.close()
self.envs.close()