Source code for scml.oneshot.rl.env

from typing import Any

import gymnasium as gym
from gymnasium.envs.registration import register

from scml.oneshot.agent import OneShotAgent
from scml.oneshot.agents import Placeholder
from scml.oneshot.context import BaseContext, FixedPartnerNumbersOneShotContext
from scml.oneshot.rl.action import ActionManager
from scml.oneshot.rl.observation import ObservationManager
from scml.oneshot.rl.reward import DefaultRewardFunction, RewardFunction
from scml.oneshot.world import SCMLBaseWorld

__all__ = ["OneShotEnv"]


[docs] class OneShotEnv(gym.Env): def __init__( self, action_manager: ActionManager, observation_manager: ObservationManager, reward_function: RewardFunction = DefaultRewardFunction(), context: BaseContext = FixedPartnerNumbersOneShotContext(), agent_type: type[OneShotAgent] = Placeholder, agent_params: dict[str, Any] | None = None, extra_checks: bool = True, skip_after_negotiations: bool = True, render_mode=None, debug=False, ): assert context.contains_context( action_manager.context, raise_on_failure=debug ), ( "Action Manager is not compatible with the given environment.\n" "Some worlds that can be generated by this environment are not handled" " correctly by this action manager" ) assert context.contains_context( observation_manager.context, raise_on_failure=debug ), ( "observation Manager is not compatible with the given environment.\n" "Some worlds that can be generated by this environment are not handled" " correctly by this observation manager" ) self._skip_after_negotiations = skip_after_negotiations self._extra_checks = extra_checks self._reward_function = reward_function self._world: SCMLBaseWorld = None # type: ignore self._agent_type = agent_type self._agent_params = agent_params if agent_params is not None else dict() self._agent_id: str = "" self._agent: OneShotAgent = None # type: ignore self._obs_manager = observation_manager self._action_manager = action_manager self._context = context self.action_space = action_manager.make_space() self.observation_space = observation_manager.make_space() self.render_mode = render_mode # self.reset()
[docs] def _get_obs(self): return self._obs_manager.encode(self._agent.awi)
# return {"agent": self._agent_location, "target": self._target_location}
[docs] def calc_info(self): """Calculates info to be returned from `step()`.""" return dict()
[docs] def _render_frame(self): """Used for rendering. Override with your rendering code""" pass
[docs] def close(self): pass
[docs] def render(self): return self._render_frame()
[docs] def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[Any, dict[str, Any]]: _ = options import random random.seed(seed) self._world, agents = self._context.generate( types=(self._agent_type,), params=(self._agent_params,), ) assert len(agents) == 1 self._agent = agents[0] if self._extra_checks: assert self._world in self._context self._agent_id = self._agent.id self._world.step_with(dict(), init=True) observation = self._get_obs() info = self.calc_info() if self.render_mode == "human": self._render_frame() return observation, info
[docs] def step(self, action): # type: ignore reward_info = self._reward_function.before_action(self._agent.awi) # score_before = self._world.scores()[self._agent_id] decoded_action = self._action_manager.decode(self._agent.awi, action) terminated = not self._world.step_with( {self._agent_id: decoded_action} # type: ignore ) reward = self._reward_function(self._agent.awi, decoded_action, reward_info) if self._world.current_step >= self._world.n_steps - 1: terminated = 1 if self._skip_after_negotiations: while not terminated and len(self._agent.awi.current_states) < 1: reward_info = self._reward_function.before_action(self._agent.awi) terminated = not self._world.step_with( {self._agent_id: dict()} # type: ignore ) reward += self._reward_function(self._agent.awi, dict(), reward_info) if self._world.current_step >= self._world.n_steps - 1: terminated = 1 obs = self._get_obs() info = self.calc_info() if self.render_mode == "human": self._render_frame() return obs, reward, terminated, False, info
register( id="scml/OneShot-v0", entry_point="scml.oneshot.rl.env:OneShotEnv", max_episode_steps=None, )