The model layer of the project, built honestly:
- tools/dataset_validate.py — full-sweep validator over the receiver
store (sha256, schema, monotonic labels, telemetry-row gate). On the
current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
is committed as the per-episode acceptance index.
- training/_features.py — channel registry (46 channels across
proc/guest/qmp/netflow), summary-stat windowing AND channel×time
tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
(Unix ns) — tested fix for a real netflow-vs-host clock-base
inconsistency that was silently dropping every netflow channel.
- training/_split.py — three held-out recipes (host / sample / time)
with profile-stratification assertions. held_out_host carries
untested_profiles for cases like scan-and-dial absent from the test
host (5 of 6 profiles tested cross-device, never silently averaged).
- training/models/ — 6 architectures behind a common BaseModel
interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
trained twice (realistic / oracle) per the deployment threat model.
Schema-hashed checkpoints refuse to load if _features.py changed
since training (silent-input-drift protection, tested).
- training/trainer/ — unified training loop: class-weighted CE, LR
warmup + cosine, gradient clipping, mixed precision when CUDA,
early stopping on val macro F1, best-on-val checkpoint. Same loop
runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
early_stopping_rounds on val mlogloss.
- training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
per-profile and per-host breakdown, paired-bootstrap significance
for model-vs-model gap. Confusion matrix uses union of seen labels.
- training/dashboard/producers/ — replay/metrics/perf/profiles
emitting the six event types the dashboard's awaiting scenes
consume; on-demand tensor extraction so the Pi can run live
inference without 65 GB of shards.
- 17 unit tests (split coverage, features round-trip, schema mismatch,
determinism, time-base alignment regression).
End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
467 lines
19 KiB
Python
467 lines
19 KiB
Python
"""Per-episode and per-window feature extraction.
|
||
|
||
Channel registry — explicit so we know the dim count is stable. Each
|
||
channel has:
|
||
- source: one of proc / guest / qmp / netflow
|
||
- kind: "level" (instantaneous) or "counter" (cumulative; we diff to rate)
|
||
- getter: f(row) -> float | None
|
||
- in_deployment: bool — is this signal available to the deployed model?
|
||
(oracle sees all; realistic sees only True)
|
||
|
||
Time alignment: each row carries both ``t_mono_ns`` and ``t_wall_ns``,
|
||
but the producers are inconsistent about what ``t_mono_ns`` means —
|
||
labels / proc / guest / qmp use *episode-relative monotonic* (tiny
|
||
values, zeroed at episode start), while netflow uses *system uptime*
|
||
(boot-relative, billions of ns). Aligning on ``t_mono_ns`` would silently
|
||
push every netflow sample to t≈86000s and drop it from every window.
|
||
|
||
We therefore use ``t_wall_ns`` as the canonical clock — every source has
|
||
it on the same Unix-nano scale. Episode-relative seconds are computed as
|
||
``(row.t_wall_ns - first_label.t_wall_ns) / 1e9``. Netflow rounds to
|
||
1-second precision, but the project's window is 10s with 5s stride, so
|
||
the rounding does not shift a sample across a window boundary.
|
||
|
||
A window is a half-open interval ``[t0, t0+W)`` in episode-relative
|
||
seconds. The phase label of a window is the most recent ``labels.jsonl``
|
||
entry whose time <= window center.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Callable, Iterable
|
||
|
||
import numpy as np
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class Channel:
|
||
name: str # "proc.cpu_user_jiffies"
|
||
source: str # "proc" | "guest" | "qmp" | "netflow"
|
||
kind: str # "level" | "counter"
|
||
getter: Callable[[dict], float | None]
|
||
in_deployment: bool
|
||
|
||
|
||
def _g(*path):
|
||
"""Build a getter from a nested-key path; returns None on miss."""
|
||
def f(row):
|
||
cur = row
|
||
for p in path:
|
||
if cur is None or not isinstance(cur, dict):
|
||
return None
|
||
cur = cur.get(p)
|
||
if cur is None:
|
||
return None
|
||
try:
|
||
return float(cur)
|
||
except (TypeError, ValueError):
|
||
return None
|
||
return f
|
||
|
||
|
||
# host_proc — observable only inside the host (NOT in deployment)
|
||
PROC_CHANNELS = [
|
||
Channel("proc.cpu_user_jiffies", "proc", "counter", _g("cpu_user_jiffies"), False),
|
||
Channel("proc.cpu_sys_jiffies", "proc", "counter", _g("cpu_sys_jiffies"), False),
|
||
Channel("proc.rss_bytes", "proc", "level", _g("rss_bytes"), False),
|
||
Channel("proc.vsize_bytes", "proc", "level", _g("vsize_bytes"), False),
|
||
Channel("proc.io_read_bytes", "proc", "counter", _g("io_read_bytes"), False),
|
||
Channel("proc.io_write_bytes", "proc", "counter", _g("io_write_bytes"), False),
|
||
Channel("proc.voluntary_ctxsw", "proc", "counter", _g("voluntary_ctxsw"), False),
|
||
Channel("proc.involuntary_ctxsw", "proc", "counter", _g("involuntary_ctxsw"), False),
|
||
Channel("proc.minor_faults", "proc", "counter", _g("minor_faults"), False),
|
||
Channel("proc.major_faults", "proc", "counter", _g("major_faults"), False),
|
||
]
|
||
|
||
# guest_agent — IN deployment (this is what the deployed model sees)
|
||
GUEST_CHANNELS = [
|
||
Channel("guest.cpu_user", "guest", "counter", _g("cpu_total_jiffies", "user"), True),
|
||
Channel("guest.cpu_sys", "guest", "counter", _g("cpu_total_jiffies", "system"), True),
|
||
Channel("guest.cpu_idle", "guest", "counter", _g("cpu_total_jiffies", "idle"), True),
|
||
Channel("guest.cpu_iowait", "guest", "counter", _g("cpu_total_jiffies", "iowait"), True),
|
||
Channel("guest.cpu_softirq", "guest", "counter", _g("cpu_total_jiffies", "softirq"), True),
|
||
Channel("guest.load_1m", "guest", "level",
|
||
lambda r: (r.get("load_1m_5m_15m") or [None])[0]
|
||
if isinstance(r.get("load_1m_5m_15m"), list) else None, True),
|
||
Channel("guest.mem_available", "guest", "level", _g("mem_available_bytes"), True),
|
||
Channel("guest.mem_buffers", "guest", "level", _g("mem_buffers_bytes"), True),
|
||
Channel("guest.mem_cached", "guest", "level", _g("mem_cached_bytes"), True),
|
||
Channel("guest.swap_used", "guest", "level", _g("swap_used_bytes"), True),
|
||
Channel("guest.eth0_rx_bytes", "guest", "counter",
|
||
_g("net", "eth0", "rx_bytes"), True),
|
||
Channel("guest.eth0_tx_bytes", "guest", "counter",
|
||
_g("net", "eth0", "tx_bytes"), True),
|
||
Channel("guest.eth0_rx_pkts", "guest", "counter",
|
||
_g("net", "eth0", "rx_pkts"), True),
|
||
Channel("guest.eth0_tx_pkts", "guest", "counter",
|
||
_g("net", "eth0", "tx_pkts"), True),
|
||
Channel("guest.n_listen_ports", "guest", "level",
|
||
lambda r: float(len(r.get("listen_ports") or [])), True),
|
||
Channel("guest.n_top_procs", "guest", "level",
|
||
lambda r: float(len(r.get("top_procs") or [])), True),
|
||
]
|
||
|
||
# host_qmp — host-side QEMU introspection, NOT in deployment
|
||
QMP_CHANNELS = [
|
||
Channel("qmp.virtio0_rd_ops", "qmp", "counter",
|
||
_g("blockstats", "virtio0", "rd_ops"), False),
|
||
Channel("qmp.virtio0_wr_ops", "qmp", "counter",
|
||
_g("blockstats", "virtio0", "wr_ops"), False),
|
||
Channel("qmp.virtio0_rd_bytes", "qmp", "counter",
|
||
_g("blockstats", "virtio0", "rd_bytes"), False),
|
||
Channel("qmp.virtio0_wr_bytes", "qmp", "counter",
|
||
_g("blockstats", "virtio0", "wr_bytes"), False),
|
||
Channel("qmp.kvm_tlb_flush", "qmp", "counter",
|
||
_g("kvm_stats", "remote_tlb_flush"), False),
|
||
Channel("qmp.kvm_pages_4k", "qmp", "level",
|
||
_g("kvm_stats", "pages_4k"), False),
|
||
Channel("qmp.kvm_pages_2m", "qmp", "level",
|
||
_g("kvm_stats", "pages_2m"), False),
|
||
]
|
||
|
||
# bridge_pcap — observable in deployment (network monitor)
|
||
NETFLOW_CHANNELS = [
|
||
Channel("netflow.pkts_in", "netflow", "level", _g("pkts_in"), True),
|
||
Channel("netflow.pkts_out", "netflow", "level", _g("pkts_out"), True),
|
||
Channel("netflow.bytes_in", "netflow", "level", _g("bytes_in"), True),
|
||
Channel("netflow.bytes_out", "netflow", "level", _g("bytes_out"), True),
|
||
Channel("netflow.syn_count", "netflow", "level", _g("syn_count"), True),
|
||
Channel("netflow.fin_count", "netflow", "level", _g("fin_count"), True),
|
||
Channel("netflow.rst_count", "netflow", "level", _g("rst_count"), True),
|
||
Channel("netflow.udp_count", "netflow", "level", _g("udp_count"), True),
|
||
Channel("netflow.tcp_count", "netflow", "level", _g("tcp_count"), True),
|
||
Channel("netflow.dns_query_count", "netflow", "level", _g("dns_query_count"), True),
|
||
Channel("netflow.unique_dst_ips", "netflow", "level", _g("unique_dst_ips"), True),
|
||
Channel("netflow.unique_dst_ports", "netflow", "level", _g("unique_dst_ports"), True),
|
||
Channel("netflow.tcp_new_flows", "netflow", "level", _g("tcp_new_flows"), True),
|
||
]
|
||
|
||
ALL_CHANNELS: list[Channel] = PROC_CHANNELS + GUEST_CHANNELS + QMP_CHANNELS + NETFLOW_CHANNELS
|
||
|
||
PHASES = ["clean", "armed", "infecting", "infected_running", "dormant", "failed"]
|
||
PHASE_TO_INT = {p: i for i, p in enumerate(PHASES)}
|
||
|
||
# Default window geometry — used by both summary-stat and tensor extractors.
|
||
# 10s aligns with PIPELINE.md / dashboard scene 7 ("10-second windows · model
|
||
# input shape"). Stride 5s gives 50% overlap, ~9 windows per ~50s episode.
|
||
DEFAULT_WINDOW_S = 10.0
|
||
DEFAULT_STRIDE_S = 5.0
|
||
TENSOR_HZ = 10.0 # uniform grid for tensor windows
|
||
TENSOR_TIMESTEPS = int(DEFAULT_WINDOW_S * TENSOR_HZ)
|
||
|
||
|
||
def channel_names() -> list[str]:
|
||
return [c.name for c in ALL_CHANNELS]
|
||
|
||
|
||
def channel_in_deployment_mask() -> np.ndarray:
|
||
"""Per-channel (not per-feature-statistic) realistic mask. Used for
|
||
tensor-input models, which see N_CHANNELS rows. The summary-feature
|
||
mask is in_deployment_mask() (length N_CHANNELS * N_STATS)."""
|
||
return np.asarray([c.in_deployment for c in ALL_CHANNELS], dtype=bool)
|
||
|
||
|
||
def _to_array(rows: list[dict], ch: Channel, t0_wall_ns: int
|
||
) -> tuple[np.ndarray, np.ndarray]:
|
||
"""Return (t_seconds_relative_to_episode_start, values) using
|
||
t_wall_ns as the canonical clock. NaN where the value is missing."""
|
||
ts: list[float] = []
|
||
vs: list[float] = []
|
||
for r in rows:
|
||
tw = r.get("t_wall_ns")
|
||
if tw is None:
|
||
# Fall back to t_mono_ns only if t_wall_ns is missing — but
|
||
# this is suspicious because netflow uses uptime in t_mono_ns.
|
||
# Skip: missing t_wall_ns is a producer-side bug.
|
||
continue
|
||
v = ch.getter(r)
|
||
ts.append((tw - t0_wall_ns) / 1e9)
|
||
vs.append(np.nan if v is None else v)
|
||
return np.asarray(ts, dtype=np.float64), np.asarray(vs, dtype=np.float64)
|
||
|
||
|
||
def _diff_counter(ts: np.ndarray, vs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||
"""Counter -> per-second rate via finite differences. Drops first sample."""
|
||
if len(ts) < 2:
|
||
return ts[:0], vs[:0]
|
||
dt = np.diff(ts)
|
||
dv = np.diff(vs)
|
||
# Avoid div-by-zero; rates with dt==0 become NaN.
|
||
with np.errstate(divide="ignore", invalid="ignore"):
|
||
rate = np.where(dt > 0, dv / dt, np.nan)
|
||
return ts[1:], rate
|
||
|
||
|
||
_STAT_NAMES = ("mean", "std", "p50", "p95", "slope")
|
||
N_STATS = len(_STAT_NAMES)
|
||
|
||
|
||
def _stats(ts: np.ndarray, vs: np.ndarray) -> np.ndarray:
|
||
"""Return [mean, std, p50, p95, slope] over a window. NaN-tolerant."""
|
||
if len(vs) == 0:
|
||
return np.full(N_STATS, np.nan)
|
||
finite = np.isfinite(vs)
|
||
if not finite.any():
|
||
return np.full(N_STATS, np.nan)
|
||
v = vs[finite]
|
||
t = ts[finite]
|
||
out = np.empty(N_STATS)
|
||
out[0] = float(np.mean(v))
|
||
out[1] = float(np.std(v)) if len(v) > 1 else 0.0
|
||
out[2] = float(np.percentile(v, 50))
|
||
out[3] = float(np.percentile(v, 95))
|
||
if len(v) >= 2 and t.max() > t.min():
|
||
# Simple linear slope by least-squares
|
||
out[4] = float(np.polyfit(t, v, 1)[0])
|
||
else:
|
||
out[4] = 0.0
|
||
return out
|
||
|
||
|
||
def _phase_at(label_times: np.ndarray, label_phases: list[str], t: float) -> str:
|
||
"""Most recent label.phase whose t <= t. Default to 'clean' before first."""
|
||
if len(label_times) == 0:
|
||
return "clean"
|
||
idx = int(np.searchsorted(label_times, t, side="right")) - 1
|
||
if idx < 0:
|
||
return "clean"
|
||
return label_phases[idx]
|
||
|
||
|
||
def feature_names_episode() -> list[str]:
|
||
return [f"{c.name}.{s}" for c in ALL_CHANNELS for s in _STAT_NAMES]
|
||
|
||
|
||
def feature_names_window() -> list[str]:
|
||
return feature_names_episode()
|
||
|
||
|
||
def in_deployment_mask() -> np.ndarray:
|
||
"""Boolean mask of length n_features marking realistic-deployment cols."""
|
||
out = []
|
||
for c in ALL_CHANNELS:
|
||
out.extend([c.in_deployment] * N_STATS)
|
||
return np.asarray(out, dtype=bool)
|
||
|
||
|
||
def episode_t0_wall_ns(epi) -> int:
|
||
"""Canonical episode-zero point: first label's t_wall_ns."""
|
||
if not epi.labels:
|
||
raise ValueError("episode has no labels — cannot define t0")
|
||
return int(epi.labels[0]["t_wall_ns"])
|
||
|
||
|
||
def channel_arrays(epi, t0_wall_ns: int) -> dict[str, tuple[np.ndarray, np.ndarray]]:
|
||
"""Convert each channel into (t_rel_sec, value) arrays, counters → rate.
|
||
|
||
``t0_wall_ns`` is the wall-clock-nanosecond reference (typically
|
||
``epi.labels[0]["t_wall_ns"]``) that defines episode-relative seconds.
|
||
"""
|
||
out: dict[str, tuple[np.ndarray, np.ndarray]] = {}
|
||
src_rows = {
|
||
"proc": epi.proc,
|
||
"guest": epi.guest,
|
||
"qmp": epi.qmp,
|
||
"netflow": epi.netflow,
|
||
}
|
||
for ch in ALL_CHANNELS:
|
||
rows = src_rows[ch.source]
|
||
ts, vs = _to_array(rows, ch, t0_wall_ns)
|
||
# Sort by time — netflow rounds to 1s precision so consecutive
|
||
# rows can have equal timestamps and out-of-order on disk; sort
|
||
# to keep np.interp / counter-diff sensible.
|
||
if ts.size:
|
||
order = np.argsort(ts, kind="stable")
|
||
ts = ts[order]
|
||
vs = vs[order]
|
||
if ch.kind == "counter":
|
||
ts, vs = _diff_counter(ts, vs)
|
||
out[ch.name] = (ts, vs)
|
||
return out
|
||
|
||
|
||
def episode_features(epi) -> tuple[np.ndarray, dict]:
|
||
"""One feature vector summarizing the whole episode."""
|
||
if not epi.labels:
|
||
return np.full(len(ALL_CHANNELS) * N_STATS, np.nan), {}
|
||
t0 = episode_t0_wall_ns(epi)
|
||
arrs = channel_arrays(epi, t0)
|
||
feats = []
|
||
for ch in ALL_CHANNELS:
|
||
ts, vs = arrs[ch.name]
|
||
feats.append(_stats(ts, vs))
|
||
info = {
|
||
"episode_id": epi.episode_id,
|
||
"host_id": epi.host_id,
|
||
"profile": (epi.meta.get("sample") or {}).get("profile"),
|
||
"sample_name": (epi.meta.get("sample") or {}).get("name"),
|
||
"sample_kind": (epi.meta.get("sample") or {}).get("kind"),
|
||
"duration_s": (epi.meta.get("result") or {}).get("duration_observed_s"),
|
||
}
|
||
return np.concatenate(feats), info
|
||
|
||
|
||
def _window_starts(epi, arrs: dict, window_s: float, stride_s: float
|
||
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||
"""Helper: derive window start times + cached label arrays.
|
||
|
||
Returns (starts, label_t, label_p). ``starts`` may be empty if the
|
||
episode is too short for one full window. All times in episode-
|
||
relative seconds (using t_wall_ns as the canonical clock)."""
|
||
t0 = episode_t0_wall_ns(epi)
|
||
label_t = np.asarray([(L["t_wall_ns"] - t0) / 1e9 for L in epi.labels],
|
||
dtype=np.float64)
|
||
label_p = [L["phase"] for L in epi.labels]
|
||
|
||
duration = (epi.meta.get("result") or {}).get("duration_observed_s") or 0.0
|
||
if duration <= 0:
|
||
max_t = 0.0
|
||
for ts, _ in arrs.values():
|
||
if len(ts):
|
||
max_t = max(max_t, float(ts.max()))
|
||
duration = max_t
|
||
|
||
if duration < window_s:
|
||
return np.zeros(0, dtype=np.float64), label_t, label_p
|
||
|
||
starts = np.arange(0.0, duration - window_s + 1e-9, stride_s)
|
||
return starts, label_t, label_p
|
||
|
||
|
||
def summary_windows(
|
||
epi,
|
||
*,
|
||
window_s: float = DEFAULT_WINDOW_S,
|
||
stride_s: float = DEFAULT_STRIDE_S,
|
||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
|
||
"""Per-window summary-stat features. Feeds GBT and MLP-on-summaries.
|
||
|
||
Returns (X, y_phase, t_center, info) where:
|
||
X shape (n_windows, n_features) float32
|
||
n_features = len(ALL_CHANNELS) * N_STATS
|
||
y_phase shape (n_windows,) int64
|
||
t_center shape (n_windows,) float64
|
||
info dict, episode-level metadata
|
||
"""
|
||
n_feat = len(ALL_CHANNELS) * N_STATS
|
||
empty = (np.zeros((0, n_feat), dtype=np.float32),
|
||
np.zeros((0,), dtype=np.int64),
|
||
np.zeros((0,), dtype=np.float64), {})
|
||
if not epi.labels:
|
||
return empty
|
||
arrs = channel_arrays(epi, episode_t0_wall_ns(epi))
|
||
starts, label_t, label_p = _window_starts(epi, arrs, window_s, stride_s)
|
||
if starts.size == 0:
|
||
info_min = {"episode_id": epi.episode_id, "host_id": epi.host_id}
|
||
return (empty[0], empty[1], empty[2], info_min)
|
||
|
||
rows, phases, centers = [], [], []
|
||
for t_start in starts:
|
||
t_end = t_start + window_s
|
||
t_mid = t_start + window_s / 2
|
||
feats = []
|
||
for ch in ALL_CHANNELS:
|
||
ts, vs = arrs[ch.name]
|
||
mask = (ts >= t_start) & (ts < t_end)
|
||
feats.append(_stats(ts[mask], vs[mask]))
|
||
rows.append(np.concatenate(feats))
|
||
phases.append(PHASE_TO_INT.get(
|
||
_phase_at(label_t, label_p, t_mid), PHASE_TO_INT["clean"]))
|
||
centers.append(t_mid)
|
||
|
||
info = _episode_info(epi)
|
||
return (np.asarray(rows, dtype=np.float32),
|
||
np.asarray(phases, dtype=np.int64),
|
||
np.asarray(centers, dtype=np.float64),
|
||
info)
|
||
|
||
|
||
def tensor_windows(
|
||
epi,
|
||
*,
|
||
window_s: float = DEFAULT_WINDOW_S,
|
||
stride_s: float = DEFAULT_STRIDE_S,
|
||
hz: float = TENSOR_HZ,
|
||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||
"""Per-window channel × time tensor features. Feeds CNN/GRU/LSTM/Transformer.
|
||
|
||
Resamples each channel onto a uniform grid (default 10 Hz) within
|
||
each window. Counter channels are first differenced to rate. Grid
|
||
points outside the channel's actual data range become NaN — the
|
||
trainer fills NaN after standardization, but a missingness mask is
|
||
also returned so models can learn to discount sparse signals.
|
||
|
||
Returns (X, y_phase, t_center, mask, info) where:
|
||
X shape (n_windows, n_channels, n_timesteps) float32
|
||
y_phase shape (n_windows,) int64
|
||
t_center shape (n_windows,) float64
|
||
mask shape (n_windows, n_channels, n_timesteps) bool
|
||
True where the value was interpolated from real data,
|
||
False where it was filled (NaN). Per-channel: a
|
||
channel with zero observations in this episode has
|
||
mask all-False for every window.
|
||
info dict, episode-level metadata
|
||
"""
|
||
n_ch = len(ALL_CHANNELS)
|
||
n_t = int(round(window_s * hz))
|
||
empty = (np.zeros((0, n_ch, n_t), dtype=np.float32),
|
||
np.zeros((0,), dtype=np.int64),
|
||
np.zeros((0,), dtype=np.float64),
|
||
np.zeros((0, n_ch, n_t), dtype=bool), {})
|
||
if not epi.labels:
|
||
return empty
|
||
arrs = channel_arrays(epi, episode_t0_wall_ns(epi))
|
||
starts, label_t, label_p = _window_starts(epi, arrs, window_s, stride_s)
|
||
if starts.size == 0:
|
||
info_min = {"episode_id": epi.episode_id, "host_id": epi.host_id}
|
||
return (empty[0], empty[1], empty[2], empty[3], info_min)
|
||
|
||
n_w = len(starts)
|
||
X = np.zeros((n_w, n_ch, n_t), dtype=np.float32)
|
||
M = np.zeros((n_w, n_ch, n_t), dtype=bool)
|
||
y = np.zeros(n_w, dtype=np.int64)
|
||
centers = np.zeros(n_w, dtype=np.float64)
|
||
|
||
for w, t_start in enumerate(starts):
|
||
t_end = t_start + window_s
|
||
t_mid = t_start + window_s / 2
|
||
# Half-open grid; exclusive of t_end so adjacent windows do not
|
||
# share boundary samples.
|
||
grid = t_start + np.arange(n_t) / hz
|
||
for c, ch in enumerate(ALL_CHANNELS):
|
||
ts, vs = arrs[ch.name]
|
||
if ts.size == 0:
|
||
# Channel has no data this episode — leave zeros, mask False
|
||
continue
|
||
finite = np.isfinite(vs)
|
||
if not finite.any():
|
||
continue
|
||
ts_f = ts[finite]
|
||
vs_f = vs[finite]
|
||
# left/right=NaN so out-of-range grid points become NaN
|
||
interp = np.interp(grid, ts_f, vs_f,
|
||
left=np.nan, right=np.nan)
|
||
valid = np.isfinite(interp)
|
||
X[w, c, valid] = interp[valid]
|
||
M[w, c] = valid
|
||
y[w] = PHASE_TO_INT.get(
|
||
_phase_at(label_t, label_p, t_mid), PHASE_TO_INT["clean"])
|
||
centers[w] = t_mid
|
||
|
||
return X, y, centers, M, _episode_info(epi)
|
||
|
||
|
||
def _episode_info(epi) -> dict:
|
||
return {
|
||
"episode_id": epi.episode_id,
|
||
"host_id": epi.host_id,
|
||
"profile": (epi.meta.get("sample") or {}).get("profile"),
|
||
"sample_name": (epi.meta.get("sample") or {}).get("name"),
|
||
"sample_kind": (epi.meta.get("sample") or {}).get("kind"),
|
||
"duration_s": (epi.meta.get("result") or {}).get("duration_observed_s"),
|
||
}
|
||
|
||
|
||
# Backwards-compat alias kept until producers/build_features migrate.
|
||
def window_features(epi, window_s: float = 1.0, stride_s: float = 0.5):
|
||
return summary_windows(epi, window_s=window_s, stride_s=stride_s)
|