knn scatter: auto-fit projection to running data spread

Project around mean ± k·σ instead of the raw [0,1]³ producer-unit
cube. PCA-3 outputs are Gaussian-ish so even after the producer's
min/max rescale, the bulk of points clusters near the centroid;
without auto-fit the scatter looks dead-centre and tiny.

Implementation: incremental Welford-ish stats (running sum / sum²)
per axis, recomputed lazily on the first frame after new data
arrives. project() centers and σ-scales each point to ~[-0.5, 0.5];
outliers clamp to ±0.7 so they're visible just outside the cube.
The bounding cube now traces mean ± k·σ instead of [0,1]³, which is
also the natural visual unit for the "data spread" the user reads
off the screen.

resetStats() runs on demo toggle and is implicit when points are
cleared. SPREAD_K=2.5 puts ~99% of normally-distributed data inside
the cube; MIN_STD=0.02 keeps degenerate (all-equal) data from
exploding the divisor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Max Gorog 2026-05-08 13:33:19 -05:00
parent aa6187042b
commit 2abc55a59b
2 changed files with 85 additions and 22 deletions

View file

@ -1682,22 +1682,66 @@ for epoch in range(20):
}
if (window.ResizeObserver) new ResizeObserver(resize).observe(canvas);
// (x,y,z) ∈ [0,1]³ → canvas pixels: rotateY then rotateX,
// perspective from a fixed camera distance.
function project(p) {
const x = (p.x ?? 0.5) - 0.5;
const y = (p.y ?? 0.5) - 0.5;
const z = (p.z ?? 0.5) - 0.5;
// ── Auto-fit: running mean / std per axis ─────────────────────
// The producer rescales PCA output to [0,1]³ by min-max of its fit
// subsample, but PCA-3 is Gaussian-ish so the bulk lands in a
// narrow band near the centroid. We track running mean+std as
// points arrive and project around mean ± SPREAD_K·σ → [-0.5,0.5]
// so the data fills the bounding cube regardless of where in
// [0,1] the producer happens to put it. Outliers clamp to ±0.7
// so they're visible just outside the cube.
const SPREAD_K = 2.5;
const MIN_STD = 0.02; // floor so degenerate (all-equal) data doesn't blow up
const stats = {
n: 0,
sx: 0, sx2: 0, sy: 0, sy2: 0, sz: 0, sz2: 0,
mx: 0.5, my: 0.5, mz: 0.5,
dx: 0.4 / SPREAD_K, dy: 0.4 / SPREAD_K, dz: 0.4 / SPREAD_K,
dirty: false,
};
function resetStats() {
stats.n = 0;
stats.sx = stats.sx2 = stats.sy = stats.sy2 = stats.sz = stats.sz2 = 0;
stats.mx = stats.my = stats.mz = 0.5;
stats.dx = stats.dy = stats.dz = 0.4 / SPREAD_K;
stats.dirty = false;
}
function addStat(p) {
const z = (typeof p.z === 'number') ? p.z : 0.5;
stats.n++;
stats.sx += p.x; stats.sx2 += p.x * p.x;
stats.sy += p.y; stats.sy2 += p.y * p.y;
stats.sz += z; stats.sz2 += z * z;
stats.dirty = true;
}
function recomputeStats() {
if (stats.n < 2) { stats.dirty = false; return; }
const n = stats.n;
stats.mx = stats.sx / n;
stats.my = stats.sy / n;
stats.mz = stats.sz / n;
stats.dx = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sx2 / n - stats.mx * stats.mx)));
stats.dy = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sy2 / n - stats.my * stats.my)));
stats.dz = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sz2 / n - stats.mz * stats.mz)));
stats.dirty = false;
}
function clamp(v, lo, hi) { return v < lo ? lo : v > hi ? hi : v; }
// Project already-normalized (centered, σ-scaled) coords to canvas
// pixels. nx, ny, nz are in roughly [-0.5, 0.5] for the bulk of
// the data; outliers go a bit beyond.
function projectNorm(nx, ny, nz) {
const cy_ = Math.cos(rotY), sy_ = Math.sin(rotY);
const cx_ = Math.cos(rotX), sx_ = Math.sin(rotX);
const x1 = x * cy_ + z * sy_;
const z1 = -x * sy_ + z * cy_;
const y2 = y * cx_ - z1 * sx_;
const z2 = y * sx_ + z1 * cx_;
const x1 = nx * cy_ + nz * sy_;
const z1 = -nx * sy_ + nz * cy_;
const y2 = ny * cx_ - z1 * sx_;
const z2 = ny * sx_ + z1 * cx_;
const camZ = 2.5;
const persp = camZ / (camZ - z2);
const w = canvas.clientWidth, h = canvas.clientHeight;
const span = Math.min(w, h) * 0.4;
const span = Math.min(w, h) * 0.46;
return {
sx: w / 2 + x1 * span * persp,
sy: h / 2 + y2 * span * persp,
@ -1706,18 +1750,32 @@ for epoch in range(20):
};
}
// Project a raw data point: normalize via running stats, then
// hand off to projectNorm.
function project(p) {
if (stats.dirty) recomputeStats();
const z = (typeof p.z === 'number') ? p.z : stats.mz;
const nx = clamp(((p.x - stats.mx) / (SPREAD_K * stats.dx)) * 0.5, -0.7, 0.7);
const ny = clamp(((p.y - stats.my) / (SPREAD_K * stats.dy)) * 0.5, -0.7, 0.7);
const nz = clamp(((z - stats.mz) / (SPREAD_K * stats.dz)) * 0.5, -0.7, 0.7);
return projectNorm(nx, ny, nz);
}
const cubeEdges = [
[0,1],[1,3],[3,2],[2,0],[4,5],[5,7],[7,6],[6,4],
[0,4],[1,5],[2,6],[3,7],
];
function drawCube() {
// The cube outlines mean ± k·σ — i.e. the data spread, not the
// raw [0,1]³ producer-unit cube. Stays consistent with the
// auto-fit projection above.
const corners = [];
for (let i = 0; i < 8; i++) {
corners.push(project({
x: (i & 1) ? 1 : 0,
y: (i & 2) ? 1 : 0,
z: (i & 4) ? 1 : 0,
}));
corners.push(projectNorm(
(i & 1) ? 0.5 : -0.5,
(i & 2) ? 0.5 : -0.5,
(i & 4) ? 0.5 : -0.5,
));
}
ctx.save();
ctx.strokeStyle = cssColor('var(--line)');
@ -1797,6 +1855,7 @@ for epoch in range(20):
// differs from ground truth.
function loadSynthetic() {
points.length = 0;
resetStats();
let seed = 7;
const rand = () => { seed = (seed * 1664525 + 1013904223) >>> 0; return ((seed & 0xffff) / 0xffff) - 0.5; };
const wrand = () => { seed = (seed * 1664525 + 1013904223) >>> 0; return (seed & 0xffff) / 0xffff; };
@ -1807,30 +1866,34 @@ for epoch in range(20):
const predicted = wrong
? PHASES[(idx + 1 + Math.floor(wrand() * 4)) % PHASES.length]
: p;
points.push({
const pt = {
x: cx + rand() * 0.18,
y: cy + rand() * 0.18,
z: cz + rand() * 0.18,
phase: p,
predicted,
cluster: idx,
});
};
points.push(pt);
addStat(pt);
}
});
rebuildLegend();
}
on('demo_start', loadSynthetic);
on('demo_stop', () => { points.length = 0; rebuildLegend(); });
on('demo_stop', () => { points.length = 0; resetStats(); rebuildLegend(); });
on('embedding', m => {
if (typeof m.x !== 'number' || typeof m.y !== 'number') return;
points.push({
const pt = {
x: m.x, y: m.y,
z: typeof m.z === 'number' ? m.z : 0.5,
phase: m.phase,
predicted: m.predicted,
cluster: typeof m.cluster === 'number' ? m.cluster : undefined,
});
};
points.push(pt);
addStat(pt);
rebuildLegend();
});

View file

@ -533,6 +533,6 @@
</article>
</div>
<script src="/static/dashboard.js?v=7e81783b"></script>
<script src="/static/dashboard.js?v=fbac7a5c"></script>
</body>
</html>