knn scatter: auto-fit projection to running data spread
Project around mean ± k·σ instead of the raw [0,1]³ producer-unit cube. PCA-3 outputs are Gaussian-ish so even after the producer's min/max rescale, the bulk of points clusters near the centroid; without auto-fit the scatter looks dead-centre and tiny. Implementation: incremental Welford-ish stats (running sum / sum²) per axis, recomputed lazily on the first frame after new data arrives. project() centers and σ-scales each point to ~[-0.5, 0.5]; outliers clamp to ±0.7 so they're visible just outside the cube. The bounding cube now traces mean ± k·σ instead of [0,1]³, which is also the natural visual unit for the "data spread" the user reads off the screen. resetStats() runs on demo toggle and is implicit when points are cleared. SPREAD_K=2.5 puts ~99% of normally-distributed data inside the cube; MIN_STD=0.02 keeps degenerate (all-equal) data from exploding the divisor. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
aa6187042b
commit
2abc55a59b
2 changed files with 85 additions and 22 deletions
|
|
@ -1682,22 +1682,66 @@ for epoch in range(20):
|
|||
}
|
||||
if (window.ResizeObserver) new ResizeObserver(resize).observe(canvas);
|
||||
|
||||
// (x,y,z) ∈ [0,1]³ → canvas pixels: rotateY then rotateX,
|
||||
// perspective from a fixed camera distance.
|
||||
function project(p) {
|
||||
const x = (p.x ?? 0.5) - 0.5;
|
||||
const y = (p.y ?? 0.5) - 0.5;
|
||||
const z = (p.z ?? 0.5) - 0.5;
|
||||
// ── Auto-fit: running mean / std per axis ─────────────────────
|
||||
// The producer rescales PCA output to [0,1]³ by min-max of its fit
|
||||
// subsample, but PCA-3 is Gaussian-ish so the bulk lands in a
|
||||
// narrow band near the centroid. We track running mean+std as
|
||||
// points arrive and project around mean ± SPREAD_K·σ → [-0.5,0.5]
|
||||
// so the data fills the bounding cube regardless of where in
|
||||
// [0,1] the producer happens to put it. Outliers clamp to ±0.7
|
||||
// so they're visible just outside the cube.
|
||||
const SPREAD_K = 2.5;
|
||||
const MIN_STD = 0.02; // floor so degenerate (all-equal) data doesn't blow up
|
||||
const stats = {
|
||||
n: 0,
|
||||
sx: 0, sx2: 0, sy: 0, sy2: 0, sz: 0, sz2: 0,
|
||||
mx: 0.5, my: 0.5, mz: 0.5,
|
||||
dx: 0.4 / SPREAD_K, dy: 0.4 / SPREAD_K, dz: 0.4 / SPREAD_K,
|
||||
dirty: false,
|
||||
};
|
||||
function resetStats() {
|
||||
stats.n = 0;
|
||||
stats.sx = stats.sx2 = stats.sy = stats.sy2 = stats.sz = stats.sz2 = 0;
|
||||
stats.mx = stats.my = stats.mz = 0.5;
|
||||
stats.dx = stats.dy = stats.dz = 0.4 / SPREAD_K;
|
||||
stats.dirty = false;
|
||||
}
|
||||
function addStat(p) {
|
||||
const z = (typeof p.z === 'number') ? p.z : 0.5;
|
||||
stats.n++;
|
||||
stats.sx += p.x; stats.sx2 += p.x * p.x;
|
||||
stats.sy += p.y; stats.sy2 += p.y * p.y;
|
||||
stats.sz += z; stats.sz2 += z * z;
|
||||
stats.dirty = true;
|
||||
}
|
||||
function recomputeStats() {
|
||||
if (stats.n < 2) { stats.dirty = false; return; }
|
||||
const n = stats.n;
|
||||
stats.mx = stats.sx / n;
|
||||
stats.my = stats.sy / n;
|
||||
stats.mz = stats.sz / n;
|
||||
stats.dx = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sx2 / n - stats.mx * stats.mx)));
|
||||
stats.dy = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sy2 / n - stats.my * stats.my)));
|
||||
stats.dz = Math.max(MIN_STD, Math.sqrt(Math.max(0, stats.sz2 / n - stats.mz * stats.mz)));
|
||||
stats.dirty = false;
|
||||
}
|
||||
|
||||
function clamp(v, lo, hi) { return v < lo ? lo : v > hi ? hi : v; }
|
||||
|
||||
// Project already-normalized (centered, σ-scaled) coords to canvas
|
||||
// pixels. nx, ny, nz are in roughly [-0.5, 0.5] for the bulk of
|
||||
// the data; outliers go a bit beyond.
|
||||
function projectNorm(nx, ny, nz) {
|
||||
const cy_ = Math.cos(rotY), sy_ = Math.sin(rotY);
|
||||
const cx_ = Math.cos(rotX), sx_ = Math.sin(rotX);
|
||||
const x1 = x * cy_ + z * sy_;
|
||||
const z1 = -x * sy_ + z * cy_;
|
||||
const y2 = y * cx_ - z1 * sx_;
|
||||
const z2 = y * sx_ + z1 * cx_;
|
||||
const x1 = nx * cy_ + nz * sy_;
|
||||
const z1 = -nx * sy_ + nz * cy_;
|
||||
const y2 = ny * cx_ - z1 * sx_;
|
||||
const z2 = ny * sx_ + z1 * cx_;
|
||||
const camZ = 2.5;
|
||||
const persp = camZ / (camZ - z2);
|
||||
const w = canvas.clientWidth, h = canvas.clientHeight;
|
||||
const span = Math.min(w, h) * 0.4;
|
||||
const span = Math.min(w, h) * 0.46;
|
||||
return {
|
||||
sx: w / 2 + x1 * span * persp,
|
||||
sy: h / 2 + y2 * span * persp,
|
||||
|
|
@ -1706,18 +1750,32 @@ for epoch in range(20):
|
|||
};
|
||||
}
|
||||
|
||||
// Project a raw data point: normalize via running stats, then
|
||||
// hand off to projectNorm.
|
||||
function project(p) {
|
||||
if (stats.dirty) recomputeStats();
|
||||
const z = (typeof p.z === 'number') ? p.z : stats.mz;
|
||||
const nx = clamp(((p.x - stats.mx) / (SPREAD_K * stats.dx)) * 0.5, -0.7, 0.7);
|
||||
const ny = clamp(((p.y - stats.my) / (SPREAD_K * stats.dy)) * 0.5, -0.7, 0.7);
|
||||
const nz = clamp(((z - stats.mz) / (SPREAD_K * stats.dz)) * 0.5, -0.7, 0.7);
|
||||
return projectNorm(nx, ny, nz);
|
||||
}
|
||||
|
||||
const cubeEdges = [
|
||||
[0,1],[1,3],[3,2],[2,0],[4,5],[5,7],[7,6],[6,4],
|
||||
[0,4],[1,5],[2,6],[3,7],
|
||||
];
|
||||
function drawCube() {
|
||||
// The cube outlines mean ± k·σ — i.e. the data spread, not the
|
||||
// raw [0,1]³ producer-unit cube. Stays consistent with the
|
||||
// auto-fit projection above.
|
||||
const corners = [];
|
||||
for (let i = 0; i < 8; i++) {
|
||||
corners.push(project({
|
||||
x: (i & 1) ? 1 : 0,
|
||||
y: (i & 2) ? 1 : 0,
|
||||
z: (i & 4) ? 1 : 0,
|
||||
}));
|
||||
corners.push(projectNorm(
|
||||
(i & 1) ? 0.5 : -0.5,
|
||||
(i & 2) ? 0.5 : -0.5,
|
||||
(i & 4) ? 0.5 : -0.5,
|
||||
));
|
||||
}
|
||||
ctx.save();
|
||||
ctx.strokeStyle = cssColor('var(--line)');
|
||||
|
|
@ -1797,6 +1855,7 @@ for epoch in range(20):
|
|||
// differs from ground truth.
|
||||
function loadSynthetic() {
|
||||
points.length = 0;
|
||||
resetStats();
|
||||
let seed = 7;
|
||||
const rand = () => { seed = (seed * 1664525 + 1013904223) >>> 0; return ((seed & 0xffff) / 0xffff) - 0.5; };
|
||||
const wrand = () => { seed = (seed * 1664525 + 1013904223) >>> 0; return (seed & 0xffff) / 0xffff; };
|
||||
|
|
@ -1807,30 +1866,34 @@ for epoch in range(20):
|
|||
const predicted = wrong
|
||||
? PHASES[(idx + 1 + Math.floor(wrand() * 4)) % PHASES.length]
|
||||
: p;
|
||||
points.push({
|
||||
const pt = {
|
||||
x: cx + rand() * 0.18,
|
||||
y: cy + rand() * 0.18,
|
||||
z: cz + rand() * 0.18,
|
||||
phase: p,
|
||||
predicted,
|
||||
cluster: idx,
|
||||
});
|
||||
};
|
||||
points.push(pt);
|
||||
addStat(pt);
|
||||
}
|
||||
});
|
||||
rebuildLegend();
|
||||
}
|
||||
|
||||
on('demo_start', loadSynthetic);
|
||||
on('demo_stop', () => { points.length = 0; rebuildLegend(); });
|
||||
on('demo_stop', () => { points.length = 0; resetStats(); rebuildLegend(); });
|
||||
on('embedding', m => {
|
||||
if (typeof m.x !== 'number' || typeof m.y !== 'number') return;
|
||||
points.push({
|
||||
const pt = {
|
||||
x: m.x, y: m.y,
|
||||
z: typeof m.z === 'number' ? m.z : 0.5,
|
||||
phase: m.phase,
|
||||
predicted: m.predicted,
|
||||
cluster: typeof m.cluster === 'number' ? m.cluster : undefined,
|
||||
});
|
||||
};
|
||||
points.push(pt);
|
||||
addStat(pt);
|
||||
rebuildLegend();
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -533,6 +533,6 @@
|
|||
</article>
|
||||
</div>
|
||||
|
||||
<script src="/static/dashboard.js?v=7e81783b"></script>
|
||||
<script src="/static/dashboard.js?v=fbac7a5c"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue