import functools
from multiprocessing import Pipe, Process
from .env_utils import get_gym_spaces
[docs]def target_fn(conn, obj_fn):
obj = obj_fn()
while True:
fn_string, args, kwargs = conn.recv()
func = getattr(obj, fn_string)
if callable(func):
if args is not None:
func = functools.partial(func, *args)
if kwargs is not None:
func = functools.partial(func, **kwargs)
result = func()
else:
result = func
conn.send(result)
if fn_string == 'close':
break
[docs]class MultiProcWrapper:
"""
A generic wrapper which takes a list of functions to be run inside a process.
Each function must return an object, see `target_fn` for how it is used.
Communication between each new process and the parent process happens via
Pipes.
"""
def __init__(self, obj_fns, daemon=True, autostart=True):
self.n_procs = len(obj_fns)
self.daemon = daemon
self.p_conn, self.child_conn = zip(*[Pipe() for _ in range(self.n_procs)])
self.proc_list = [
Process(target=target_fn, args=(conn, obj_fn))
for conn, obj_fn in zip(self.child_conn, obj_fns)
]
if autostart:
self.start()
[docs] def start(self):
for proc in self.proc_list:
proc.daemon = self.daemon
proc.start()
[docs] def stop(self):
self.exec_remote('stop')
for proc in self.proc_list:
if proc.is_alive():
proc.join()
[docs] def exec_remote(self, fn_string, proc_list=None,
args_list=None, kwargs_list=None):
if proc_list is None:
proc_list = list(range(self.n_procs))
if args_list is None:
args_list = [None] * len(proc_list)
if kwargs_list is None:
kwargs_list = [None] * len(proc_list)
assert len(args_list) == len(proc_list) and \
len(kwargs_list) == len(proc_list), \
'Argument list mismatch!'
target_p_conn = []
for i, p_conn in enumerate(self.p_conn):
if i in proc_list:
target_p_conn.append(p_conn)
for conn, args, kwargs in zip(target_p_conn, args_list, kwargs_list):
conn.send((fn_string, args, kwargs))
return [conn.recv() for conn in target_p_conn]
[docs]class ParallelEnvs(MultiProcWrapper):
"""
A utility class which wraps around multiple environments
and runs them in subprocesses
"""
def __init__(self, make_env_fn, n_envs: int = 1, base_seed: int = 0,
daemon: bool = True, autostart: bool = True):
self.observation_space, self.action_space = get_gym_spaces(make_env_fn)
obj_fns = [
functools.partial(make_env_fn,
None if base_seed is None else base_seed + rank)
for rank in range(1, n_envs + 1)
]
super().__init__(obj_fns, daemon=daemon, autostart=autostart)
[docs] def reset(self, env_ids: list):
return self.exec_remote('reset', proc_list=env_ids)
[docs] def step(self, env_ids: list, actions: list):
action_args = [[a] for a in actions]
return self.exec_remote('step', args_list=action_args, proc_list=env_ids)
[docs] def close(self):
self.exec_remote('close')
for proc in self.proc_list:
if proc.is_alive():
proc.terminate()
[docs] def render(self, env_ids: list):
self.exec_remote('render', proc_list=env_ids)