From 2abc55a59b136c38520acfd0d3566a5724ae388e Mon Sep 17 00:00:00 2001 From: Max Gorog Date: Fri, 8 May 2026 13:33:19 -0500 Subject: [PATCH] knn scatter: auto-fit projection to running data spread MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Project around mean ± k·σ instead of the raw [0,1]³ producer-unit cube. PCA-3 outputs are Gaussian-ish so even after the producer's min/max rescale, the bulk of points clusters near the centroid; without auto-fit the scatter looks dead-centre and tiny. Implementation: incremental Welford-ish stats (running sum / sum²) per axis, recomputed lazily on the first frame after new data arrives. project() centers and σ-scales each point to ~[-0.5, 0.5]; outliers clamp to ±0.7 so they're visible just outside the cube. The bounding cube now traces mean ± k·σ instead of [0,1]³, which is also the natural visual unit for the "data spread" the user reads off the screen. resetStats() runs on demo toggle and is implicit when points are cleared. SPREAD_K=2.5 puts ~99% of normally-distributed data inside the cube; MIN_STD=0.02 keeps degenerate (all-equal) data from exploding the divisor. Co-Authored-By: Claude Opus 4.7 (1M context) --- training/dashboard/static/dashboard.js | 105 ++++++++++++++++++++----- training/dashboard/static/index.html | 2 +- 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/training/dashboard/static/dashboard.js b/training/dashboard/static/dashboard.js index 73a15eb..f6fcb79 100644 --- a/training/dashboard/static/dashboard.js +++ b/training/dashboard/static/dashboard.js @@ -1682,22 +1682,66 @@ for epoch in range(20): } if (window.ResizeObserver) new ResizeObserver(resize).observe(canvas); - // (x,y,z) ∈ [0,1]³ → canvas pixels: rotateY then rotateX, - // perspective from a fixed camera distance. - function project(p) { - const x = (p.x ?? 0.5) - 0.5; - const y = (p.y ?? 0.5) - 0.5; - const z = (p.z ?? 0.5) - 0.5; + // ── Auto-fit: running mean / std per axis ───────────────────── + // The producer rescales PCA output to [0,1]³ by min-max of its fit + // subsample, but PCA-3 is Gaussian-ish so the bulk lands in a + // narrow band near the centroid. We track running mean+std as + // points arrive and project around mean ± SPREAD_K·σ → [-0.5,0.5] + // so the data fills the bounding cube regardless of where in + // [0,1] the producer happens to put it. Outliers clamp to ±0.7 + // so they're visible just outside the cube. + const SPREAD_K = 2.5; + const MIN_STD = 0.02; // floor so degenerate (all-equal) data doesn't blow up + const stats = { + n: 0, + sx: 0, sx2: 0, sy: 0, sy2: 0, sz: 0, sz2: 0, + mx: 0.5, my: 0.5, mz: 0.5, + dx: 0.4 / SPREAD_K, dy: 0.4 / SPREAD_K, dz: 0.4 / SPREAD_K, + dirty: false, + }; + function resetStats() { + stats.n = 0; + stats.sx = stats.sx2 = stats.sy = stats.sy2 = stats.sz = stats.sz2 = 0; + stats.mx = stats.my = stats.mz = 0.5; + stats.dx = stats.dy = stats.dz = 0.4 / SPREAD_K; + stats.dirty = false; + } + function addStat(p) { + const z = (typeof p.z === 'number') ? p.z : 0.5; + stats.n++; + stats.sx += p.x; stats.sx2 += p.x * p.x; + stats.sy += p.y; stats.sy2 += p.y * p.y; + stats.sz += z; stats.sz2 += z * z; + stats.dirty = true; + } + function recomputeStats() { + if (stats.n < 2) { stats.dirty = false; return; } + const n = stats.n; + stats.mx = stats.sx / n; + stats.my = stats.sy / n; + stats.mz = stats.sz / n; + stats.dx = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sx2 / n - stats.mx * stats.mx))); + stats.dy = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sy2 / n - stats.my * stats.my))); + stats.dz = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sz2 / n - stats.mz * stats.mz))); + stats.dirty = false; + } + + function clamp(v, lo, hi) { return v < lo ? lo : v > hi ? hi : v; } + + // Project already-normalized (centered, σ-scaled) coords to canvas + // pixels. nx, ny, nz are in roughly [-0.5, 0.5] for the bulk of + // the data; outliers go a bit beyond. + function projectNorm(nx, ny, nz) { const cy_ = Math.cos(rotY), sy_ = Math.sin(rotY); const cx_ = Math.cos(rotX), sx_ = Math.sin(rotX); - const x1 = x * cy_ + z * sy_; - const z1 = -x * sy_ + z * cy_; - const y2 = y * cx_ - z1 * sx_; - const z2 = y * sx_ + z1 * cx_; + const x1 = nx * cy_ + nz * sy_; + const z1 = -nx * sy_ + nz * cy_; + const y2 = ny * cx_ - z1 * sx_; + const z2 = ny * sx_ + z1 * cx_; const camZ = 2.5; const persp = camZ / (camZ - z2); const w = canvas.clientWidth, h = canvas.clientHeight; - const span = Math.min(w, h) * 0.4; + const span = Math.min(w, h) * 0.46; return { sx: w / 2 + x1 * span * persp, sy: h / 2 + y2 * span * persp, @@ -1706,18 +1750,32 @@ for epoch in range(20): }; } + // Project a raw data point: normalize via running stats, then + // hand off to projectNorm. + function project(p) { + if (stats.dirty) recomputeStats(); + const z = (typeof p.z === 'number') ? p.z : stats.mz; + const nx = clamp(((p.x - stats.mx) / (SPREAD_K * stats.dx)) * 0.5, -0.7, 0.7); + const ny = clamp(((p.y - stats.my) / (SPREAD_K * stats.dy)) * 0.5, -0.7, 0.7); + const nz = clamp(((z - stats.mz) / (SPREAD_K * stats.dz)) * 0.5, -0.7, 0.7); + return projectNorm(nx, ny, nz); + } + const cubeEdges = [ [0,1],[1,3],[3,2],[2,0],[4,5],[5,7],[7,6],[6,4], [0,4],[1,5],[2,6],[3,7], ]; function drawCube() { + // The cube outlines mean ± k·σ — i.e. the data spread, not the + // raw [0,1]³ producer-unit cube. Stays consistent with the + // auto-fit projection above. const corners = []; for (let i = 0; i < 8; i++) { - corners.push(project({ - x: (i & 1) ? 1 : 0, - y: (i & 2) ? 1 : 0, - z: (i & 4) ? 1 : 0, - })); + corners.push(projectNorm( + (i & 1) ? 0.5 : -0.5, + (i & 2) ? 0.5 : -0.5, + (i & 4) ? 0.5 : -0.5, + )); } ctx.save(); ctx.strokeStyle = cssColor('var(--line)'); @@ -1797,6 +1855,7 @@ for epoch in range(20): // differs from ground truth. function loadSynthetic() { points.length = 0; + resetStats(); let seed = 7; const rand = () => { seed = (seed * 1664525 + 1013904223) >>> 0; return ((seed & 0xffff) / 0xffff) - 0.5; }; const wrand = () => { seed = (seed * 1664525 + 1013904223) >>> 0; return (seed & 0xffff) / 0xffff; }; @@ -1807,30 +1866,34 @@ for epoch in range(20): const predicted = wrong ? PHASES[(idx + 1 + Math.floor(wrand() * 4)) % PHASES.length] : p; - points.push({ + const pt = { x: cx + rand() * 0.18, y: cy + rand() * 0.18, z: cz + rand() * 0.18, phase: p, predicted, cluster: idx, - }); + }; + points.push(pt); + addStat(pt); } }); rebuildLegend(); } on('demo_start', loadSynthetic); - on('demo_stop', () => { points.length = 0; rebuildLegend(); }); + on('demo_stop', () => { points.length = 0; resetStats(); rebuildLegend(); }); on('embedding', m => { if (typeof m.x !== 'number' || typeof m.y !== 'number') return; - points.push({ + const pt = { x: m.x, y: m.y, z: typeof m.z === 'number' ? m.z : 0.5, phase: m.phase, predicted: m.predicted, cluster: typeof m.cluster === 'number' ? m.cluster : undefined, - }); + }; + points.push(pt); + addStat(pt); rebuildLegend(); }); diff --git a/training/dashboard/static/index.html b/training/dashboard/static/index.html index f58ea07..f7509e2 100644 --- a/training/dashboard/static/index.html +++ b/training/dashboard/static/index.html @@ -533,6 +533,6 @@ - +