Pgx API Usage
Example.1: Random play
import jax
import jax.numpy as jnp
import pgx
seed = 42
batch_size = 10
key = jax.random.PRNGKey(seed)
def act_randomly(rng_key, obs, mask):
"""Ignore observation and choose randomly from legal actions"""
del obs
probs = mask / mask.sum()
logits = jnp.maximum(jnp.log(probs), jnp.finfo(probs.dtype).min)
return jax.random.categorical(rng_key, logits=logits, axis=-1)
# Load the environment
env = pgx.make("go_9x9")
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
# Initialize the states
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, batch_size)
state = init_fn(keys)
# Run random simulation
while not (state.terminated | state.truncated).all():
key, subkey = jax.random.split(key)
action = act_randomly(subkey, state.observation, state.legal_action_mask)
state = step_fn(state, action) # state.reward (2,)
Example.2: Random agent vs Baseline model
This illustrative example helps to understand
- How
state.current_player
is defined - How to access the reward of each player
- How
Env.step
behaves against already terminated states - How to use baseline models probided by Pgx
import jax
import jax.numpy as jnp
import pgx
from pgx.experimental.utils import act_randomly
seed = 42
batch_size = 10
key = jax.random.PRNGKey(seed)
# Prepare agent A and B
# Agent A: random player
# Agent B: baseline player provided by Pgx
A = 0
B = 1
# Load the environment
env = pgx.make("go_9x9")
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
# Prepare baseline model
# Note that it additionaly requires Haiku library ($ pip install dm-haiku)
model_id = "go_9x9_v0"
model = pgx.make_baseline_model(model_id)
# Initialize the states
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, batch_size)
state = init_fn(keys)
print(f"Game index: {jnp.arange(batch_size)}") # [0 1 2 3 4 5 6 7 8 9]
print(f"Black player: {state.current_player}") # [1 1 0 1 0 0 1 1 1 1]
# In other words
print(f"A is black: {state.current_player == A}") # [False False True False True True False False False False]
print(f"B is black: {state.current_player == B}") # [ True True False True False False True True True True]
# Run simulation
R = state.rewards
while not (state.terminated | state.truncated).all():
# Action of random player A
key, subkey = jax.random.split(key)
action_A = jax.jit(act_randomly)(subkey, state)
# Greedy action of baseline model B
logits, value = model(state.observation)
action_B = logits.argmax(axis=-1)
action = jnp.where(state.current_player == A, action_A, action_B)
state = step_fn(state, action)
R += state.rewards
print(f"Return of agent A = {R[:, A]}") # [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]
print(f"Return of agent B = {R[:, B]}") # [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Note that we can avoid to explicitly deal with the first batch dimension like [:, A]
by using vmap
later.