Skip to content

Pgx API

This is the list of all public APIs of Pgx. Two important components in Pgx are State and Env.

pgx.State

Bases: ABC

Base state class of all Pgx game environments. Basically an immutable (frozen) dataclass. A basic usage is generating via Env.init:

state = env.init(jax.random.PRNGKey(0))

and Env.step receives and returns this state class:

state = env.step(state, action)

Serialization via flax.struct.serialization is supported. There are 6 common attributes over all games:

Attributes:

Name Type Description
current_player Array

id of agent to play. Note that this does NOT represent the turn (e.g., black/white in Go). This ID is consistent over the parallel vmapped states.

observation Array

observation for the current state. Env.observe is called to compute.

rewards Array

the i-th element indicates the intermediate reward for the agent with player-id i. If Env.step is called for a terminal state, the following state.rewards is zero for all players.

terminated Array

denotes that the state is terminal state. Note that some environments (e.g., Go) have an max_termination_steps parameter inside and will terminate within a limited number of states (following AlphaGo).

truncated Array

indicates that the episode ends with the reason other than termination. Note that current Pgx environments do not invoke truncation but users can use TimeLimit wrapper to truncate the environment. In Pgx environments, some MinAtar games may not terminate within a finite timestep. However, the other environments are supposed to terminate within a finite timestep with probability one.

legal_action_mask Array

Boolean array of legal actions. If illegal action is taken, the game will terminate immediately with the penalty to the palyer.

