Source code for sim_panel.panelists.enrich

from __future__ import annotations

import concurrent.futures
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple


from sim_panel.utils.time import utc_now_iso
from sim_panel.utils.progress import tqdm_wrap
from sim_panel.backends import Backend
from sim_panel.backends.types import Message

from .records import PersonaRecord
from .render import render_persona_text_prompt


[docs] @dataclass(frozen=True) class PersonaTextGenSettings: prompt_version: str = "v1" temperature: float = 0.2 max_tokens: Optional[int] = None metadata: Optional[Dict[str, Any]] = None # passed to backend.chat(), optional max_workers: int = 1
def _enrich_one_persona( r: PersonaRecord, *, backend: Backend, settings: PersonaTextGenSettings, ) -> Tuple[str, Dict[str, Any]]: """Generate persona_text for a single record. Returns (text, provenance_entry).""" if r.attributes is None: raise ValueError( f"Cannot generate persona_text for persona_id={r.persona_id}: attributes is None." ) prompt = render_persona_text_prompt( r.attributes, prompt_version=settings.prompt_version ) messages: List[Message] = [ {"role": "system", "content": prompt["system"]}, {"role": "user", "content": prompt["user"]}, ] res = backend.chat( messages, temperature=settings.temperature, max_tokens=settings.max_tokens, metadata=settings.metadata, ) persona_text = res.content.strip() prov = { "generated_at": utc_now_iso(), "prompt_version": settings.prompt_version, "temperature": settings.temperature, "max_tokens": settings.max_tokens, "backend": {"name": backend.config.name, "model": res.model}, "usage": ( { "prompt_tokens": res.usage.prompt_tokens, "completion_tokens": res.usage.completion_tokens, "total_tokens": res.usage.total_tokens, } if backend.config.return_usage else None ), } return persona_text, prov
[docs] def ensure_persona_text( records: List[PersonaRecord], *, backend: Backend, settings: PersonaTextGenSettings, variant: str = "default", overwrite: bool = False, progress: bool = True, ) -> List[PersonaRecord]: """ For each record of the given variant: - if persona_text missing (or overwrite=True), generate from attributes - write provenance fields Requires attributes to be present for generation. """ # Partition records: those needing generation vs pass-through. to_generate: List[Tuple[int, PersonaRecord]] = [] out: List[Optional[PersonaRecord]] = [None] * len(records) for i, r in enumerate(records): if r.persona_text_variant != variant: out[i] = r continue needs_text = overwrite or (r.persona_text is None or not r.persona_text.strip()) if not needs_text: out[i] = r continue to_generate.append((i, r)) n_total = sum(1 for r in records if r.persona_text_variant == variant) desc = f"Enrich personas ({len(to_generate)}/{n_total})" use_parallel = settings.max_workers > 1 and len(to_generate) > 1 if use_parallel: with concurrent.futures.ThreadPoolExecutor(max_workers=settings.max_workers) as pool: future_to_idx = { pool.submit( _enrich_one_persona, r, backend=backend, settings=settings ): (i, r) for i, r in to_generate } for future in tqdm_wrap( concurrent.futures.as_completed(future_to_idx), total=len(future_to_idx), desc=desc, enabled=progress, ): i, r = future_to_idx[future] persona_text, prov = future.result() r.persona_text = persona_text r.text_key = None r.compute_keys() r.provenance = {**r.provenance, "persona_text": prov} out[i] = r else: for i, r in tqdm_wrap(to_generate, total=len(to_generate), desc=desc, enabled=progress): persona_text, prov = _enrich_one_persona(r, backend=backend, settings=settings) r.persona_text = persona_text r.text_key = None r.compute_keys() r.provenance = {**r.provenance, "persona_text": prov} out[i] = r return [r for r in out if r is not None]