Source code for torchrl.policies.ou_noise

import numpy as np


[docs]class OUNoise: def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000): self.mu = mu self.theta = theta self.sigma = max_sigma self.max_sigma = max_sigma self.min_sigma = min_sigma self.decay_period = decay_period self.action_dim = action_space.shape[0] self.low = action_space.low self.high = action_space.high self.state = None self.reset()
[docs] def reset(self): self.state = np.ones(self.action_dim) * self.mu
[docs] def evolve_state(self): x = self.state dx = self.theta * (self.mu - x) + \ self.sigma * np.random.randn(self.action_dim) self.state = x + dx return self.state
[docs] def get_action(self, action, t=0): ou_state = self.evolve_state() self.sigma = self.max_sigma - \ (self.max_sigma - self.min_sigma) * \ min(1.0, t / self.decay_period) return np.clip(action + ou_state, self.low, self.high)