Source code in pgx/core.py
@dataclass
class State(abc.ABC):
    """Base state class of all Pgx game environments. Basically an immutable (frozen) dataclass.
    A basic usage is generating via `Env.init`:

        state = env.init(jax.random.PRNGKey(0))

    and `Env.step` receives and returns this state class:

        state = env.step(state, action)

    Serialization via `flax.struct.serialization` is supported.
    There are 6 common attributes over all games:

    Attributes:
        current_player (Array): id of agent to play.
            Note that this does NOT represent the turn (e.g., black/white in Go).
            This ID is consistent over the parallel vmapped states.
        observation (Array): observation for the current state.
            `Env.observe` is called to compute.
        rewards (Array): the `i`-th element indicates the intermediate reward for
            the agent with player-id `i`. If `Env.step` is called for a terminal state,
            the following `state.rewards` is zero for all players.
        terminated (Array): denotes that the state is terminal state. Note that
            some environments (e.g., Go) have an `max_termination_steps` parameter inside
            and will terminate within a limited number of states (following AlphaGo).
        truncated (Array): indicates that the episode ends with the reason other than termination.
            Note that current Pgx environments do not invoke truncation but users can use `TimeLimit` wrapper
            to truncate the environment. In Pgx environments, some MinAtar games may not terminate within a finite timestep.
            However, the other environments are supposed to terminate within a finite timestep with probability one.
        legal_action_mask (Array): Boolean array of legal actions. If illegal action is taken,
            the game will terminate immediately with the penalty to the palyer.
    """

    current_player: Array
    observation: Array
    rewards: Array
    terminated: Array
    truncated: Array
    legal_action_mask: Array
    _step_count: Array

    @property
    @abc.abstractmethod
    def env_id(self) -> EnvId:
        """Environment id (e.g. "go_19x19")"""
        ...

    def _repr_html_(self) -> str:
        return self.to_svg()

    def to_svg(
        self,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> str:
        """Return SVG string. Useful for visualization in notebook.

        Args:
            color_theme (Optional[Literal["light", "dark"]]): xxx see also global config.
            scale (Optional[float]): change image size. Default(None) is 1.0

        Returns:
            str: SVG string
        """
        from pgx._src.visualizer import Visualizer

        v = Visualizer(color_theme=color_theme, scale=scale)
        return v.get_dwg(states=self).tostring()

    def save_svg(
        self,
        filename,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> None:
        """Save the entire state (not observation) to a file.
        The filename must end with `.svg`

        Args:
            color_theme (Optional[Literal["light", "dark"]]): xxx see also global config.
            scale (Optional[float]): change image size. Default(None) is 1.0

        Returns:
            None
        """
        from pgx._src.visualizer import save_svg

        save_svg(self, filename, color_theme=color_theme, scale=scale)

env_id: EnvId abstractmethod property

Environment id (e.g. "go_19x19")

save_svg(filename, *, color_theme=None, scale=None)

Save the entire state (not observation) to a file. The filename must end with .svg

Parameters:

Name Type Description Default
color_theme Optional[Literal['light', 'dark']]

xxx see also global config.

None
scale Optional[float]

change image size. Default(None) is 1.0

None

Returns:

Type Description
None

None

Source code in pgx/core.py
def save_svg(
    self,
    filename,
    *,
    color_theme: Optional[Literal["light", "dark"]] = None,
    scale: Optional[float] = None,
) -> None:
    """Save the entire state (not observation) to a file.
    The filename must end with `.svg`

    Args:
        color_theme (Optional[Literal["light", "dark"]]): xxx see also global config.
        scale (Optional[float]): change image size. Default(None) is 1.0

    Returns:
        None
    """
    from pgx._src.visualizer import save_svg

    save_svg(self, filename, color_theme=color_theme, scale=scale)

to_svg(*, color_theme=None, scale=None)

Return SVG string. Useful for visualization in notebook.

Parameters:

Name Type Description Default
color_theme Optional[Literal['light', 'dark']]

xxx see also global config.

None
scale Optional[float]

change image size. Default(None) is 1.0

None

Returns:

Name Type Description
str str

SVG string

Source code in pgx/core.py
def to_svg(
    self,
    *,
    color_theme: Optional[Literal["light", "dark"]] = None,
    scale: Optional[float] = None,
) -> str:
    """Return SVG string. Useful for visualization in notebook.

    Args:
        color_theme (Optional[Literal["light", "dark"]]): xxx see also global config.
        scale (Optional[float]): change image size. Default(None) is 1.0

    Returns:
        str: SVG string
    """
    from pgx._src.visualizer import Visualizer

    v = Visualizer(color_theme=color_theme, scale=scale)
    return v.get_dwg(states=self).tostring()

pgx.Env

Bases: ABC

Environment class API.

Example usage

env: Env = pgx.make("tic_tac_toe")
state = env.init(jax.random.PRNGKey(0))
action = jax.random.int32(4)
state = env.step(state, action)
Source code in pgx/core.py
class Env(abc.ABC):
    """Environment class API.

    !!! example "Example usage"

        ```py
        env: Env = pgx.make("tic_tac_toe")
        state = env.init(jax.random.PRNGKey(0))
        action = jax.random.int32(4)
        state = env.step(state, action)
        ```

    """

    def __init__(self): ...

    def init(self, key: PRNGKey) -> State:
        """Return the initial state. Note that no internal state of
        environment changes.

        Args:
            key: pseudo-random generator key in JAX. Consumed in this function.

        Returns:
            State: initial state of environment

        """
        state = self._init(key)
        observation = self.observe(state)
        return state.replace(observation=observation)  # type: ignore

    def step(
        self,
        state: State,
        action: Array,
        key: Optional[Array] = None,
    ) -> State:
        """Step function."""
        is_illegal = ~state.legal_action_mask[action]
        current_player = state.current_player

        # If the state is already terminated or truncated, environment does not take usual step,
        # but return the same state with zero-rewards for all players
        state = jax.lax.cond(
            (state.terminated | state.truncated),
            lambda: state.replace(rewards=jnp.zeros_like(state.rewards)),  # type: ignore
            lambda: self._step(state.replace(_step_count=state._step_count + 1), action, key),  # type: ignore
        )

        # Taking illegal action leads to immediate game terminal with negative reward
        state = jax.lax.cond(
            is_illegal,
            lambda: self._step_with_illegal_action(state, current_player),
            lambda: state,
        )

        # All legal_action_mask elements are **TRUE** at terminal state
        # This is to avoid zero-division error when normalizing action probability
        # Taking any action at terminal state does not give any effect to the state
        state = jax.lax.cond(
            state.terminated,
            lambda: state.replace(legal_action_mask=jnp.ones_like(state.legal_action_mask)),  # type: ignore
            lambda: state,
        )

        observation = self.observe(state)
        state = state.replace(observation=observation)  # type: ignore

        return state

    def observe(self, state: State, player_id: Optional[Array] = None) -> Array:
        """Observation function."""
        if player_id is None:
            player_id = state.current_player
        else:
            warnings.warn("[Pgx] `player_id` in `observe` is deprecated. This argument will be removed in the future.", DeprecationWarning)
        obs = self._observe(state, player_id)
        return jax.lax.stop_gradient(obs)

    @abc.abstractmethod
    def _init(self, key: PRNGKey) -> State:
        """Implement game-specific init function here."""
        ...

    @abc.abstractmethod
    def _step(self, state, action, key) -> State:
        """Implement game-specific step function here."""
        ...

    @abc.abstractmethod
    def _observe(self, state: State, player_id: Array) -> Array:
        """Implement game-specific observe function here."""
        ...

    @property
    @abc.abstractmethod
    def id(self) -> EnvId:
        """Environment id."""
        ...

    @property
    @abc.abstractmethod
    def version(self) -> str:
        """Environment version. Updated when behavior, parameter, or API is changed.
        Refactoring or speeding up without any expected behavior changes will NOT update the version number.
        """
        ...

    @property
    @abc.abstractmethod
    def num_players(self) -> int:
        """Number of players (e.g., 2 in Tic-tac-toe)"""
        ...

    @property
    def num_actions(self) -> int:
        """Return the size of action space (e.g., 9 in Tic-tac-toe)"""
        state = self.init(jax.random.PRNGKey(0))
        return int(state.legal_action_mask.shape[0])

    @property
    def observation_shape(self) -> Tuple[int, ...]:
        """Return the matrix shape of observation"""
        state = self.init(jax.random.PRNGKey(0))
        obs = self._observe(state, state.current_player)
        return obs.shape

    @property
    def _illegal_action_penalty(self) -> float:
        """Negative reward given when illegal action is selected."""
        return -1.0

    def _step_with_illegal_action(self, state: State, loser: Array) -> State:
        penalty = self._illegal_action_penalty
        reward = jnp.ones_like(state.rewards) * (-1 * penalty) * (self.num_players - 1)
        reward = reward.at[loser].set(penalty)
        return state.replace(rewards=reward, terminated=TRUE)  # type: ignore

id: EnvId abstractmethod property

Environment id.

num_actions: int property

Return the size of action space (e.g., 9 in Tic-tac-toe)

num_players: int abstractmethod property

Number of players (e.g., 2 in Tic-tac-toe)

observation_shape: Tuple[int, ...] property

Return the matrix shape of observation

version: str abstractmethod property

Environment version. Updated when behavior, parameter, or API is changed. Refactoring or speeding up without any expected behavior changes will NOT update the version number.

init(key)

Return the initial state. Note that no internal state of environment changes.

Parameters:

Name Type Description Default
key PRNGKey

pseudo-random generator key in JAX. Consumed in this function.

required

Returns:

Name Type Description
State State

initial state of environment

Source code in pgx/core.py
def init(self, key: PRNGKey) -> State:
    """Return the initial state. Note that no internal state of
    environment changes.

    Args:
        key: pseudo-random generator key in JAX. Consumed in this function.

    Returns:
        State: initial state of environment

    """
    state = self._init(key)
    observation = self.observe(state)
    return state.replace(observation=observation)  # type: ignore

observe(state, player_id=None)

Observation function.

Source code in pgx/core.py
def observe(self, state: State, player_id: Optional[Array] = None) -> Array:
    """Observation function."""
    if player_id is None:
        player_id = state.current_player
    else:
        warnings.warn("[Pgx] `player_id` in `observe` is deprecated. This argument will be removed in the future.", DeprecationWarning)
    obs = self._observe(state, player_id)
    return jax.lax.stop_gradient(obs)

step(state, action, key=None)

Step function.

Source code in pgx/core.py
def step(
    self,
    state: State,
    action: Array,
    key: Optional[Array] = None,
) -> State:
    """Step function."""
    is_illegal = ~state.legal_action_mask[action]
    current_player = state.current_player

    # If the state is already terminated or truncated, environment does not take usual step,
    # but return the same state with zero-rewards for all players
    state = jax.lax.cond(
        (state.terminated | state.truncated),
        lambda: state.replace(rewards=jnp.zeros_like(state.rewards)),  # type: ignore
        lambda: self._step(state.replace(_step_count=state._step_count + 1), action, key),  # type: ignore
    )

    # Taking illegal action leads to immediate game terminal with negative reward
    state = jax.lax.cond(
        is_illegal,
        lambda: self._step_with_illegal_action(state, current_player),
        lambda: state,
    )

    # All legal_action_mask elements are **TRUE** at terminal state
    # This is to avoid zero-division error when normalizing action probability
    # Taking any action at terminal state does not give any effect to the state
    state = jax.lax.cond(
        state.terminated,
        lambda: state.replace(legal_action_mask=jnp.ones_like(state.legal_action_mask)),  # type: ignore
        lambda: state,
    )

    observation = self.observe(state)
    state = state.replace(observation=observation)  # type: ignore

    return state

pgx.EnvId = Literal['2048', 'animal_shogi', 'backgammon', 'bridge_bidding', 'chess', 'connect_four', 'gardner_chess', 'go_9x9', 'go_19x19', 'hex', 'kuhn_poker', 'leduc_holdem', 'minatar-asterix', 'minatar-breakout', 'minatar-freeway', 'minatar-seaquest', 'minatar-space_invaders', 'othello', 'shogi', 'sparrow_mahjong', 'tic_tac_toe'] module-attribute

Naming convention of EnvId

Hyphen - is used to represent that there is a different original game source (e.g., MinAtar), and underscore - is used for the other cases.

pgx.make(env_id)

Load the specified environment.

Example usage

env = pgx.make("tic_tac_toe")

BridgeBidding environment

BridgeBidding environment requires the domain knowledge of bridge game. So we forbid users to load the bridge environment by make("bridge_bidding"). Use BridgeBidding class directly by from pgx.bridge_bidding import BridgeBidding.

Source code in pgx/core.py
def make(env_id: EnvId):  # noqa: C901
    """Load the specified environment.

    !!! example "Example usage"

        ```py
        env = pgx.make("tic_tac_toe")
        ```

    !!! note "`BridgeBidding` environment"

        `BridgeBidding` environment requires the domain knowledge of bridge game.
        So we forbid users to load the bridge environment by `make("bridge_bidding")`.
        Use `BridgeBidding` class directly by `from pgx.bridge_bidding import BridgeBidding`.

    """
    # NOTE: BridgeBidding environment requires the domain knowledge of bridge
    # So we forbid users to load the bridge environment by `make("bridge_bidding")`.
    if env_id == "2048":
        from pgx.play2048 import Play2048

        return Play2048()
    elif env_id == "animal_shogi":
        from pgx.animal_shogi import AnimalShogi

        return AnimalShogi()
    elif env_id == "backgammon":
        from pgx.backgammon import Backgammon

        return Backgammon()
    elif env_id == "chess":
        from pgx.chess import Chess

        return Chess()
    elif env_id == "connect_four":
        from pgx.connect_four import ConnectFour

        return ConnectFour()
    elif env_id == "gardner_chess":
        from pgx.gardner_chess import GardnerChess

        return GardnerChess()
    elif env_id == "go_9x9":
        from pgx.go import Go

        return Go(size=9, komi=7.5)
    elif env_id == "go_19x19":
        from pgx.go import Go

        return Go(size=19, komi=7.5)
    elif env_id == "hex":
        from pgx.hex import Hex

        return Hex()
    elif env_id == "kuhn_poker":
        from pgx.kuhn_poker import KuhnPoker

        return KuhnPoker()
    elif env_id == "leduc_holdem":
        from pgx.leduc_holdem import LeducHoldem

        return LeducHoldem()
    # elif env_id == "mahjong":
    #     from pgx.mahjong import Mahjong

    #     return Mahjong()
    elif env_id == "minatar-asterix":
        from pgx.minatar.asterix import MinAtarAsterix  # type: ignore

        return MinAtarAsterix()
    elif env_id == "minatar-breakout":
        from pgx.minatar.breakout import MinAtarBreakout  # type: ignore

        return MinAtarBreakout()
    elif env_id == "minatar-freeway":
        from pgx.minatar.freeway import MinAtarFreeway  # type: ignore

        return MinAtarFreeway()
    elif env_id == "minatar-seaquest":
        from pgx.minatar.seaquest import MinAtarSeaquest  # type: ignore

        return MinAtarSeaquest()
    elif env_id == "minatar-space_invaders":
        from pgx.minatar.space_invaders import MinAtarSpaceInvaders  # type: ignore

        return MinAtarSpaceInvaders()
    elif env_id == "othello":
        from pgx.othello import Othello

        return Othello()
    elif env_id == "shogi":
        from pgx.shogi import Shogi

        return Shogi()
    elif env_id == "sparrow_mahjong":
        from pgx.sparrow_mahjong import SparrowMahjong

        return SparrowMahjong()
    elif env_id == "tic_tac_toe":
        from pgx.tic_tac_toe import TicTacToe

        return TicTacToe()
    else:
        envs = "\n".join(available_envs())
        raise ValueError(f"Wrong env_id '{env_id}' is passed. Available ids are: \n{envs}")

pgx.available_envs()

List up all environment id available in pgx.make function.

Example usage

pgx.available_envs()
('2048', 'animal_shogi', 'backgammon', 'chess', 'connect_four', 'go_9x9', 'go_19x19', 'hex', 'kuhn_poker', 'leduc_holdem', 'minatar-asterix', 'minatar-breakout', 'minatar-freeway', 'minatar-seaquest', 'minatar-space_invaders', 'othello', 'shogi', 'sparrow_mahjong', 'tic_tac_toe')

BridgeBidding environment

BridgeBidding environment requires the domain knowledge of bridge game. So we forbid users to load the bridge environment by make("bridge_bidding"). Use BridgeBidding class directly by from pgx.bridge_bidding import BridgeBidding.

Source code in pgx/core.py
def available_envs() -> Tuple[EnvId, ...]:
    """List up all environment id available in `pgx.make` function.

    !!! example "Example usage"

        ```py
        pgx.available_envs()
        ('2048', 'animal_shogi', 'backgammon', 'chess', 'connect_four', 'go_9x9', 'go_19x19', 'hex', 'kuhn_poker', 'leduc_holdem', 'minatar-asterix', 'minatar-breakout', 'minatar-freeway', 'minatar-seaquest', 'minatar-space_invaders', 'othello', 'shogi', 'sparrow_mahjong', 'tic_tac_toe')
        ```


    !!! note "`BridgeBidding` environment"

        `BridgeBidding` environment requires the domain knowledge of bridge game.
        So we forbid users to load the bridge environment by `make("bridge_bidding")`.
        Use `BridgeBidding` class directly by `from pgx.bridge_bidding import BridgeBidding`.

    """
    games = get_args(EnvId)
    games = tuple(filter(lambda x: x != "bridge_bidding", games))
    return games

pgx.set_visualization_config(*, color_theme='light', scale=1.0, frame_duration_seconds=0.2)

Source code in pgx/_src/visualizer.py
def set_visualization_config(
    *,
    color_theme: ColorTheme = "light",
    scale: float = 1.0,
    frame_duration_seconds: float = 0.2,
):
    global_config.color_theme = color_theme
    global_config.scale = scale
    global_config.frame_duration_seconds = frame_duration_seconds

pgx.save_svg(state, filename, *, color_theme=None, scale=None)

Source code in pgx/_src/visualizer.py
def save_svg(
    state: State,
    filename: Union[str, Path],
    *,
    color_theme: Optional[Literal["light", "dark"]] = None,
    scale: Optional[float] = None,
) -> None:
    if state.env_id.startswith("minatar"):
        state.save_svg(filename=filename)
    else:
        v = Visualizer(color_theme=color_theme, scale=scale)
        v.get_dwg(states=state).saveas(filename)

pgx.save_svg_animation(states, filename, *, color_theme=None, scale=None, frame_duration_seconds=None)

Source code in pgx/_src/visualizer.py
def save_svg_animation(
    states: Sequence[State],
    filename: Union[str, Path],
    *,
    color_theme: Optional[Literal["light", "dark"]] = None,
    scale: Optional[float] = None,
    frame_duration_seconds: Optional[float] = None,
) -> None:
    assert not states[0].env_id.startswith("minatar"), "MinAtar does not support svg animation."
    v = Visualizer(color_theme=color_theme, scale=scale)

    if frame_duration_seconds is None:
        frame_duration_seconds = global_config.frame_duration_seconds

    frame_groups = []
    dwg = None
    for i, state in enumerate(states):
        dwg = v.get_dwg(states=state)
        assert (
            len([e for e in dwg.elements if type(e) is svgwrite.container.Group]) == 1
        ), "Drawing must contain only one group"
        group: svgwrite.container.Group = dwg.elements[-1]
        group["id"] = f"_fr{i:x}"  # hex frame number
        group["class"] = "frame"
        frame_groups.append(group)

    assert dwg is not None
    del dwg.elements[-1]
    total_seconds = frame_duration_seconds * len(frame_groups)

    style = f".frame{{visibility:hidden; animation:{total_seconds}s linear _k infinite;}}"
    style += f"@keyframes _k{{0%,{100/len(frame_groups)}%{{visibility:visible}}{100/len(frame_groups) * 1.000001}%,100%{{visibility:hidden}}}}"

    for i, group in enumerate(frame_groups):
        dwg.add(group)
        style += f"#{group['id']}{{animation-delay:{i * frame_duration_seconds}s}}"
    dwg.defs.add(svgwrite.container.Style(content=style))
    dwg.saveas(filename)

pgx.BaselineModelId = Literal['animal_shogi_v0', 'gardner_chess_v0', 'go_9x9_v0', 'hex_v0', 'othello_v0', 'minatar-asterix_v0', 'minatar-breakout_v0', 'minatar-freeway_v0', 'minatar-seaquest_v0', 'minatar-space_invaders_v0'] module-attribute

pgx.make_baseline_model(model_id, download_dir='baselines')

Source code in pgx/_src/baseline.py
def make_baseline_model(model_id: BaselineModelId, download_dir: str = "baselines"):
    if model_id in (
        "animal_shogi_v0",
        "gardner_chess_v0",
        "go_9x9_v0",
        "hex_v0",
        "othello_v0",
    ):
        return _make_az_baseline_model(model_id, download_dir)
    elif model_id in (
        "minatar-asterix_v0",
        "minatar-breakout_v0",
        "minatar-freeway_v0",
        "minatar-seaquest_v0",
        "minatar-space_invaders_v0",
    ):
        return _make_minatar_baseline_model(model_id, download_dir)
    else:
        assert False

pgx.api_test(env, num=100, use_key=True)

Source code in pgx/_src/api_test.py
def api_test(env: Env, num: int = 100, use_key=True):
    api_test_single(env, num, use_key)
    api_test_batch(env, num, use_key)