from __future__ import annotations
import concurrent.futures
import sys
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Sequence
from sim_panel.utils.progress import tqdm_wrap
from sim_panel.decisions.selection import (
render_selection_prompt,
parse_selection_response,
apply_execution_rules,
)
from sim_panel.decisions.types import SelectionContext
from sim_panel.generators.rng import make_rng, stable_event_id
from sim_panel.generators.types import GeneratorConfig
from sim_panel.outcomes.base import EvaluationContext
from sim_panel.outcomes.registry import build_outcome_model
from sim_panel.policies.registry import build_policy
from sim_panel.policies.types import ExposureDecision
from sim_panel.policies.random import RandomAssignmentPolicy # renamed file
from sim_panel.schema.validate import (
validate_rows,
validate_unique_event_id,
validate_self_selection_links,
)
from sim_panel.panelists.panelist import Panelist
from sim_panel.products.product import Product
# Type for checkpoint callback: called with (period, rows_for_period)
OnPeriodComplete = Callable[[int, List[Dict[str, Any]]], None]
[docs]
@dataclass
class EventGenerator:
"""
Orchestrates policy exposure + panelist selection/evaluation + outcomes into schema rows.
Expectations:
- panelists: Sequence[Panelist] runtime agents (LLM-capable)
- products: Sequence[Product] runtime wrappers (record + display())
"""
cfg: GeneratorConfig
[docs]
def generate(
self,
*,
panelists: Sequence[Panelist],
products: Sequence[Product],
progress: bool = True,
resume_from_period: int = 0,
prior_rows: Optional[List[Dict[str, Any]]] = None,
on_period_complete: Optional[OnPeriodComplete] = None,
) -> List[Dict[str, Any]]:
rng = make_rng(self.cfg.seed)
policy = build_policy(self.cfg.policy)
outcome_model = build_outcome_model(self.cfg.outcome) if self.cfg.outcome is not None else None
panelist_ids = [p.panelist_id for p in panelists]
product_ids = [p.product_id for p in products]
product_by_id = {p.product_id: p for p in products}
panelist_by_id = {p.panelist_id: p for p in panelists}
# Start with prior rows from checkpoint (if resuming)
rows: List[Dict[str, Any]] = list(prior_rows) if prior_rows else []
# Advance RNG past completed periods so seeds match a fresh run
for t in range(resume_from_period):
self._decide_for_period(
rng=rng, policy=policy, panelist_ids=panelist_ids,
t=t, product_ids=product_ids,
)
if resume_from_period > 0:
print(
f"[checkpoint] Resuming from period {resume_from_period}/{self.cfg.n_periods} "
f"({len(rows)} prior rows loaded)",
file=sys.stderr, flush=True,
)
remaining = range(resume_from_period, self.cfg.n_periods)
for t in tqdm_wrap(remaining, total=len(remaining), desc="Periods", enabled=progress):
# generator owns time; update runtime agents
for p in tqdm_wrap(panelists, total=len(panelists), desc=f"Set t={t}", enabled=progress and len(panelists) > 1000):
p.state.t = t
# Decide exposures for this period
decisions = self._decide_for_period(
rng=rng,
policy=policy,
panelist_ids=panelist_ids,
t=t,
product_ids=product_ids,
)
period_rows: List[Dict[str, Any]] = []
# Execute decisions into events
use_parallel = self.cfg.max_workers > 1 and len(decisions) > 1
if use_parallel:
ordered_batches: List[Optional[List[Dict[str, Any]]]] = [None] * len(decisions)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.cfg.max_workers) as pool:
future_to_meta = {
pool.submit(
self._execute_decision,
dec=dec,
panelist=panelist_by_id[dec.panelist_id],
product_by_id=product_by_id,
t=t,
outcome_model=outcome_model,
): (idx, dec)
for idx, dec in enumerate(decisions)
}
for future in tqdm_wrap(
concurrent.futures.as_completed(future_to_meta),
total=len(future_to_meta),
desc=f"Execute t={t}",
enabled=progress,
):
idx, dec = future_to_meta[future]
try:
ordered_batches[idx] = future.result()
except Exception as exc:
raise RuntimeError(
f"Decision execution failed at t={t}, idx={idx}, panelist_id={dec.panelist_id}"
) from exc
for batch in ordered_batches:
if batch is None:
raise RuntimeError("Parallel execution finished with a missing result batch.")
period_rows.extend(batch)
else:
for dec in tqdm_wrap(decisions, total=len(decisions), desc=f"Execute t={t}", enabled=progress):
period_rows.extend(
self._execute_decision(
dec=dec,
panelist=panelist_by_id[dec.panelist_id],
product_by_id=product_by_id,
t=t,
outcome_model=outcome_model,
)
)
rows.extend(period_rows)
if on_period_complete is not None:
on_period_complete(t, period_rows)
if self.cfg.validate_on_finish:
self._validate(rows)
return rows
def _execute_decision(
self,
*,
dec: ExposureDecision,
panelist: Panelist,
product_by_id: Dict[str, Product],
t: int,
outcome_model: Optional[Any],
) -> List[Dict[str, Any]]:
"""Process a single ExposureDecision and return the resulting event rows."""
result_rows: List[Dict[str, Any]] = []
if dec.evaluate_product_ids is not None:
# random/manual: directly evaluate assigned products
for prod_id in dec.evaluate_product_ids:
prod = product_by_id.get(prod_id)
if prod is None:
raise ValueError(f"Policy assigned unknown product_id={prod_id!r}")
result_rows.append(
self._emit_evaluation_event(
panelist=panelist,
product=prod,
t=t,
selection_id=None,
outcome_model=outcome_model,
)
)
return result_rows
if dec.selection is not None:
# self_selection: show choice_set, let panelist request any number, then apply execution rules
choice_set = list(dec.selection.choice_set)
products_shown: List[Dict[str, Any]] = []
for pid in choice_set:
prod = product_by_id.get(pid)
if prod is None:
continue
item: Dict[str, Any] = {
"product_id": pid,
"product_display": prod.display(),
}
if self.cfg.selection.include_features:
item["product_features"] = dict(prod.attributes)
products_shown.append(item)
sel_ctx = SelectionContext(
panelist_id=panelist.panelist_id,
t=t,
products_shown=products_shown,
)
strategy = self.cfg.prompting_strategy
sel_prompt = render_selection_prompt(ctx=sel_ctx, cfg=self.cfg.selection, prompting_strategy=strategy)
# For zero_shot / few_shot, override system prompt to remove persona
sel_system_prompt = None
if strategy in ("zero_shot", "few_shot"):
sel_system_prompt = "You are evaluating consumer products. Provide honest, thoughtful responses."
raw_sel = panelist.select(
task_prompt=sel_prompt,
choice_set=choice_set,
metadata={"module": "generators.selection", "policy": self.cfg.policy.name, "t": t},
system_prompt=sel_system_prompt,
)
parsed_sel = parse_selection_response(
raw_text=raw_sel,
choice_set_ids=choice_set,
cfg=self.cfg.selection,
)
# Emit selection event (records free-will request)
sel_row = self._emit_selection_event(
panelist_id=panelist.panelist_id,
t=t,
choice_set=choice_set,
selected_product_ids=parsed_sel.requested_product_ids,
selection_traces=parsed_sel.traces,
selection_errors=parsed_sel.errors,
)
result_rows.append(sel_row)
selection_id = sel_row["event_id"]
executed, dropped = apply_execution_rules(
requested_product_ids=parsed_sel.requested_product_ids,
choice_set_ids=choice_set,
rules=self.cfg.execution.rules,
)
if not self.cfg.execution.rules.allow_empty and len(executed) == 0 and len(choice_set) > 0:
# v0 fallback: evaluate the first shown item deterministically
executed = [choice_set[0]]
# Evaluate executed subset
for prod_id in executed:
prod = product_by_id.get(prod_id)
if prod is None:
continue
result_rows.append(
self._emit_evaluation_event(
panelist=panelist,
product=prod,
t=t,
selection_id=selection_id,
outcome_model=outcome_model,
)
)
# Store operational details as traces on the selection row
if dropped or executed:
sel_row.setdefault("traces", {})
if isinstance(sel_row["traces"], dict):
if dropped:
sel_row["traces"]["dropped_product_ids"] = dropped
sel_row["traces"]["executed_product_ids"] = executed
return result_rows
raise ValueError("ExposureDecision must have either evaluate_product_ids or selection.")
def _decide_for_period(
self,
*,
rng,
policy,
panelist_ids: Sequence[str],
t: int,
product_ids: Sequence[str],
) -> List[ExposureDecision]:
# Balanced RCT-like assignment is inherently global; use batch API when available.
if isinstance(policy, RandomAssignmentPolicy) and getattr(policy.cfg, "random_mode", None) == "balanced_quota":
return policy.decide_batch(rng=rng, panelist_ids=panelist_ids, t=t, product_ids=product_ids) # type: ignore[attr-defined]
return [
policy.decide(rng=rng, panelist_id=pid, t=t, product_ids=product_ids)
for pid in panelist_ids
]
def _emit_selection_event(
self,
*,
panelist_id: str,
t: int,
choice_set: List[str],
selected_product_ids: List[str],
selection_traces: Optional[Dict[str, Any]],
selection_errors: Optional[List[str]],
) -> Dict[str, Any]:
base: Dict[str, Any] = {
"schema_version": self.cfg.schema_version,
"event_type": "selection",
"policy": self.cfg.policy.name,
"panelist_id": panelist_id,
"t": t,
"choice_set": list(choice_set),
"selected_product_ids": list(selected_product_ids),
"selection_id": None,
"outcomes": None,
"traces": selection_traces if selection_traces is not None else None,
}
if selection_errors:
base.setdefault("traces", {})
if isinstance(base["traces"], dict):
base["traces"]["selection_errors"] = selection_errors
event_id = stable_event_id(
self.cfg.event_namespace,
{
"schema_version": self.cfg.schema_version,
"event_type": "selection",
"policy": self.cfg.policy.name,
"panelist_id": panelist_id,
"t": t,
"choice_set": list(choice_set),
"selected_product_ids": list(selected_product_ids),
},
)
return {"event_id": event_id, **base, **(self.cfg.row_meta or {})}
def _emit_evaluation_event(
self,
*,
panelist: Panelist,
product: Product,
t: int,
selection_id: Optional[str],
outcome_model: Optional[Any],
) -> Dict[str, Any]:
panelist_id = panelist.panelist_id
product_id = product.product_id
product_display = product.display()
panelist_features: Dict[str, Any] = {}
if self.cfg.include_panelist_features_in_events:
panelist_features = dict(getattr(panelist, "attributes", {}) or {})
product_features: Dict[str, Any] = {}
if self.cfg.include_product_features_in_events:
product_features = dict(product.attributes)
outcomes = None
traces = None
if outcome_model is not None:
ctx = EvaluationContext(
panelist_id=panelist_id,
product_id=product_id,
t=t,
product_display=product_display,
panelist_features=panelist_features,
product_features=product_features,
)
res = outcome_model.evaluate(panelist=panelist, ctx=ctx, prompting_strategy=self.cfg.prompting_strategy)
outcomes = res.outcomes
traces = res.traces
if res.errors:
traces = dict(traces or {})
traces["outcome_errors"] = res.errors
base = {
"schema_version": self.cfg.schema_version,
"event_type": "evaluation",
"policy": self.cfg.policy.name,
"panelist_id": panelist_id,
"t": t,
"product_id": product_id,
"product_display": product_display,
"panelist_features": panelist_features,
"product_features": product_features,
"outcomes": outcomes,
"traces": traces,
"selection_id": selection_id,
}
event_id = stable_event_id(
self.cfg.event_namespace,
{
"schema_version": self.cfg.schema_version,
"event_type": "evaluation",
"policy": self.cfg.policy.name,
"panelist_id": panelist_id,
"t": t,
"product_id": product_id,
"selection_id": selection_id,
},
)
return {"event_id": event_id, **base, **(self.cfg.row_meta or {})}
def _validate(self, rows: Sequence[Dict[str, Any]]) -> None:
report = validate_rows(rows, schema_version=None, max_errors=self.cfg.max_errors)
if not report.ok:
raise ValueError(f"Schema validation failed: {report.summary()}")
ok, msg = validate_unique_event_id(rows)
if not ok:
raise ValueError(f"event_id uniqueness check failed: {msg}")
if self.cfg.policy.name == "self_selection":
ok2, problems = validate_self_selection_links(rows)
if not ok2:
raise ValueError(f"self_selection link check failed: {problems}")