import random
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import Any, Iterable, Union, Sequence
import numpy as np
from attr import define, field
from scml.oneshot.agents.greedy import GreedyOneShotAgent
from scml.oneshot.agents.rand import (
EqualDistOneShotAgent,
NiceAgent,
RandDistOneShotAgent,
)
from negmas.helpers.strings import unique_name
from scml.common import intin, isin, isinclass, isinfloat, isinobject, make_array
from scml.oneshot.agent import OneShotAgent
from scml.oneshot.agents import Placeholder
from scml.oneshot.awi import OneShotAWI
from scml.oneshot.common import is_system_agent
from scml.oneshot.world import (
SCML2021OneShotWorld,
SCML2022OneShotWorld,
SCML2023OneShotWorld,
SCML2024OneShotWorld,
SCMLBaseWorld,
get_n_agents_per_process,
get_n_lines,
)
__all__ = [
"Context",
"GeneralContext",
"ANACContext",
"LimitedPartnerNumbersContext",
"FixedPartnerNumbersContext",
"ANACOneShotContext",
"LimitedPartnerNumbersOneShotContext",
"FixedPartnerNumbersOneShotContext",
"SupplierContext",
"ConsumerContext",
"StrongSupplierContext",
"StrongConsumerContext",
"WeakSupplierContext",
"WeakConsumerContext",
"BalancedSupplierContext",
"BalancedConsumerContext",
"RepeatingContext",
"ContextParams",
"MonopolicContext",
"SingleAgentPerLevelSupplierContext",
"EutopiaContext",
"EutopiaConsumerContext",
"EutopiaSupplierContext",
]
[docs]
DefaultAgentsOneShot = (
GreedyOneShotAgent,
RandDistOneShotAgent,
EqualDistOneShotAgent,
)
@define
[docs]
class ContextParams:
"""Basic Parameters you can assume about a context. Returned by `extract_context_params`"""
class Strength(Enum):
Weak = -1
Balanced = 0
Strong = 1
def sample_with_strength(
c: int | tuple[int, int], n: int | tuple[int, int], s: Strength | None
):
if s is None:
c = intin(c)
n = intin(n)
return c, n
mnc, mxc = c if isinstance(c, Iterable) else (c, c)
if s == Strength.Balanced:
c = intin((mnc + 1, mxc - 1))
mn, mx = c, c + 2
elif s == Strength.Strong:
c = intin((mnc + 1, mxc))
mn = n[0] if isinstance(n, Iterable) else n
mx = max(mn, c - 1)
else:
c = intin((mnc, mxc - 1))
mx = n[1] if isinstance(n, Iterable) else n
mn = min(c + 2, mx)
return c, intin((mn, mx))
def safemax(x: Iterable | int):
if isinstance(x, Iterable):
return max(x)
return x
def safemin(x: Iterable | int):
if isinstance(x, Iterable):
return min(x)
return x
[docs]
class Context(ABC):
"""A context used for generating worlds satisfying predefined conditions and testing for them"""
[docs]
def __call__(self, *args, **kwargs):
return self.generate(*args, **kwargs)
@abstractmethod
[docs]
def generate(
self,
types: tuple[type[OneShotAgent], ...] | None = None,
params: tuple[dict[str, Any], ...] | None = None,
name: str | None = None,
) -> tuple[SCMLBaseWorld, tuple[OneShotAgent]]:
"""
Generates a world with one or more agents to be controlled externally and returns both
Args:
agent_types: The types of a list of agents to be guaranteed to exist in the world
agent_params: The parameters to pass to the constructors of these agents. None means no parameters for any agents
name: The name of the worlds to generate. Uses a random name if not given
Returns:
The constructed world and a tuple of the agents created corresponding (in order) to the given agent types/params
"""
...
@abstractmethod
[docs]
def is_valid_world(
self,
world: SCMLBaseWorld,
) -> bool:
"""Checks that the given world could have been generated from this context"""
...
@abstractmethod
[docs]
def is_valid_awi(self, awi: OneShotAWI) -> bool:
"""Checks that the given AWI is connected to a world that could have been generated from this context"""
...
@abstractmethod
[docs]
def contains_context(self, context: "Context") -> bool:
"""Checks that the any world generated from the given `context` could have been generated from this context"""
...
[docs]
def __contains__(self, other: "Union[SCMLBaseWorld, OneShotAWI, Context]") -> bool:
if isinstance(other, Context):
return self.contains_context(other)
if isinstance(other, OneShotAWI):
return self.is_valid_awi(other)
return self.is_valid_world(other)
N_SUPPLIERS = (4, 8)
"""Numbers of suppliers supported"""
N_CONSUMERS = (4, 8)
"""Numbers of consumers supported"""
NTESTS = 20
DEFAULT_PLACEHOLDER_AGENT_TYPES = (Placeholder,)
WARN_ON_FAILURE = True
RAISE_ON_FAILURE = False
def _is(
condition: bool,
raise_on_failure=RAISE_ON_FAILURE,
warn_on_failure=WARN_ON_FAILURE,
message: str = "",
) -> bool:
if not condition:
return False
if raise_on_failure:
raise AssertionError(message)
if warn_on_failure:
warnings.warn(message)
return True
def _not(
condition: bool,
raise_on_failure=RAISE_ON_FAILURE,
warn_on_failure=WARN_ON_FAILURE,
message: str = "",
) -> bool:
if condition:
return False
if raise_on_failure:
raise AssertionError(message)
if warn_on_failure:
warnings.warn(message)
return True
def _safeget(d: dict[str, dict[str, Any]], x: str, y: str):
return d.get(x, dict()).get(y, None)
def _world_matches_config(
world: SCMLBaseWorld,
config: dict[str, Any],
expected_types: Iterable[type[OneShotAgent] | str] | None,
expected_world_type: type[SCMLBaseWorld] | None = None,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
):
if _is(
world.perishable != _safeget(config, "info", "perishable"),
raise_on_failure,
warn_on_failure,
f'{world.perishable=} != f{_safeget(config, "info", "perishable")=}',
):
return False
if _not(
isin(world.n_steps, _safeget(config, "info", "n_steps")),
raise_on_failure,
warn_on_failure,
f'not isin({world.n_steps=}, {_safeget(config, "info", "n_steps")=})',
):
return False
if _not(
isin(world.n_processes, _safeget(config, "info", "n_processes")),
raise_on_failure,
warn_on_failure,
"not isin(world.n_processes, self.n_processes)",
):
return False
if _not(
isin(world.info["n_lines"], _safeget(config, "info", "n_lines")),
raise_on_failure,
warn_on_failure,
'not isin(world.info["n_lines"], _safeget(config,"info","n_lines"))',
):
return False
if _not(
all(
isin(_, _safeget(config, "info", "n_agents_per_process"))
for _ in world.info["n_agents_per_process"]
),
raise_on_failure,
warn_on_failure,
"not all( isin(_, self.n_agents_per_process) for _ in world.info['n_agents_per_process'])",
):
return False
if _not(
isin(
world.info["process_inputs_generator"],
_safeget(config, "info", "process_inputs"),
),
raise_on_failure,
warn_on_failure,
"not isin(world.info['process_inputs_generator'], self.process_inputs)",
):
return False
if _not(
isin(
world.info["process_outputs_generator"],
_safeget(config, "info", "process_outputs"),
),
raise_on_failure,
warn_on_failure,
"not isin(world.info['process_outputs_generator'], self.process_outputs)",
):
return False
if _not(
isin(
world.info["production_costs"],
_safeget(config, "info", "production_costs"),
),
raise_on_failure,
warn_on_failure,
"not isin(world.info['production_costs'], self.production_costs)",
):
return False
if _not(
isinfloat(world.info["profit_means"], _safeget(config, "info", "profit_means")),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['profit_means'], self.profit_means)",
):
return False
if _not(
isinfloat(
world.info["profit_stddevs"],
_safeget(config, "info", "profit_stddevs"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['profit_stddevs'], self.profit_stddevs)",
):
return False
if _not(
isinfloat(
world.info["max_productivity"],
_safeget(config, "info", "max_productivity"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['max_productivity'], self.max_productivity)",
):
return False
if _is(
_safeget(config, "info", "initial_balance") is not None
and not isin(
world.info["initial_balance"],
_safeget(config, "info", "initial_balance"),
),
raise_on_failure,
warn_on_failure,
"self.initial_balance is not None and not isin(world.info['initial_balance'], self.initial_balance)",
):
return False
if _not(
isinfloat(
world.info["exogenous_supply_predictability"],
_safeget(config, "info", "exogenous_supply_predictability"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat( world.info['exogenous_supply_predictability'], self.exogenous_supply_predictability,)",
):
return False
if _not(
isinfloat(
world.info["exogenous_sales_predictability"],
_safeget(config, "info", "exogenous_sales_predictability"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat( world.info['exogenous_sales_predictability'], self.exogenous_sales_predictability,)",
):
return False
if _not(
isinfloat(
world.info["exogenous_control"],
_safeget(config, "info", "exogenous_control"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['exogenous_control'], self.exogenous_control)",
):
return False
if _not(
isinfloat(
world.info["cash_availability"],
_safeget(config, "info", "cash_availability"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['cash_availability'], self.cash_availability)",
):
return False
if _not(
isinfloat(
world.info["shortfall_penalty"],
_safeget(config, "info", "shortfall_penalty"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['shortfall_penalty'], self.shortfall_penalty)",
):
return False
if _not(
isinfloat(
world.info["shortfall_penalty_dev"],
_safeget(config, "info", "shortfall_penalty_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat( world.info['shortfall_penalty_dev'], self.shortfall_penalty_dev)",
):
return False
if _not(
isinfloat(
world.info["disposal_cost"], _safeget(config, "info", "disposal_cost")
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['disposal_cost'], self.disposal_cost)",
):
return False
if _not(
isinfloat(
world.info["disposal_cost_dev"],
_safeget(config, "info", "disposal_cost_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['disposal_cost_dev'], self.disposal_cost_dev)",
):
return False
if _not(
isinfloat(world.info["storage_cost"], _safeget(config, "info", "storage_cost")),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['storage_cost'], self.storage_cost)",
):
return False
if _not(
isinfloat(
world.info["storage_cost_dev"],
_safeget(config, "info", "storage_cost_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['storage_cost_dev'], self.storage_cost_dev)",
):
return False
if _not(
isinfloat(
world.info["exogenous_price_dev"],
_safeget(config, "info", "exogenous_price_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['exogenous_price_dev'], self.exogenous_price_dev)",
):
return False
if _not(
isinfloat(
world.info["price_multiplier"],
_safeget(config, "info", "price_multiplier"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['price_multiplier'], self.price_multiplier)",
):
return False
if _is(
world.info["cost_increases_with_level"]
!= _safeget(config, "info", "cost_increases_with_level"),
raise_on_failure,
warn_on_failure,
"world.info['cost_increases_with_level'] != self.cost_increases_with_level",
):
return False
if _is(
world.info["equal_exogenous_supply"]
!= _safeget(config, "info", "equal_exogenous_supply"),
raise_on_failure,
warn_on_failure,
"world.info['equal_exogenous_supply'] != self.equal_exogenous_supply",
):
return False
if _is(
world.info["equal_exogenous_sales"]
!= _safeget(config, "info", "equal_exogenous_sales"),
raise_on_failure,
warn_on_failure,
"world.info['equal_exogenous_sales'] != self.equal_exogenous_sales",
):
return False
if _is(
world.info["cap_exogenous_quantities"]
!= _safeget(config, "info", "cap_exogenous_quantities"),
raise_on_failure,
warn_on_failure,
"world.info['cap_exogenous_quantities'] != self.cap_exogenous_quantities",
):
return False
if _is(
world.info["force_signing"] != _safeget(config, "info", "force_signing"),
raise_on_failure,
warn_on_failure,
"world.info['force_signing'] != self.force_signing",
):
return False
if _is(
world.info["random_agent_types"]
!= _safeget(config, "info", "random_agent_types"),
raise_on_failure,
warn_on_failure,
"world.info['random_agent_types'] != self.random_agent_types",
):
return False
if _is(
world.info["penalties_scale"] != _safeget(config, "info", "penalties_scale"),
raise_on_failure,
warn_on_failure,
"world.info['penalties_scale'] != self.penalties_scale",
):
return False
if _is(
world.info["exogenous_generation_method"]
!= _safeget(config, "info", "exogenous_generation_method"),
raise_on_failure,
warn_on_failure,
"world.info['exogenous_generation_method'] != self.method",
):
return False
if expected_world_type and _not(
isinstance(world, expected_world_type),
raise_on_failure,
warn_on_failure,
"not isinstance(world, self.world_type)",
):
return False
if expected_types:
world_agent_types = [
type(_._obj) # type: ignore
for aid, _ in world.agents.items()
if not is_system_agent(aid)
]
if _not(
isinclass(world_agent_types, list(expected_types)),
raise_on_failure,
warn_on_failure,
f"not isinclass({world_agent_types=}, {list(expected_types)=})",
):
return False
return True
def _config_matches_base(
config: dict[str, Any],
base: dict[str, Any],
raise_on_failure: bool,
warn_on_failure: bool,
):
if _is(
_safeget(config, "info", "perishable") != _safeget(base, "info", "perishable"),
raise_on_failure,
warn_on_failure,
f'{_safeget(config, "info", "perishable")=} != {_safeget(base, "info", "perishable")=}',
):
return False
if _not(
isin(
_safeget(config, "info", "n_steps"),
_safeget(base, "info", "n_steps"),
),
raise_on_failure,
warn_on_failure,
"not isin(config.n_steps, self.n_steps)",
):
return False
if _not(
isin(
_safeget(config, "info", "n_processes"),
_safeget(base, "info", "n_processes"),
),
raise_on_failure,
warn_on_failure,
"not isin(config.n_processes, self.n_processes)",
):
return False
if _not(
isin(
_safeget(config, "info", "n_lines"),
_safeget(base, "info", "n_lines"),
),
raise_on_failure,
warn_on_failure,
'not isin(_safeget(config,"info", dict())["n_lines"], config.get("info","n_lines"))',
):
return False
if _not(
all(
isin(_, _safeget(base, "info", "n_agents_per_process"))
for _ in _safeget(config, "info", "n_agents_per_process")
),
raise_on_failure,
warn_on_failure,
f"not all( isin(_, self.n_agents_per_process) for _ in config.get('info', dict())['n_agents_per_process'])\n"
f'{_safeget(base, "info", "n_agents_per_process")=}\n {_safeget(config, "info", "n_agents_per_process")=}',
):
return False
if _not(
isin(
_safeget(config, "info", "process_inputs_generator"),
_safeget(base, "info", "process_inputs"),
),
raise_on_failure,
warn_on_failure,
"not isin(config.get('info', dict())['process_inputs_generator'], self.process_inputs)",
):
return False
if _not(
isin(
_safeget(config, "info", "process_outputs_generator"),
_safeget(base, "info", "process_outputs"),
),
raise_on_failure,
warn_on_failure,
"not isin(config.get('info', dict())['process_outputs_generator'], self.process_outputs)",
):
return False
if _not(
isin(
_safeget(config, "info", "production_costs"),
_safeget(base, "info", "production_costs"),
),
raise_on_failure,
warn_on_failure,
"not isin(config.get('info', dict())['production_costs'], self.production_costs)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "profit_means"),
_safeget(base, "info", "profit_means"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['profit_means'], self.profit_means)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "profit_stddevs"),
_safeget(base, "info", "profit_stddevs"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['profit_stddevs'], self.profit_stddevs)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "max_productivity"),
_safeget(base, "info", "max_productivity"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['max_productivity'], self.max_productivity)",
):
return False
if _is(
_safeget(base, "info", "initial_balance") is None
and not isin(
_safeget(config, "info", "initial_balance"),
_safeget(base, "info", "initial_balance"),
),
raise_on_failure,
warn_on_failure,
"self.initial_balance is not None and not isin(config.get('info', dict())['initial_balance'], self.initial_balance)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "exogenous_supply_predictability"),
_safeget(base, "info", "exogenous_supply_predictability"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat( config.get('info', dict())['exogenous_supply_predictability'], self.exogenous_supply_predictability,)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "exogenous_sales_predictability"),
_safeget(base, "info", "exogenous_sales_predictability"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat( config.get('info', dict())['exogenous_sales_predictability'], self.exogenous_sales_predictability,)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "exogenous_control"),
_safeget(base, "info", "exogenous_control"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['exogenous_control'], self.exogenous_control)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "cash_availability"),
_safeget(base, "info", "cash_availability"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['cash_availability'], self.cash_availability)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "shortfall_penalty"),
_safeget(base, "info", "shortfall_penalty"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['shortfall_penalty'], self.shortfall_penalty)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "shortfall_penalty_dev"),
_safeget(base, "info", "shortfall_penalty_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat( config.get('info', dict())['shortfall_penalty_dev'], self.shortfall_penalty_dev)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "disposal_cost"),
_safeget(base, "info", "disposal_cost"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['disposal_cost'], self.disposal_cost)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "disposal_cost_dev"),
_safeget(base, "info", "disposal_cost_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['disposal_cost_dev'], self.disposal_cost_dev)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "storage_cost"),
_safeget(base, "info", "storage_cost"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['storage_cost'], self.storage_cost)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "storage_cost_dev"),
_safeget(base, "info", "storage_cost_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['storage_cost_dev'], self.storage_cost_dev)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "exogenous_price_dev"),
_safeget(base, "info", "exogenous_price_dev"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['exogenous_price_dev'], self.exogenous_price_dev)",
):
return False
if _not(
isinfloat(
_safeget(config, "info", "price_multiplier"),
_safeget(base, "info", "price_multiplier"),
),
raise_on_failure,
warn_on_failure,
"not isinfloat(config.get('info', dict())['price_multiplier'], self.price_multiplier)",
):
return False
if _is(
_safeget(config, "info", "cost_increases_with_level")
!= _safeget(base, "info", "cost_increases_with_level"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['cost_increases_with_level'] != self.cost_increases_with_level",
):
return False
if _is(
_safeget(config, "info", "equal_exogenous_supply")
!= _safeget(base, "info", "equal_exogenous_supply"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['equal_exogenous_supply'] != self.equal_exogenous_supply",
):
return False
if _is(
_safeget(config, "info", "equal_exogenous_sales")
!= _safeget(base, "info", "equal_exogenous_sales"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['equal_exogenous_sales'] != self.equal_exogenous_sales",
):
return False
if _is(
_safeget(config, "info", "cap_exogenous_quantities")
!= _safeget(base, "info", "cap_exogenous_quantities"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['cap_exogenous_quantities'] != self.cap_exogenous_quantities",
):
return False
if _is(
_safeget(config, "info", "force_signing")
!= _safeget(base, "info", "force_signing"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['force_signing'] != self.force_signing",
):
return False
if _is(
_safeget(config, "info", "random_agent_types")
!= _safeget(base, "info", "random_agent_types"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['random_agent_types'] != self.random_agent_types",
):
return False
if _is(
_safeget(config, "info", "penalties_scale")
!= _safeget(base, "info", "penalties_scale"),
raise_on_failure,
warn_on_failure,
"config.get('info', dict())['penalties_scale'] != self.penalties_scale",
):
return False
if _is(
_safeget(config, "info", "exogenous_generation_method")
!= _safeget(base, "info", "exogenous_generation_method"),
raise_on_failure,
warn_on_failure,
f' {_safeget(config, "info", "exogenous_generation_method")=} != {_safeget(base, "info", "exogenous_generation_method")=}, ',
):
return False
return True
@define
class BaseContext(Context, ABC):
"""A context that generates oneshot worlds with agents of a given `types` with predetermined structure and settings"""
name: str | None = None
world_type: type[SCMLBaseWorld] = SCML2024OneShotWorld
world_params: dict[str, Any] = field(factory=dict)
non_competitors: tuple[str | type[OneShotAgent], ...] = DefaultAgentsOneShot
placeholder_types: tuple[type[OneShotAgent], ...] = DEFAULT_PLACEHOLDER_AGENT_TYPES
placeholder_params: tuple[dict, ...] | None = None
placeholder_levels: tuple[int, ...] | None = None
@abstractmethod
def make_config(self) -> dict[str, Any]:
"""Generates a config for a world"""
@abstractmethod
def is_valid_world( # type: ignore
self,
world: SCMLBaseWorld,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
types: tuple[str | type[OneShotAgent], ...] | None = None,
) -> bool:
"""Checks that the given world could have been generated from this context"""
def extract_context_params(
self, min_values: bool, level: int | None = None
) -> ContextParams:
_ = min_values, level
raise NotImplementedError(
f"{self.__class__.__name__} did not implement `extrtact_context_params`"
)
def make(
self,
types: tuple[type[OneShotAgent], ...] | None = None,
params: tuple[dict[str, Any], ...] | None = None,
name: str | None = None,
) -> SCMLBaseWorld:
"""Generates the oneshot world and assigns an agent of type `agent_type` to it"""
if types is None:
types = self.placeholder_types
params = self.placeholder_params
return self.make_world(types, params, name=name)
def make_world(
self,
types: tuple[type[OneShotAgent], ...] | None = None,
params: tuple[dict[str, Any], ...] | None = None,
config: dict[str, Any] | None = None,
name: str | None = None,
) -> SCMLBaseWorld:
"""Generates a world"""
if types is None:
types = self.placeholder_types
params = self.placeholder_params
test_world = (config is not None,)
if config is None:
config = self.make_config()
config = self.world_type.replace_agents(
config, self.placeholder_types, types, params
)
if name is None:
name = unique_name(self.name, sep=".")
if name is not None:
config["name"] = name
world = self.world_type(
**(self.world_params | config),
one_offer_per_step=True,
)
if test_world:
assert self.is_valid_world(world, types=types)
return world
def generate( # type: ignore
self,
types: tuple[type[OneShotAgent], ...] | None = None,
params: tuple[dict[str, Any], ...] | None = None,
name: str | None = None,
) -> tuple[SCMLBaseWorld, tuple[OneShotAgent, ...]]:
"""Generates the world and assigns an agent to it"""
if not name:
name = self.name
if types is None:
types = self.placeholder_types
params = self.placeholder_params
if isinstance(types, OneShotAgent):
types = (types,) # type: ignore
if isinstance(params, dict):
params = (params,)
world = self.make(types, params, name)
ids = []
if types:
ids = [id for id, a in world.agents.items() if isinobject(a._obj, types)] # type: ignore
assert len(ids) == len(
types
), f"Found the following agent of type {types=}: {ids=}"
agents = tuple(world.agents[id]._obj for id in ids) # type: ignore
return world, agents # type: ignore
def is_valid_awi(
self,
awi: OneShotAWI,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
types: tuple[str | type[OneShotAgent], ...] | None = None,
) -> bool: # type: ignore
# todo: what should I do with tupes input to is_invalid_world
return self.is_valid_world(
awi._world,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
types=types,
)
def contains_context(
self,
context: Context,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
n_tests: int = NTESTS,
) -> bool:
for _ in range(n_tests):
world, _ = context.generate()
if not self.is_valid_world(
world,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
):
return False
return True
@define
[docs]
class GeneralContext(BaseContext):
"""A context that generates oneshot worlds with agents of a given `types` with predetermined structure and settings"""
# std vs oneshot
[docs]
perishable: bool = True
# negotiation parameters
[docs]
price_multiplier: np.ndarray | tuple[float, float] | float = (1.5, 2.0)
[docs]
force_signing = True
# production graph parameters
[docs]
n_steps: tuple[int, int] | int = (20, 200)
[docs]
n_processes: tuple[int, int] | int = 2
[docs]
n_lines: tuple[int, int] | int = 10
[docs]
n_agents_per_process: np.ndarray | list[int] | tuple[int, int] | int = (
min(N_CONSUMERS[0], N_SUPPLIERS[0]),
max(N_CONSUMERS[-1], N_SUPPLIERS[-1]),
)
# profile parameters
[docs]
production_costs: np.ndarray | tuple[int, int] | int = (1, 4)
[docs]
cash_availability: tuple[float, float] | float = (1.5, 2.5)
[docs]
shortfall_penalty: tuple[float, float] | float = (0.2, 1.0)
[docs]
shortfall_penalty_dev: tuple[float, float] | float = (0.0, 0.1)
[docs]
disposal_cost: tuple[float, float] | float = (0.0, 0.2)
[docs]
disposal_cost_dev: tuple[float, float] | float = (0.0, 0.02)
[docs]
storage_cost: tuple[float, float] | float = (0.0, 0.02)
[docs]
storage_cost_dev: tuple[float, float] | float = 0
[docs]
cost_increases_with_level = True
[docs]
penalties_scale: str | list[str] = "trading"
[docs]
process_inputs: tuple[int, int] | int = 1
[docs]
process_outputs: np.ndarray | tuple[int, int] | int = 1
# exogenous contract generation parameters
[docs]
exogenous_generation_method = "profitable"
[docs]
profit_means: np.ndarray | tuple[float, float] | float = (0.1, 0.2)
[docs]
profit_stddevs: np.ndarray | tuple[float, float] | float = 0.05
[docs]
max_productivity: np.ndarray | tuple[float, float] | float = (0.8, 1.0)
[docs]
initial_balance: np.ndarray | tuple[int, int] | int | None = None
[docs]
exogenous_supply_predictability: tuple[float, float] | float = (0.6, 0.9)
[docs]
exogenous_sales_predictability: tuple[float, float] | float = (0.6, 0.9)
[docs]
exogenous_control: tuple[float, float] | float = -1
[docs]
exogenous_price_dev: tuple[float, float] | float = (0.1, 0.2)
[docs]
equal_exogenous_supply = False
[docs]
equal_exogenous_sales = False
[docs]
cap_exogenous_quantities: bool = True
[docs]
def __attrs_post_init__(self):
from scml.std.world import StdWorld
if self.perishable:
assert not issubclass(self.world_type, StdWorld)
else:
assert issubclass(self.world_type, StdWorld)
[docs]
def make_predefined_config(
self,
agent_types: list[type[OneShotAgent]],
agent_processes: list[int],
agent_params: list[dict[str, Any]],
n_agents_per_process: list[int],
) -> dict[str, Any]:
"""Generates a config for a world"""
if agent_params is None:
agent_params = [dict() for _ in agent_types]
return self.world_params | self.world_type.generate(
agent_types=agent_types, # type: ignore
agent_params=agent_params,
agent_processes=agent_processes,
perishable=self.perishable,
n_steps=self.n_steps,
n_processes=len(n_agents_per_process),
n_lines=self.n_lines,
n_agents_per_process=np.asarray(n_agents_per_process),
process_inputs=self.process_inputs,
process_outputs=self.process_outputs,
production_costs=self.production_costs,
profit_means=self.profit_means,
profit_stddevs=self.profit_stddevs,
max_productivity=self.max_productivity,
initial_balance=self.initial_balance,
exogenous_supply_predictability=self.exogenous_supply_predictability,
exogenous_sales_predictability=self.exogenous_sales_predictability,
exogenous_control=self.exogenous_control,
cash_availability=self.cash_availability,
shortfall_penalty=self.shortfall_penalty,
shortfall_penalty_dev=self.shortfall_penalty_dev,
disposal_cost=self.disposal_cost,
disposal_cost_dev=self.disposal_cost_dev,
storage_cost=self.storage_cost,
storage_cost_dev=self.storage_cost_dev,
exogenous_price_dev=self.exogenous_price_dev,
price_multiplier=self.price_multiplier,
cost_increases_with_level=self.cost_increases_with_level,
equal_exogenous_supply=self.equal_exogenous_supply,
equal_exogenous_sales=self.equal_exogenous_sales,
cap_exogenous_quantities=self.cap_exogenous_quantities,
force_signing=self.force_signing,
random_agent_types=False,
penalties_scale=self.penalties_scale,
exogenous_generation_method=self.exogenous_generation_method,
)
[docs]
def contains_context(
self,
context: Context,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
n_tests: int = NTESTS,
) -> bool:
if isinstance(context, GeneralContext):
return self.contains_general_context(context)
return super().contains_context(
context, raise_on_failure, warn_on_failure, n_tests
)
[docs]
def _assign_types(self, n_processes, types, params, levels, n_agents_per_process):
n_agents = sum(n_agents_per_process)
perlevel = defaultdict(list)
for i, t, p in zip(levels, types, params):
perlevel[i].append((t, p))
agent_types = list(random.choices(self.non_competitors, k=n_agents))
agent_params: list[dict[str, Any]] = list(dict() for _ in agent_types)
agent_processes = np.zeros(n_agents, dtype=int)
nxt, indx = 0, -1
rngs = []
for level in range(n_processes):
last = nxt + n_agents_per_process[level]
agent_processes[nxt:last] = level
rngs.append((nxt, last))
nxt += n_agents_per_process[level]
for i, tp in perlevel.items():
first, last = rngs[i]
assert (
last - first + 1 >= len(tp)
), f"Cannot put agents of type {tp=} in level {i} which has only {last - first + 1} agents"
random.shuffle(tp)
selected = list(range(first, last))
random.shuffle(selected)
selected = selected[: len(tp)]
for indx, (my_type, my_params) in zip(selected, tp):
agent_types[indx] = my_type
if params:
agent_params[indx]["controller_params"] = my_params
return agent_types, agent_processes, agent_params
[docs]
def _distribute_agents(self, n_types):
n_processes = intin(self.n_processes)
# distribute agents over production levels (processes)
n_agents_per_process = make_array(
self.n_agents_per_process, n_processes, dtype=int, min_total=n_types
)
return n_processes, n_agents_per_process
[docs]
def make_config(self) -> dict[str, Any]:
"""Generates a config for a world"""
types = self.placeholder_types
params = [dict() for _ in types]
levels = self.placeholder_levels
n_processes, n_agents_per_process = self._distribute_agents(len(types))
assert len(n_agents_per_process) == n_processes
n_agents = sum(n_agents_per_process)
assert n_agents >= len(types)
# find my levels
if not levels:
levels = tuple(random.randint(0, n_processes - 1) for _ in types)
return self.make_predefined_config(
*self._assign_types(
n_processes, types, params, levels, n_agents_per_process
),
n_agents_per_process, # type: ignore
)
[docs]
def is_valid_world( # type: ignore
self,
world: SCMLBaseWorld,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
types: tuple[str | type[OneShotAgent], ...] | None = None,
) -> bool:
"""Checks that the given world could have been generated from this context"""
if _is(
world.perishable != self.perishable,
raise_on_failure,
warn_on_failure,
f"{world.perishable=} != {self.perishable=}",
):
return False
if _not(
isin(world.n_steps, self.n_steps),
raise_on_failure,
warn_on_failure,
f"not isin({world.n_steps=}, {self.n_steps=})",
):
return False
if _not(
isin(world.n_processes, self.n_processes),
raise_on_failure,
warn_on_failure,
"not isin(world.n_processes, self.n_processes)",
):
return False
if _not(
isin(world.info["n_lines"], self.n_lines),
raise_on_failure,
warn_on_failure,
'not isin(world.info["n_lines"], self.n_lines)',
):
return False
if _not(
all(
isin(_, self.n_agents_per_process)
for _ in world.info["n_agents_per_process"]
),
raise_on_failure,
warn_on_failure,
f"not all( isin(_, self.n_agents_per_process) for _ in world.info['n_agents_per_process'])\n"
f"{self.n_agents_per_process=}\n{world.info['n_agents_per_process']=}",
):
return False
if _not(
isin(world.info["process_inputs_generator"], self.process_inputs),
raise_on_failure,
warn_on_failure,
"not isin(world.info['process_inputs_generator'], self.process_inputs)",
):
return False
if _not(
isin(world.info["process_outputs_generator"], self.process_outputs),
raise_on_failure,
warn_on_failure,
"not isin(world.info['process_outputs_generator'], self.process_outputs)",
):
return False
if _not(
isin(world.info["production_costs"], self.production_costs),
raise_on_failure,
warn_on_failure,
"not isin(world.info['production_costs'], self.production_costs)",
):
return False
if _not(
isinfloat(world.info["profit_means"], self.profit_means),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['profit_means'], self.profit_means)",
):
return False
if _not(
isinfloat(world.info["profit_stddevs"], self.profit_stddevs),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['profit_stddevs'], self.profit_stddevs)",
):
return False
if _not(
isinfloat(world.info["max_productivity"], self.max_productivity),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['max_productivity'], self.max_productivity)",
):
return False
if _is(
self.initial_balance is not None
and not isin(world.info["initial_balance"], self.initial_balance),
raise_on_failure,
warn_on_failure,
"self.initial_balance is not None and not isin(world.info['initial_balance'], self.initial_balance)",
):
return False
if _not(
isinfloat(
world.info["exogenous_supply_predictability"],
self.exogenous_supply_predictability,
),
raise_on_failure,
warn_on_failure,
"not isinfloat( world.info['exogenous_supply_predictability'], self.exogenous_supply_predictability,)",
):
return False
if _not(
isinfloat(
world.info["exogenous_sales_predictability"],
self.exogenous_sales_predictability,
),
raise_on_failure,
warn_on_failure,
"not isinfloat( world.info['exogenous_sales_predictability'], self.exogenous_sales_predictability,)",
):
return False
if _not(
isinfloat(world.info["exogenous_control"], self.exogenous_control),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['exogenous_control'], self.exogenous_control)",
):
return False
if _not(
isinfloat(world.info["cash_availability"], self.cash_availability),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['cash_availability'], self.cash_availability)",
):
return False
if _not(
isinfloat(world.info["shortfall_penalty"], self.shortfall_penalty),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['shortfall_penalty'], self.shortfall_penalty)",
):
return False
if _not(
isinfloat(world.info["shortfall_penalty_dev"], self.shortfall_penalty_dev),
raise_on_failure,
warn_on_failure,
"not isinfloat( world.info['shortfall_penalty_dev'], self.shortfall_penalty_dev)",
):
return False
if _not(
isinfloat(world.info["disposal_cost"], self.disposal_cost),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['disposal_cost'], self.disposal_cost)",
):
return False
if _not(
isinfloat(world.info["disposal_cost_dev"], self.disposal_cost_dev),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['disposal_cost_dev'], self.disposal_cost_dev)",
):
return False
if _not(
isinfloat(world.info["storage_cost"], self.storage_cost),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['storage_cost'], self.storage_cost)",
):
return False
if _not(
isinfloat(world.info["storage_cost_dev"], self.storage_cost_dev),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['storage_cost_dev'], self.storage_cost_dev)",
):
return False
if _not(
isinfloat(world.info["exogenous_price_dev"], self.exogenous_price_dev),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['exogenous_price_dev'], self.exogenous_price_dev)",
):
return False
if _not(
isinfloat(world.info["price_multiplier"], self.price_multiplier),
raise_on_failure,
warn_on_failure,
"not isinfloat(world.info['price_multiplier'], self.price_multiplier)",
):
return False
if _is(
world.info["cost_increases_with_level"] != self.cost_increases_with_level,
raise_on_failure,
warn_on_failure,
"world.info['cost_increases_with_level'] != self.cost_increases_with_level",
):
return False
if _is(
world.info["equal_exogenous_supply"] != self.equal_exogenous_supply,
raise_on_failure,
warn_on_failure,
"world.info['equal_exogenous_supply'] != self.equal_exogenous_supply",
):
return False
if _is(
world.info["equal_exogenous_sales"] != self.equal_exogenous_sales,
raise_on_failure,
warn_on_failure,
"world.info['equal_exogenous_sales'] != self.equal_exogenous_sales",
):
return False
if _is(
world.info["cap_exogenous_quantities"] != self.cap_exogenous_quantities,
raise_on_failure,
warn_on_failure,
"world.info['cap_exogenous_quantities'] != self.cap_exogenous_quantities",
):
return False
if _is(
world.info["force_signing"] != self.force_signing,
raise_on_failure,
warn_on_failure,
"world.info['force_signing'] != self.force_signing",
):
return False
if _is(
world.info["random_agent_types"] is not False,
raise_on_failure,
warn_on_failure,
"world.info['random_agent_types'] != False",
):
return False
if _is(
world.info["penalties_scale"] != self.penalties_scale,
raise_on_failure,
warn_on_failure,
"world.info['penalties_scale'] != self.penalties_scale",
):
return False
if _is(
world.info["exogenous_generation_method"]
!= self.exogenous_generation_method,
raise_on_failure,
warn_on_failure,
"world.info['exogenous_generation_method'] != self.method",
):
return False
if _not(
isinstance(world, self.world_type),
raise_on_failure,
warn_on_failure,
"not isinstance(world, self.world_type)",
):
return False
if types:
world_agent_types = [
type(_._obj) # type: ignore
for aid, _ in world.agents.items()
if not is_system_agent(aid)
]
if _not(
isinclass(world_agent_types, list(self.non_competitors) + list(types)),
raise_on_failure,
warn_on_failure,
f"not isinclass({world_agent_types=}, {list(self.non_competitors)=}), {types=}",
):
return False
return True
[docs]
def contains_general_context(self, context: "GeneralContext") -> bool:
"""Checks that the any world generated from the given `context` could have been generated from this context"""
if context.perishable != self.perishable:
return False
if not isin(context.n_steps, self.n_steps):
return False
if not isin(context.n_processes, self.n_processes):
return False
if not isin(context.n_lines, self.n_lines):
return False
if not isin(context.n_agents_per_process, self.n_agents_per_process):
return False
if not isin(context.process_inputs, self.process_inputs):
return False
if not isin(context.process_outputs, self.process_outputs):
return False
if not isin(context.production_costs, self.production_costs):
return False
if not isinfloat(context.profit_means, self.profit_means):
return False
if not isinfloat(context.profit_stddevs, self.profit_stddevs):
return False
if not isinfloat(context.max_productivity, self.max_productivity):
return False
if (
self.initial_balance is not None
and not isin(context.initial_balance, self.initial_balance) # type: ignore
) or (
self.initial_balance is not None
and not isin(context.initial_balance, self.initial_balance) # type: ignore
):
return False
if not isinfloat(
context.exogenous_supply_predictability,
self.exogenous_supply_predictability,
):
return False
if not isinfloat(
context.exogenous_sales_predictability,
self.exogenous_sales_predictability,
):
return False
if not isinfloat(context.exogenous_control, self.exogenous_control):
return False
if not isinfloat(context.cash_availability, self.cash_availability):
return False
if not isinfloat(context.shortfall_penalty, self.shortfall_penalty):
return False
if not isinfloat(context.shortfall_penalty_dev, self.shortfall_penalty_dev):
return False
if not isinfloat(context.disposal_cost, self.disposal_cost):
return False
if not isinfloat(context.disposal_cost_dev, self.disposal_cost_dev):
return False
if not isinfloat(context.storage_cost, self.storage_cost):
return False
if not isinfloat(context.storage_cost_dev, self.storage_cost_dev):
return False
if not isinfloat(context.exogenous_price_dev, self.exogenous_price_dev):
return False
if not isinfloat(context.price_multiplier, self.price_multiplier):
return False
if context.cost_increases_with_level != self.cost_increases_with_level:
return False
if context.equal_exogenous_supply != self.equal_exogenous_supply:
return False
if context.equal_exogenous_sales != self.equal_exogenous_sales:
return False
if context.cap_exogenous_quantities != self.cap_exogenous_quantities:
return False
if context.force_signing != self.force_signing:
return False
# if context.random_agent_types != self.random_agent_types:
# return False
if context.penalties_scale != self.penalties_scale:
return False
if context.exogenous_generation_method != self.exogenous_generation_method:
return False
if isinstance(context.world_type, self.world_type):
return False
if not isinclass(list(context.non_competitors), list(self.non_competitors)):
return False
return True
@define
[docs]
class LimitedPartnerNumbersContext(GeneralContext):
"""Generates a world limiting the range of the agent level, production capacity
and the number of suppliers, consumers, and optionally same-level competitors."""
[docs]
n_consumers: tuple[int, int] = N_CONSUMERS
[docs]
n_suppliers: tuple[int, int] = (0, 0)
[docs]
n_competitors: tuple[int, int] = (N_CONSUMERS[0] - 1, N_CONSUMERS[1] - 1)
[docs]
buying_strength: Strength | None = None
[docs]
selling_strength: Strength | None = None
[docs]
def __attrs_post_init__(self):
max_n_proceses = (
max(self.n_processes)
if isinstance(self.n_processes, Iterable)
else self.n_processes
)
assert isin(
tuple(_ + 1 for _ in self.n_competitors), # type: ignore
self.n_agents_per_process, # type: ignore
), f"{self.n_competitors=}, {self.n_agents_per_process=}"
assert (
not (self.level > 0 and self.level < max_n_proceses - 1)
or (self.n_suppliers[-1] > 1 and self.n_consumers[-1] > 1)
), f"{self.n_suppliers=}, {self.n_consumers=}, {self.level=}, {self.n_processes=}"
if self.level == 0:
assert isin(
self.n_consumers, self.n_agents_per_process
), f"{self.n_consumers=}, {self.n_agents_per_process=}, {self.level=}"
assert max(self.n_suppliers) < 1, f"{self.n_suppliers=}, {self.level=}"
assert (
min(self.n_consumers) > 0
), f"{self.n_consumers=}, {self.n_agents_per_process=}, {self.level=}"
elif (
self.level == -1
or isinstance(self.n_processes, int)
and self.level == max_n_proceses - 1
):
assert self.level < max_n_proceses, f"{max_n_proceses=}, {self.level=}"
assert isin(
self.n_suppliers, self.n_agents_per_process
), f"{self.n_suppliers=}, {self.n_agents_per_process=}, {self.level=}"
assert (
max(self.n_consumers) < 1
), f"{self.n_consumers=}, {self.n_agents_per_process=}, {self.level=}"
assert min(self.n_suppliers) > 0, f"{self.n_suppliers=}, {self.level=}"
else:
assert isin(
self.n_consumers, self.n_agents_per_process
), f"{self.n_consumers=}, {self.n_agents_per_process=}, {self.level=}"
assert isin(
self.n_suppliers, self.n_agents_per_process
), f"{self.n_suppliers=}, {self.n_agents_per_process=}, {self.level=}"
[docs]
def make_config(self) -> dict[str, Any]:
"""Generates a config"""
types = self.placeholder_types
params = (
[dict() for _ in types]
if self.placeholder_params is None
else self.placeholder_params
)
levels = self.placeholder_levels
assert levels is None or all(_ == self.level for _ in levels), (
"LimitedPartnerNumbersContext does not allow you to decide the levels of "
"the agents when creating the config as it uses its internal level "
"and assigns all placeholder agents to it: {self.level=}, {levels=}"
)
levels = tuple(self.level for _ in types)
if params is None:
params = tuple(dict() for _ in types)
n_processes, n_agents_per_process = self._distribute_agents(len(types))
# find my level
my_level = n_processes - 1 if self.level < 0 else self.level
n_competitors = self.n_competitors
n_suppliers = n_agents_per_process[my_level - 1] if my_level > 0 else 0
n_consumers = (
n_agents_per_process[my_level + 1] if my_level < n_processes - 1 else 0
)
if self.buying_strength is not None:
n_competitors, n_suppliers = sample_with_strength(
n_competitors, self.n_suppliers, self.buying_strength
)
if self.selling_strength is not None:
n_competitors, n_consumers = sample_with_strength(
n_competitors, self.n_consumers, self.selling_strength
)
n_competitors = intin(n_competitors)
# override the number of consumers and number of suppliers to match my choice
if my_level == 0:
n_agents_per_process[1] = n_consumers
elif my_level == n_processes - 1:
n_agents_per_process[my_level - 1] = n_suppliers
else:
n_agents_per_process[my_level + 1] = n_consumers
n_agents_per_process[my_level - 1] = n_suppliers
n_agents_per_process[my_level] = max(len(types), n_competitors + 1)
for k, n in enumerate(n_agents_per_process):
assert isin(n, self.n_agents_per_process), (
f"Level {k} has {n} agents which is not in {self.n_agents_per_process}"
f": {self.n_suppliers=}, {self.n_competitors=}, {self.n_consumers=}"
f": {self.selling_strength=}, {self.buying_strength=}\n{n_agents_per_process}"
f"\n{self.n_agents_per_process}"
f"\n {n_suppliers=}, {n_competitors=}, {n_consumers=}"
)
return self.make_predefined_config(
*self._assign_types(
n_processes, types, params, levels, n_agents_per_process
),
n_agents_per_process, # type: ignore
)
[docs]
def find_test_agents(
self,
world: SCMLBaseWorld,
types: tuple[type[OneShotAgent], ...] | None = None,
) -> list[str]:
if types is None:
types = self.placeholder_types
return [aid for aid, agent in world.agents.items() if isinobject(agent, types)]
[docs]
def is_valid_world( # type: ignore
self,
world: SCMLBaseWorld,
types: tuple[type[OneShotAgent], ...] | None = None,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
) -> bool:
"""Checks that the given world could have been generated from this context"""
if types is None:
types = self.placeholder_types
agent_ids = self.find_test_agents(world, types)
n_processes = world.n_processes
expected_level = self.level
for aid in agent_ids:
my_level = world.agent_profiles[aid].input_product
if _is(
my_level == expected_level,
raise_on_failure,
warn_on_failure,
f"Agent {aid} of type {world.agents[aid]._obj.__class__.__name__} "
f"is on level {my_level} but expected to be on level {expected_level}",
):
return False
is_first_level = my_level == 0
is_last_level = my_level == n_processes - 1
my_suppliers = [
_ for _ in world.agent_suppliers[aid] if not is_system_agent(_)
]
my_consumers = [
_ for _ in world.agent_consumers[aid] if not is_system_agent(_)
]
my_competitors = (
world.suppliers[my_level + 1]
if not is_last_level
else world.consumers[my_level - 1]
)
assert (
aid in my_competitors
), f"{aid} not found in its competitors!! {my_competitors=}"
my_competitors = [_ for _ in my_competitors if _ != aid]
n_consumers, n_suppliers = len(my_consumers), len(my_suppliers)
n_competitors = len(my_competitors)
if not isin(n_competitors, self.n_competitors):
warnings.warn(
f"Invalid n_competitors: {n_competitors=} != {self.n_competitors=}"
)
return False
if self.buying_strength is not None:
if self.buying_strength == Strength.Strong:
if _not(
n_suppliers >= n_competitors + 1,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Strength {self.buying_strength} but {n_suppliers=} and {n_competitors=}",
):
return False
elif self.buying_strength == Strength.Weak:
if _not(
n_suppliers <= n_competitors - 1,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Strength {self.buying_strength} but {n_suppliers=} and {n_competitors=}",
):
return False
elif self.buying_strength == Strength.Balanced:
if _not(
n_competitors - 1 <= n_suppliers <= n_competitors + 1,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Strength {self.buying_strength} but {n_suppliers=} and {n_competitors=}",
):
return False
if self.selling_strength is not None:
if self.selling_strength == Strength.Strong:
if _not(
n_consumers >= n_competitors + 1,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Strength {self.selling_strength} but {n_consumers=} and {n_competitors=}",
):
return False
elif self.selling_strength == Strength.Weak:
if _not(
n_consumers <= n_competitors - 1,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Strength {self.selling_strength} but {n_consumers=} and {n_competitors=}",
):
return False
elif self.selling_strength == Strength.Balanced:
if _not(
n_competitors - 1 <= n_consumers <= n_competitors + 1,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Strength {self.selling_strength} but {n_consumers=} and {n_competitors=}",
):
return False
if is_first_level:
if _not(
isin(n_consumers, self.n_consumers),
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Invalid n_consumers: {n_consumers=} != {self.n_consumers=}",
):
return False
if _is(
n_suppliers != 0,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=(
f"Invalid n_suppliers for {aid} (at level {my_level} "
f"[of {world.n_processes} processes]): {len(my_suppliers)=} != 1\nAll Suppliers: {world.suppliers}"
),
):
return False
elif is_last_level:
if _not(
isin(n_suppliers, self.n_suppliers),
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Invalid n_suppliers: {n_suppliers=} != {self.n_suppliers=}",
):
return False
if _is(
n_consumers != 0,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Invalid n_conumsers: {len(my_consumers)=} != 1",
):
return False
else:
if _not(
isin(n_suppliers, self.n_suppliers),
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
message=f"Invalid n_suppliers: {n_suppliers=} not in {self.n_suppliers=}",
):
return False
if _not(
isin(n_consumers, self.n_consumers),
raise_on_failure,
warn_on_failure,
message=f"Invalid n_consumers: {n_consumers=} not in {self.n_consumers=}",
):
return False
return super().is_valid_world(world, raise_on_failure=raise_on_failure)
[docs]
def contains_limited_partner_context(
self,
context: "LimitedPartnerNumbersContext",
raise_on_failure: bool = False,
warn_on_failure: bool = False,
) -> bool:
if _not(
isin(context.n_processes, self.n_processes),
raise_on_failure,
warn_on_failure,
"not isin(context.n_processes, self.n_processes)",
):
return False
if _not(
isin(context.level, self.level),
raise_on_failure,
warn_on_failure,
"not isin(context.level, self.level)",
):
return False
if _not(
isin(context.n_consumers, self.n_consumers),
raise_on_failure,
warn_on_failure,
"not isin(context.n_consumers, self.n_consumers)",
):
return False
if _not(
isin(context.n_suppliers, self.n_suppliers),
raise_on_failure,
warn_on_failure,
"not isin(context.n_suppliers, self.n_suppliers)",
):
return False
if _not(
isin(context.n_competitors, self.n_competitors),
raise_on_failure,
warn_on_failure,
"not isin(context.n_competitors, self.n_competitors)",
):
return False
return super().contains_context(context, raise_on_failure)
[docs]
def contains_context(
self,
context: Context,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
n_tests: int = NTESTS,
) -> bool:
"""Checks that the any world generated from the given `context` could have been generated from this context"""
if isinstance(context, self.__class__):
return self.contains_limited_partner_context(
context,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
)
return super().contains_context(
context, raise_on_failure, warn_on_failure, n_tests
)
@define
[docs]
class FixedPartnerNumbersContext(LimitedPartnerNumbersContext):
"""Generates a world limiting the range of the agent level, production capacity
and the number of suppliers, consumers, and optionally same-level competitors."""
[docs]
n_consumers: int = 4 # type: ignore
[docs]
n_suppliers: int = 0 # type: ignore
[docs]
n_competitors: int = 3 # type: ignore
[docs]
def __attrs_post_init__(self):
object.__setattr__(self, "n_consumers", (self.n_consumers, self.n_consumers))
object.__setattr__(self, "n_suppliers", (self.n_suppliers, self.n_suppliers))
object.__setattr__(
self, "n_competitors", (self.n_competitors, self.n_competitors)
)
super().__attrs_post_init__()
object.__setattr__(self, "n_consumers", self.n_consumers[0]) # type: ignore
object.__setattr__(self, "n_suppliers", self.n_suppliers[0]) # type: ignore
object.__setattr__(self, "n_competitors", self.n_competitors[0]) # type: ignore
@define
[docs]
class MonopolicContext(LimitedPartnerNumbersContext):
"""An agent that has no competitors in the same level as themselves"""
# n_competitors: tuple[int, int] = field(default=(0, 0), converter=lambda _: (0, 0))
[docs]
n_competitors: tuple[int, int] = (0, 0)
[docs]
n_agents_per_process: np.ndarray | list[int] | tuple[int, int] | int = field(
default=(1, max(N_CONSUMERS[-1], N_SUPPLIERS[-1])),
converter=lambda _: (1, max(N_CONSUMERS[-1], N_SUPPLIERS[-1])),
)
[docs]
def __attrs_post_init__(self):
npp = self.n_agents_per_process
if isinstance(npp, int):
assert (
npp == 1
), f"You passed {self.n_agents_per_process=} to a MonopolicContext but this MUST be one in this case"
elif isinstance(npp, tuple):
npp = (min(1, npp[0]), npp[1])
else:
npp = list(set(list(npp) + [1]))
object.__setattr__(self, "n_agents_per_process", npp)
return super().__attrs_post_init__()
@define
class SingleAgentPerLevelConsumerContext(MonopolicContext):
"""A world in which every level has exactly one factory and the agent is a consumer"""
level: int = -1
n_consumers: tuple[int, int] = (0, 0)
n_suppliers: tuple[int, int] = (1, 1)
n_agents_per_process: np.ndarray | list[int] | tuple[int, int] | int = field(
default=1, converter=lambda _: 1
)
@define
[docs]
class SingleAgentPerLevelSupplierContext(MonopolicContext):
"""A world in which every level has exactly one factory and the agent is a supplier"""
[docs]
n_consumers: tuple[int, int] = (1, 1)
[docs]
n_suppliers: tuple[int, int] = (0, 0)
[docs]
n_agents_per_process: np.ndarray | list[int] | tuple[int, int] | int = field(
default=1, converter=lambda _: 1
)
@define
[docs]
class EutopiaContext(MonopolicContext):
"""An unrealistic context in which the agent is the only one in its level and all other agents are nice."""
[docs]
non_competitors: tuple[str | type[OneShotAgent], ...] = field(
default=(NiceAgent,), converter=lambda _: (NiceAgent,)
)
@define
[docs]
class EutopiaSupplierContext(EutopiaContext):
"""An unrealistic context in which the agent is the only supplier and all consumers are nice."""
[docs]
level: int = field(default=0, converter=lambda _: 0)
[docs]
n_consumers: tuple[int, int] = N_CONSUMERS
[docs]
n_suppliers: tuple[int, int] = (0, 0)
@define
[docs]
class EutopiaConsumerContext(EutopiaContext):
"""An unrealistic context in which the agent is the only consumer and all suppliers are nice."""
[docs]
level: int = field(default=-1, converter=lambda _: -1)
[docs]
n_consumers: tuple[int, int] = (0, 0)
[docs]
n_suppliers: tuple[int, int] = N_SUPPLIERS
@define
[docs]
class FixedPartnerNumbersOneShotContext(FixedPartnerNumbersContext):
...
@define
[docs]
class ANACContext(GeneralContext):
"""Generates a oneshot world with no constraints except compatibility with a specific ANAC competition year."""
[docs]
def __attrs_post_init__(self):
object.__setattr__(
self,
"world_type",
{
2024: SCML2024OneShotWorld,
2023: SCML2023OneShotWorld,
2022: SCML2022OneShotWorld,
2021: SCML2021OneShotWorld,
2020: SCMLBaseWorld,
}[self.year],
)
@define
[docs]
class LimitedPartnerNumbersOneShotContext(LimitedPartnerNumbersContext):
"""Generates a oneshot world limiting the range of the agent level, production capacity
and the number of suppliers, consumers, and optionally same-level competitors."""
[docs]
def __attrs_post_init__(self):
super().__attrs_post_init__()
object.__setattr__(
self,
"world_type",
{
2024: SCML2024OneShotWorld,
2023: SCML2023OneShotWorld,
2022: SCML2022OneShotWorld,
2021: SCML2021OneShotWorld,
2020: SCMLBaseWorld,
}[self.year],
)
@define
[docs]
class ANACOneShotContext(ANACContext):
"""Generates a oneshot world with no constraints except compatibility with a specific ANAC competition year."""
[docs]
def __attrs_post_init__(self):
object.__setattr__(
self,
"world_type",
{
2024: SCML2024OneShotWorld,
2023: SCML2023OneShotWorld,
2022: SCML2022OneShotWorld,
2021: SCML2021OneShotWorld,
2020: SCMLBaseWorld,
}[self.year],
)
@define
[docs]
class SupplierContext(LimitedPartnerNumbersOneShotContext):
"""A world context that can generate any world compatible with the observation manager"""
def __init__(self, *args, **kwargs):
n_agents_per_process = (
min(N_SUPPLIERS[0], N_CONSUMERS[0]), # type: ignore
max(N_SUPPLIERS[1], N_CONSUMERS[1]), # type: ignore
)
kwargs |= dict(
n_suppliers=(0, 0), # suppliers have no suppliers
n_consumers=N_CONSUMERS,
n_competitors=(N_SUPPLIERS[0] - 1, N_SUPPLIERS[1] - 1),
n_agents_per_process=n_agents_per_process,
level=0, # suppliers are always in the first level
)
super().__init__(*args, **kwargs)
@define
[docs]
class StrongSupplierContext(SupplierContext):
"""A supplier with almost many consumers relative to competitors"""
def __init__(self, *args, **kwargs):
kwargs |= dict(selling_strength=Strength.Strong)
super().__init__(*args, **kwargs)
@define
[docs]
class BalancedSupplierContext(SupplierContext):
"""A supplier with almost same number of consumers as competitors"""
def __init__(self, *args, **kwargs):
kwargs |= dict(selling_strength=Strength.Balanced)
super().__init__(*args, **kwargs)
@define
[docs]
class WeakSupplierContext(SupplierContext):
"""A supplier with few consumers relative to competitors"""
def __init__(self, *args, **kwargs):
kwargs |= dict(selling_strength=Strength.Weak)
super().__init__(*args, **kwargs)
@define
[docs]
class ConsumerContext(LimitedPartnerNumbersOneShotContext):
"""A world context that can generate any world compatible with the observation manager"""
def __init__(self, *args, **kwargs):
n_agents_per_process = (
min(N_SUPPLIERS[0], N_CONSUMERS[0]), # type: ignore
max(N_SUPPLIERS[1], N_CONSUMERS[1]), # type: ignore
)
kwargs |= dict(
n_suppliers=N_SUPPLIERS,
n_consumers=(0, 0), # consumers have no consumers
n_competitors=(N_CONSUMERS[0] - 1, N_CONSUMERS[1] - 1),
n_agents_per_process=n_agents_per_process,
level=-1, # consumers are always in the last level
)
super().__init__(*args, **kwargs)
@define
[docs]
class StrongConsumerContext(ConsumerContext):
"""A consumer with almost many suppliers relative to competitors"""
def __init__(self, *args, **kwargs):
kwargs |= dict(buying_strength=Strength.Strong)
super().__init__(*args, **kwargs)
@define
[docs]
class BalancedConsumerContext(ConsumerContext):
"""A consumer with almost same number of suppliers as competitors"""
def __init__(self, *args, **kwargs):
kwargs |= dict(buying_strength=Strength.Balanced)
super().__init__(*args, **kwargs)
@define
[docs]
class WeakConsumerContext(ConsumerContext):
"""A consumer with few suppliers relative to competitors"""
def __init__(self, *args, **kwargs):
kwargs |= dict(buying_strength=Strength.Weak)
super().__init__(*args, **kwargs)
@define
class OneShotContext(GeneralContext):
"""A basic context fixing stationary world config parameters"""
@define
[docs]
class RepeatingContext(BaseContext):
"""Encapsulates one or more configs and switches between them when asked to generate or make something."""
[docs]
configs: tuple[dict[str, Any], ...] = field(
factory=lambda: (GeneralContext().make_config(),)
)
[docs]
randomize: bool = True
[docs]
rename: bool = True
[docs]
_next: int = field(init=False, default=0)
[docs]
def __attrs_post_init__(self):
if not self.configs:
raise ValueError("RepeatingContext with no configs")
[docs]
def make_config(
self,
types: tuple[type[OneShotAgent], ...] = DEFAULT_PLACEHOLDER_AGENT_TYPES,
params: tuple[dict[str, Any], ...] | None = None,
) -> dict[str, Any]:
if not self.configs:
raise ValueError("No configs to generate from")
if self.randomize:
self._next = random.randint(0, len(self.configs) - 1)
config = self.configs[self._next]
self._next = (self._next + 1) % len(self.configs)
config = self.world_type.replace_agents(
config, self.placeholder_types, types, params
)
if self.rename:
config["name"] = unique_name(
f"c{self._next}", add_time=False, rand_digits=6, sep=""
)
return config
@classmethod
[docs]
def from_context(
cls: type,
context: BaseContext,
n: int = 1,
types: tuple[type[OneShotAgent]] = DEFAULT_PLACEHOLDER_AGENT_TYPES,
rename: bool = False,
randomize: bool = False,
):
return cls(
configs=tuple(context.make_config() for _ in range(n)),
placeholder_types=types,
rename=rename,
randomize=randomize,
)
[docs]
def contains_repeating_context(
self,
context: "RepeatingContext",
raise_on_failure: bool = False,
warn_on_failure: bool = False,
):
for config in context.configs:
if any(
_config_matches_base(config, base, raise_on_failure, warn_on_failure)
for base in self.configs
):
break
else:
return False
return True
[docs]
def is_valid_world(
self,
world: SCMLBaseWorld,
raise_on_failure=RAISE_ON_FAILURE,
warn_on_failure=WARN_ON_FAILURE,
types: tuple[str | type[OneShotAgent], ...] | None = None,
) -> bool:
"""Checks that the given world could have been generated from this context"""
for config in self.configs:
if _world_matches_config(
world,
config,
expected_types=list(self.non_competitors) + list(types)
if types
else None,
expected_world_type=self.world_type,
raise_on_failure=raise_on_failure,
warn_on_failure=warn_on_failure,
):
return True
return False
[docs]
def contains_context(
self,
context: Context,
raise_on_failure: bool = False,
warn_on_failure: bool = False,
n_tests: int = NTESTS,
) -> bool:
if isinstance(context, RepeatingContext):
return self.contains_repeating_context(
context, raise_on_failure, warn_on_failure
)
return super().contains_context(
context, raise_on_failure, warn_on_failure, n_tests
)