trainer
- class rbgame.trainer.DecentralizedTrainer(env_args, num_train_envs=16, num_test_envs=16, batch_size=64, update_freq=100, test_freq=100, episodes_per_train=5000, episodes_per_test=50, train_fn=None, test_fn=None, save_best_fn=None, save_last_fn=None, stop_fn=None, reward_metric=None, shared_memory=True)[source]
Bases:
objectA decentralized trainer.
- Parameters:
num_train_envs (
int) – Number enviroments forDummyVectorEnvused in train phase.num_test_envs (
int) – Number enviroments forDummyVectorEnvused in test phase.batch_size (
int) – Batch size.update_freq (
int) – After how many steps do a policy update, now only update by steps.test_freq (
int) – After how many episodes do a test.episodes_per_train (
int) – Total number episodes of training.episodes_per_test (
int) – Total number episodes in a test.train_fn (
Optional[Callable[[int,int],None]]) – A hook called after eachnum_train_envepisodes during training. It can be used to perform custom additional operations, with the signaturef(num_collected_episodes: int, num_collected_steps: int) -> None.test_fn (
Optional[Callable[[int,int],None]]) – A hook called after eachnum_test_envepisodes during testing. It can be used to perform custom additional operations, with the signaturef(num_collected_episodes: int, num_collected_steps: int) -> None.save_best_fn (
Optional[Callable[[int],None]]) – A hook called when the reward metric get better during training. with the signaturef(episode_to_call: int) -> None.save_last_fn (
Optional[Callable[[],None]]) – A hook called when training has finished.stop_fn (
Optional[Callable[[float,int],bool]]) – A hook called after eachnum_train_envepisodes during training with the signaturef(reward_to_stop: int, episode_to_stop: int) -> bool.reward_metric (
Optional[Callable[[ndarray],float]]) – A function with signaturef(rewards: np.ndarray with shape (num_episode, agent_num)) -> a scalar np.ndarray. We need to return a single scalar to monitor training. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents.
- train(agents, learning_mask, exploration_mask=None, plot=True)[source]
Agents play together to learn.
- Parameters:
agents (
list[RLAgent]) –listof agents, which participate in game.learning_mask (
list|ndarray) – A binary vector to define which agent need to learn.exploration_mask (
Union[list,ndarray,None]) – A binary vector to define how agent behaves within training. Whether explore or not for off-policy agent and whether random sample or get mode for on-policy agent. Default toNone, which mean all agents explore during training.plot (
bool) – Plot a graph of metric evolulation and save it.
- Return type:
- Returns:
Training statistic.