from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, Tuple
from dataclasses import replace
from sim_panel.config.types import RunBundle, RunConfig
from sim_panel.config.yaml_loader import load_yaml
from sim_panel.generators.pipeline import EventGenerator
from sim_panel.generators.types import GeneratorConfig, ExecutionConfig
from sim_panel.policies.base import PolicyConfig
from sim_panel.decisions.types import SelectionConfig, ExecutionRules
from sim_panel.outcomes.registry import outcome_config_from_yaml_dict
from sim_panel.panelists.factory import build_panelists
from sim_panel.panelists.panelist import EvalSettings, SelectSettings
from sim_panel.panelists.io import load_persona_records, save_persona_records
from sim_panel.panelists.enrich import ensure_persona_text, PersonaTextGenSettings
from sim_panel.products.io import load_product_records, save_product_records
from sim_panel.products.product import Product
from sim_panel.products.factory import build_products
from sim_panel.products.enrich import ensure_display_text, ProductDisplayTextGenSettings
from sim_panel.backends import Backend
from sim_panel.io.manual_schedule import load_manual_schedule
[docs]
def build_run_from_yaml(path: str) -> RunBundle:
d = load_yaml(path)
return build_run_from_dict(d, config_path=path)
[docs]
def build_run_from_dict(d: Mapping[str, Any], *, config_path: Optional[str] = None) -> RunBundle:
"""
Build a complete run bundle from a YAML-parsed dict.
Required YAML sections:
panelists:
source: path/to/personas.jsonl
variant: default
enrich: (optional)
products:
source: path/to/products.jsonl
variant: default
enrich: (optional)
policy:
name: random | manual | self_selection
Optional sections:
generator, selection, execution, outcomes_model, questionnaire, backend, output_dir
"""
cfg_snapshot: Dict[str, Any] = dict(d)
# --- required source sections ---
panel_cfg = _require_mapping(d, "panelists")
prod_cfg = _require_mapping(d, "products")
personas_path = _require_str(panel_cfg, "source")
products_path = _require_str(prod_cfg, "source")
persona_variant = _get_str(panel_cfg, "variant", default="default") or "default"
product_variant = _get_str(prod_cfg, "variant", default="default") or "default"
output_dir = _get_str(d, "output_dir", default=None)
# --- generator basics ---
gen_cfg_raw = _get_mapping(d, "generator", default={})
schema_version = _get_str(gen_cfg_raw, "schema_version", default="0.1.0")
seed = _get_int(gen_cfg_raw, "seed", default=0)
n_periods = _get_int(gen_cfg_raw, "n_periods", default=1)
validate_on_finish = _get_bool(gen_cfg_raw, "validate_on_finish", default=True)
max_errors = _get_int(gen_cfg_raw, "max_errors", default=50)
event_namespace = _get_str(gen_cfg_raw, "event_namespace", default="sim_panel.v0")
max_workers = _get_int(gen_cfg_raw, "max_workers", default=1)
prompting_strategy = _get_str(gen_cfg_raw, "prompting_strategy", default="persona") or "persona"
# --- policy config (required) ---
policy_cfg_raw = _require_mapping(d, "policy")
policy_cfg = _build_policy_config(policy_cfg_raw)
# --- selection config ---
selection_cfg_raw = _get_mapping(d, "selection", default={})
selection_cfg = SelectionConfig(
allow_empty=_get_bool(selection_cfg_raw, "allow_empty", default=True),
# NOTE: SelectionConfig uses `include_features` (product features only).
# Keep YAML key `include_product_features` for readability/back-compat.
include_features=_get_bool(selection_cfg_raw, "include_product_features", default=True),
require_json_only=_get_bool(selection_cfg_raw, "require_json_only", default=True),
max_selected_soft=_get_optional_int(selection_cfg_raw, "max_selected_soft", default=None),
include_raw_text=_get_bool(selection_cfg_raw, "include_raw_text", default=True),
custom_few_shot_example=_get_mapping(selection_cfg_raw, "custom_few_shot_example", default=None),
)
# --- execution rules ---
exec_cfg_raw = _get_mapping(d, "execution", default={})
rules = ExecutionRules(
enforce_subset_of_choice_set=_get_bool(exec_cfg_raw, "enforce_subset_of_choice_set", default=True),
max_evals_per_panelist_per_t=_get_optional_int(exec_cfg_raw, "max_evals_per_panelist_per_t", default=None),
allow_empty=_get_bool(exec_cfg_raw, "allow_empty", default=True),
keep_strategy=_get_str(exec_cfg_raw, "keep_strategy", default="keep_first"),
)
execution_cfg = ExecutionConfig(rules=rules)
# --- outcomes (optional) ---
outcome_cfg = None
if "questionnaire" in d or "outcomes_model" in d:
outcome_cfg = outcome_config_from_yaml_dict(d)
# --- backend (optional unless enrichment or llm outcomes enabled) ---
backend_cfg = _get_mapping(d, "backend", default=None)
backend: Optional[Backend] = None
if backend_cfg is not None:
backend = _build_backend(backend_cfg)
# --- load records ---
persona_records = load_persona_records(personas_path)
product_records = load_product_records(products_path)
# --- persisted enrichment (optional) ---
persona_records, personas_path = _maybe_enrich_personas(
persona_records,
source_path=personas_path,
section=panel_cfg,
variant=persona_variant,
backend=backend,
)
product_records, products_path = _maybe_enrich_products(
product_records,
source_path=products_path,
section=prod_cfg,
variant=product_variant,
backend=backend,
)
# --- manual policy mapping injection (optional; required if policy.name == "manual") ---
if policy_cfg.name == "manual":
manual_section = policy_cfg_raw.get("manual")
if not isinstance(manual_section, Mapping):
raise ValueError("policy.name=manual requires a 'policy.manual' mapping section.")
# Validate against IDs available for the selected variants (post-enrichment).
panelist_ids = [
r.persona_id for r in persona_records
if getattr(r, "persona_text_variant", "default") == persona_variant
]
product_ids = [
r.product_id for r in product_records
if getattr(r, "display_variant", "default") == product_variant
]
fmt = _require_str(manual_section, "format")
path = _require_str(manual_section, "path")
on_unknown = _get_str(manual_section, "on_unknown", default="error") or "error"
schedule = load_manual_schedule(
path=path,
format=fmt,
panelist_ids=panelist_ids,
product_ids=product_ids,
on_unknown=on_unknown,
panelist_id_col=_get_str(manual_section, "panelist_id_col", default="panelist_id") or "panelist_id",
product_id_col=_get_str(manual_section, "product_id_col", default="product_id") or "product_id",
t_col=_get_str(manual_section, "t_col", default="t") or "t",
default_t=_get_int(manual_section, "default_t", default=0),
)
policy_cfg = replace(policy_cfg, manual_assignment_fn=schedule.to_fn(on_unknown=on_unknown))
# --- panelist runtime settings ---
eval_settings = _build_eval_settings(panel_cfg.get("eval_settings"))
select_settings = _build_select_settings(panel_cfg.get("select_settings"))
# If outcomes model is LLM, require backend
if outcome_cfg is not None and outcome_cfg.name == "llm" and backend is None:
raise ValueError("outcomes_model.name=llm requires a backend configuration.")
# --- build runtime objects ---
panelists = build_panelists(
persona_records,
backend=backend,
variant=persona_variant,
eval_settings=eval_settings,
select_settings=select_settings,
)
products = build_products(product_records, variant=product_variant)
gen_cfg = GeneratorConfig(
schema_version=schema_version,
seed=seed,
n_periods=n_periods,
policy=policy_cfg,
selection=selection_cfg,
execution=execution_cfg,
outcome=outcome_cfg,
validate_on_finish=validate_on_finish,
max_errors=max_errors,
event_namespace=event_namespace,
max_workers=max_workers,
prompting_strategy=prompting_strategy,
)
generator = EventGenerator(gen_cfg)
run_cfg = RunConfig(
generator=gen_cfg,
personas_path=personas_path,
products_path=products_path,
persona_variant=persona_variant,
product_variant=product_variant,
output_dir=output_dir,
)
return RunBundle(
generator=generator,
panelists=panelists,
products=products,
config_snapshot=cfg_snapshot,
run_config=run_cfg,
)
# ----------------------------
# Enrichment orchestrators
# ----------------------------
def _maybe_enrich_personas(
records,
*,
source_path: str,
section: Mapping[str, Any],
variant: str,
backend: Optional[Backend],
):
enrich = section.get("enrich")
if not isinstance(enrich, Mapping):
return records, source_path
enabled = bool(enrich.get("enabled", False))
if not enabled:
return records, source_path
if backend is None:
raise ValueError("panelists.enrich.enabled=true requires a backend configuration.")
overwrite = bool(enrich.get("overwrite", False))
save_target = enrich.get("save", "in_place")
save_path = _resolve_save_path(source_path, save_target)
settings_raw = enrich.get("settings", {})
if not isinstance(settings_raw, Mapping):
raise ValueError("panelists.enrich.settings must be a mapping if provided.")
settings = PersonaTextGenSettings(
prompt_version=str(settings_raw.get("prompt_version", "v1")),
temperature=float(settings_raw.get("temperature", 0.2)),
max_tokens=settings_raw.get("max_tokens"),
metadata=settings_raw.get("metadata"),
max_workers=max(1, int(settings_raw.get("max_workers", 1))),
)
updated = ensure_persona_text(
records,
backend=backend,
settings=settings,
variant=variant,
overwrite=overwrite,
)
save_persona_records(save_path, updated)
return updated, save_path
def _maybe_enrich_products(
records,
*,
source_path: str,
section: Mapping[str, Any],
variant: str,
backend: Optional[Backend],
):
enrich = section.get("enrich")
if not isinstance(enrich, Mapping):
return records, source_path
enabled = bool(enrich.get("enabled", False))
if not enabled:
return records, source_path
if backend is None:
raise ValueError("products.enrich.enabled=true requires a backend configuration.")
overwrite = bool(enrich.get("overwrite", False))
save_target = enrich.get("save", "in_place")
save_path = _resolve_save_path(source_path, save_target)
settings_raw = enrich.get("settings", {})
if not isinstance(settings_raw, Mapping):
raise ValueError("products.enrich.settings must be a mapping if provided.")
settings = ProductDisplayTextGenSettings(
prompt_version=str(settings_raw.get("prompt_version", "v1")),
temperature=float(settings_raw.get("temperature", 0.2)),
max_tokens=settings_raw.get("max_tokens"),
metadata=settings_raw.get("metadata"),
campaign=settings_raw.get("campaign"),
tone=str(settings_raw.get("tone", "neutral")),
length=str(settings_raw.get("length", "short")),
max_workers=max(1, int(settings_raw.get("max_workers", 1))),
)
updated = ensure_display_text(
records,
backend=backend,
settings=settings,
variant=variant,
overwrite=overwrite,
)
save_product_records(save_path, updated)
return updated, save_path
def _resolve_save_path(source_path: str, save: Any) -> str:
"""
save can be:
- "in_place" (default): overwrite source_path
- {"path": "..."}: write to explicit path
"""
if save is None:
return source_path
if isinstance(save, str):
if save == "in_place":
return source_path
raise ValueError("enrich.save must be 'in_place' or a mapping with {path: ...}")
if isinstance(save, Mapping):
p = save.get("path")
if not isinstance(p, str) or not p:
raise ValueError("enrich.save.path must be a non-empty string")
return p
raise ValueError("enrich.save must be 'in_place' or a mapping with {path: ...}")
# ----------------------------
# Builders
# ----------------------------
def _build_policy_config(d: Mapping[str, Any]) -> PolicyConfig:
name = _require_str(d, "name")
evals_per_period = _get_int(d, "evals_per_period", default=1)
random_mode = _get_str(d, "random_mode", default="balanced_quota")
product_probs = d.get("product_probs")
if product_probs is not None and not isinstance(product_probs, dict):
raise ValueError("policy.product_probs must be a mapping if provided.")
choice_set_size = d.get("choice_set_size", None)
if choice_set_size is not None and not isinstance(choice_set_size, int):
raise ValueError("policy.choice_set_size must be int or null.")
allow_empty_selection = _get_bool(d, "allow_empty_selection", default=True)
# manual mapping injection handled later (io/manual_schedule)
manual_assignment_fn = None
return PolicyConfig(
name=name, # type: ignore[arg-type]
evals_per_period=evals_per_period,
random_mode=random_mode, # type: ignore[arg-type]
product_probs=product_probs,
choice_set_size=choice_set_size,
allow_empty_selection=allow_empty_selection,
manual_assignment_fn=manual_assignment_fn,
)
def _build_eval_settings(x: Any) -> Optional[EvalSettings]:
if x is None:
return None
if not isinstance(x, Mapping):
raise ValueError("panelists.eval_settings must be a mapping if provided.")
return EvalSettings(
temperature=float(x.get("temperature", 0.2)),
max_tokens=x.get("max_tokens"),
metadata=x.get("metadata"),
)
def _build_select_settings(x: Any) -> Optional[SelectSettings]:
if x is None:
return None
if not isinstance(x, Mapping):
raise ValueError("panelists.select_settings must be a mapping if provided.")
return SelectSettings(
temperature=float(x.get("temperature", 0.2)),
max_tokens=x.get("max_tokens"),
metadata=x.get("metadata"),
)
def _build_backend(d: Mapping[str, Any]) -> Backend:
"""
Wire backend from YAML.
Expects a registry builder at:
sim_panel.backends.registry.build_backend_from_dict
"""
try:
from sim_panel.backends.registry import build_backend_from_dict # type: ignore
except Exception as e:
raise NotImplementedError(
"backend config provided, but sim_panel.backends.registry.build_backend_from_dict "
"is not available. Implement that builder or omit 'backend' for deterministic runs."
) from e
return build_backend_from_dict(d)
# ----------------------------
# Parsing helpers
# ----------------------------
def _require_mapping(d: Mapping[str, Any], key: str) -> Mapping[str, Any]:
if key not in d:
raise ValueError(f"Missing required key: {key}")
v = d[key]
if not isinstance(v, Mapping):
raise ValueError(f"{key} must be a mapping/dict, got {type(v).__name__}")
return v
def _get_mapping(d: Mapping[str, Any], key: str, default: Any) -> Any:
if key not in d:
return default
v = d[key]
if v is None:
return default
if not isinstance(v, Mapping):
raise ValueError(f"{key} must be a mapping/dict, got {type(v).__name__}")
return v
def _require_str(d: Mapping[str, Any], key: str) -> str:
v = d.get(key)
if not isinstance(v, str) or not v:
raise ValueError(f"{key} must be a non-empty string")
return v
def _get_str(d: Mapping[str, Any], key: str, default: Optional[str]) -> Optional[str]:
if key not in d or d.get(key) is None:
return default
v = d.get(key)
if not isinstance(v, str):
raise ValueError(f"{key} must be a string")
return v
def _get_int(d: Mapping[str, Any], key: str, default: int) -> int:
if key not in d or d.get(key) is None:
return default
v = d.get(key)
if not isinstance(v, int):
raise ValueError(f"{key} must be an int")
return v
def _get_optional_int(d: Mapping[str, Any], key: str, default: Optional[int]) -> Optional[int]:
if key not in d or d.get(key) is None:
return default
v = d.get(key)
if not isinstance(v, int):
raise ValueError(f"{key} must be an int or null")
return v
def _get_bool(d: Mapping[str, Any], key: str, default: bool) -> bool:
if key not in d or d.get(key) is None:
return default
v = d.get(key)
if not isinstance(v, bool):
raise ValueError(f"{key} must be a bool")
return v