from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import register
from scml.oneshot.agent import OneShotAgent
from scml.oneshot.agents import Placeholder
from scml.oneshot.context import BaseContext, FixedPartnerNumbersOneShotContext
from scml.oneshot.rl.action import ActionManager
from scml.oneshot.rl.observation import ObservationManager
from scml.oneshot.rl.reward import DefaultRewardFunction, RewardFunction
from scml.oneshot.world import SCMLBaseWorld
__all__ = ["OneShotEnv"]
[docs]
class OneShotEnv(gym.Env):
def __init__(
self,
action_manager: ActionManager,
observation_manager: ObservationManager,
reward_function: RewardFunction = DefaultRewardFunction(),
context: BaseContext = FixedPartnerNumbersOneShotContext(),
agent_type: type[OneShotAgent] = Placeholder,
agent_params: dict[str, Any] | None = None,
extra_checks: bool = True,
skip_after_negotiations: bool = True,
render_mode=None,
debug=False,
):
assert context.contains_context(
action_manager.context, raise_on_failure=debug
), (
"Action Manager is not compatible with the given environment.\n"
"Some worlds that can be generated by this environment are not handled"
" correctly by this action manager"
)
assert context.contains_context(
observation_manager.context, raise_on_failure=debug
), (
"observation Manager is not compatible with the given environment.\n"
"Some worlds that can be generated by this environment are not handled"
" correctly by this observation manager"
)
self._skip_after_negotiations = skip_after_negotiations
self._extra_checks = extra_checks
self._reward_function = reward_function
self._world: SCMLBaseWorld = None # type: ignore
self._agent_type = agent_type
self._agent_params = agent_params if agent_params is not None else dict()
self._agent_id: str = ""
self._agent: OneShotAgent = None # type: ignore
self._obs_manager = observation_manager
self._action_manager = action_manager
self._context = context
self.action_space = action_manager.make_space()
self.observation_space = observation_manager.make_space()
self.render_mode = render_mode
# self.reset()
[docs]
def _get_obs(self):
return self._obs_manager.encode(self._agent.awi)
# return {"agent": self._agent_location, "target": self._target_location}
[docs]
def calc_info(self):
"""Calculates info to be returned from `step()`."""
return dict()
[docs]
def _render_frame(self):
"""Used for rendering. Override with your rendering code"""
pass
[docs]
def render(self):
return self._render_frame()
[docs]
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
_ = options
import random
random.seed(seed)
self._world, agents = self._context.generate(
types=(self._agent_type,),
params=(self._agent_params,),
)
assert len(agents) == 1
self._agent = agents[0]
if self._extra_checks:
assert self._world in self._context
self._agent_id = self._agent.id
self._world.step_with(dict(), init=True)
observation = self._get_obs()
info = self.calc_info()
if self.render_mode == "human":
self._render_frame()
return observation, info
[docs]
def step(self, action): # type: ignore
reward_info = self._reward_function.before_action(self._agent.awi)
# score_before = self._world.scores()[self._agent_id]
decoded_action = self._action_manager.decode(self._agent.awi, action)
terminated = not self._world.step_with(
{self._agent_id: decoded_action} # type: ignore
)
reward = self._reward_function(self._agent.awi, decoded_action, reward_info)
if self._world.current_step >= self._world.n_steps - 1:
terminated = 1
if self._skip_after_negotiations:
while not terminated and len(self._agent.awi.current_states) < 1:
reward_info = self._reward_function.before_action(self._agent.awi)
terminated = not self._world.step_with(
{self._agent_id: dict()} # type: ignore
)
reward += self._reward_function(self._agent.awi, dict(), reward_info)
if self._world.current_step >= self._world.n_steps - 1:
terminated = 1
obs = self._get_obs()
info = self.calc_info()
if self.render_mode == "human":
self._render_frame()
return obs, reward, terminated, False, info
register(
id="scml/OneShot-v0",
entry_point="scml.oneshot.rl.env:OneShotEnv",
max_episode_steps=None,
)