rl_agent

class rbgame.agent.rl_agent.RLAgent(policy, memory=None, update_per_step=1.0, repeat_per_collect=1000)[source]

Bases: BaseAgent

Base Reinforcement Learning agent.

Parameters:
  • policy (BasePolicy) – Policy.

  • memory (Optional[VectorReplayBuffer]) – Replay Buffer.

  • update_per_step (float) – How many times agent samples from memory and learns per one step, using only in offpolicy algorithms.

  • repeat_per_collect (float) – How many times agents learns on sampled data, using only in onpolicy algorithms.

abstract get_action(obs)

Compute action from observation.

Parameters:

obs (dict[str, ndarray]) – Observation and action mask from game.

Return type:

int

Returns:

Action.

infer_act(obs_batch, exploration_noise)[source]

Forward batch of observations through network.

Parameters:
  • obs_batch (Batch) – Batch of observations.

  • exploration_noise (bool) – Exploration or not.

Return type:

ndarray

Returns:

Batch of actions.

abstract policy_update_fn(batch_size, num_collected_steps)[source]

Update policy.

Parameters:
  • batch_size (int) – Batch size.

  • num_collected_step – Number collected steps.

Return type:

int

Returns:

Number gradient steps.

class rbgame.agent.rl_agent.OffPolicyAgent(policy, memory=None, update_per_step=1.0, repeat_per_collect=1000)[source]

Bases: RLAgent

get_action(obs)[source]

Compute action from observation.

Parameters:

obs (dict[str, ndarray]) – Observation and action mask from game.

Return type:

int

Returns:

Action.

infer_act(obs_batch, exploration_noise)

Forward batch of observations through network.

Parameters:
  • obs_batch (Batch) – Batch of observations.

  • exploration_noise (bool) – Exploration or not.

Return type:

ndarray

Returns:

Batch of actions.

policy_update_fn(batch_size, num_collected_steps)[source]

Update policy. For offpolicy algorithms, agent samples batch_size of transitions from replay buffer to learn and repeats it several times.

Parameters:
  • batch_size (int) – Batch size.

  • num_collected_step – Number collected steps.

Return type:

int

Returns:

Number gradient steps.