lean4-htt/src/Lean/Compiler/LCNF/ToExpr.lean
Henrik Böving 85645958f9
fix: overeager specialisation reuse in codegen (#10429)
This PR fixes and overeager reuse of specialisation in the code
generator.

The issue was originally discovered in
https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/Miscompilation.20.28incorrect.20code.29.20in.20new.20compiler/near/540037917
and occurs because the specialisation cache didnt't take the name of
alternatives in pattern matches
into account.
2025-09-17 17:35:40 +00:00

119 lines
3.8 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.LCNF.Basic
public section
namespace Lean.Compiler.LCNF
namespace ToExpr
abbrev LevelMap := FVarIdMap Nat
private def _root_.Lean.FVarId.toExpr (offset : Nat) (m : LevelMap) (fvarId : FVarId) : Expr :=
match m.get? fvarId with
| some level => .bvar (offset - level - 1)
| none => .fvar fvarId
private def _root_.Lean.Expr.abstract' (offset : Nat) (m : LevelMap) (e : Expr) : Expr :=
go offset e
where
go (o : Nat) (e : Expr) : Expr :=
match e with
| .fvar fvarId => fvarId.toExpr o m
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
| .app f a => .app (go o f) (go o a)
| .mdata k b => .mdata k (go o b)
| .proj s i b => .proj s i (go o b)
| .forallE n d b bi => .forallE n (go o d) (go (o+1) b) bi
| .lam n d b bi => .lam n (go o d) (go (o+1) b) bi
| .letE n t v b nd => .letE n (go o t) (go o v) (go (o+1) b) nd
abbrev ToExprM := ReaderT Nat $ StateM LevelMap
@[inline] def mkLambdaM (params : Array Param) (e : Expr) : ToExprM Expr :=
return go (← read) (← get) params.size e
where
go (offset : Nat) (m : LevelMap) (i : Nat) (e : Expr) : Expr :=
if i > 0 then
let param := params[i-1]!
let domain := param.type.abstract' (offset - 1) m
go (offset - 1) m (i - 1) (.lam param.binderName domain e .default)
else
e
private abbrev _root_.Lean.FVarId.toExprM (fvarId : FVarId) : ToExprM Expr :=
return fvarId.toExpr (← read) (← get)
@[inline] def abstractM (e : Expr) : ToExprM Expr :=
return e.abstract' (← read) (← get)
@[inline] def withFVar (fvarId : FVarId) (k : ToExprM α) : ToExprM α := do
let offset ← read
modify fun s => s.insert fvarId offset
withReader (·+1) k
@[inline] partial def withParams (params : Array Param) (k : ToExprM α) : ToExprM α :=
go 0
where
@[specialize] go (i : Nat) : ToExprM α := do
if h : i < params.size then
withFVar params[i].fvarId (go (i+1))
else
k
@[inline] def run (x : ToExprM α) (offset : Nat := 0) (levelMap : LevelMap := {}) : α :=
x |>.run offset |>.run' levelMap
@[inline] def run' (x : ToExprM α) (xs : Array FVarId) : α :=
let map := xs.foldl (init := {}) fun map x => map.insert x map.size
run x xs.size map
end ToExpr
open ToExpr
private def Arg.toExprM (arg : Arg) : ToExprM Expr :=
return arg.toExpr.abstract' (← read) (← get)
mutual
partial def FunDecl.toExprM (decl : FunDecl) : ToExprM Expr :=
withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM)
partial def Code.toExprM (code : Code) : ToExprM Expr := do
match code with
| .let decl k =>
let type ← abstractM decl.type
let value ← abstractM decl.value.toExpr
let body ← withFVar decl.fvarId k.toExprM
return .letE decl.binderName type value body true
| .fun decl k | .jp decl k =>
let type ← abstractM decl.type
let value ← decl.toExprM
let body ← withFVar decl.fvarId k.toExprM
return .letE decl.binderName type value body true
| .return fvarId => fvarId.toExprM
| .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM Arg.toExprM)
| .unreach type => return mkApp (mkConst ``lcUnreachable) (← abstractM type)
| .cases c =>
let alts ← c.alts.mapM fun
| .alt ctorName params k => do
let body ← withParams params do mkLambdaM params (← k.toExprM)
return mkApp (mkConst ctorName) body
| .default k => k.toExprM
return mkAppN (mkConst `cases) (#[← c.discr.toExprM] ++ alts)
end
public def Code.toExpr (code : Code) (xs : Array FVarId := #[]) : Expr :=
run' code.toExprM xs
public def FunDecl.toExpr (decl : FunDecl) (xs : Array FVarId := #[]) : Expr :=
run' decl.toExprM xs
end Lean.Compiler.LCNF