From 6d4682959939994c2fdbbfae6bfdb7d2c683d290 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 24 Oct 2022 19:53:08 -0700 Subject: [PATCH] chore: new LCNF representation This is the first of a series of commits to change the LCNF representation. --- src/Lean/Compiler/LCNF/AlphaEqv.lean | 52 +++++--- src/Lean/Compiler/LCNF/Basic.lean | 167 +++++++++++++++++++----- src/Lean/Compiler/LCNF/CSE.lean | 7 +- src/Lean/Compiler/LCNF/Closure.lean | 35 +++-- src/Lean/Compiler/LCNF/CompilerM.lean | 57 ++++++-- src/Lean/Compiler/LCNF/DependsOn.lean | 25 +++- src/Lean/Compiler/LCNF/ElimDead.lean | 35 +++-- src/Lean/Compiler/LCNF/FixedParams.lean | 35 +++-- src/Lean/Compiler/LCNF/ForEachExpr.lean | 40 +----- src/Lean/Compiler/LCNF/InferType.lean | 13 +- src/Lean/Compiler/LCNF/LCtx.lean | 2 +- src/Lean/Compiler/LCNF/Level.lean | 31 ++++- src/Lean/Compiler/LCNF/Simp/Basic.lean | 22 ++-- src/Lean/Compiler/LCNF/ToExpr.lean | 7 +- 14 files changed, 361 insertions(+), 167 deletions(-) diff --git a/src/Lean/Compiler/LCNF/AlphaEqv.lean b/src/Lean/Compiler/LCNF/AlphaEqv.lean index 01c4775286..b7d524cb45 100644 --- a/src/Lean/Compiler/LCNF/AlphaEqv.lean +++ b/src/Lean/Compiler/LCNF/AlphaEqv.lean @@ -19,25 +19,47 @@ def eqvFVar (fvarId₁ fvarId₂ : FVarId) : EqvM Bool := do let fvarId₂ := (← read).find? fvarId₂ |>.getD fvarId₂ return fvarId₁ == fvarId₂ -def eqvExpr (e₁ e₂ : Expr) : EqvM Bool := do +def eqvType (e₁ e₂ : Expr) : EqvM Bool := do match e₁, e₂ with - | .app f₁ a₁, .app f₂ a₂ => eqvExpr a₁ a₂ <&&> eqvExpr f₁ f₂ - | .proj s₁ i₁ e₁, .proj s₂ i₂ e₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvExpr e₁ e₂ - | .mdata m₁ e₁, .mdata m₂ e₂ => pure (m₁ == m₂) <&&> eqvExpr e₁ e₂ + | .app f₁ a₁, .app f₂ a₂ => eqvType a₁ a₂ <&&> eqvType f₁ f₂ | .fvar fvarId₁, .fvar fvarId₂ => eqvFVar fvarId₁ fvarId₂ - | .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvExpr d₁ d₂ <&&> eqvExpr b₁ b₂ - | .letE .., _ | _, .letE .. => unreachable! + | .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvType d₁ d₂ <&&> eqvType b₁ b₂ | _, _ => return e₁ == e₂ -def eqvExprs (es₁ es₂ : Array Expr) : EqvM Bool := do +def eqvTypes (es₁ es₂ : Array Expr) : EqvM Bool := do if es₁.size = es₂.size then for e₁ in es₁, e₂ in es₂ do - unless (← eqvExpr e₁ e₂) do + unless (← eqvType e₁ e₂) do return false return true else return false +def eqvArg (a₁ a₂ : Arg) : EqvM Bool := do + match a₁, a₂ with + | .type e₁, .type e₂ => eqvType e₁ e₂ + | .fvar x₁, .fvar x₂ => eqvFVar x₁ x₂ + | .erased, .erased => return true + | _, _ => return false + +def eqvArgs (as₁ as₂ : Array Arg) : EqvM Bool := do + if as₁.size = as₂.size then + for a₁ in as₁, a₂ in as₂ do + unless (← eqvArg a₁ a₂) do + return false + return true + else + return false + +def eqvLetExpr (e₁ e₂ : LetExpr) : EqvM Bool := do + match e₁, e₂ with + | .value v₁, .value v₂ => return v₁ == v₂ + | .erased, .erased => return true + | .proj s₁ i₁ x₁, .proj s₂ i₂ x₂ => pure (s₁ == s₂ && i₁ == i₂) <&&> eqvFVar x₁ x₂ + | .const n₁ us₁ as₁, .const n₂ us₂ as₂ => pure (n₁ == n₂ && us₁ == us₂) <&&> eqvArgs as₁ as₂ + | .fvar f₁ as₁, .fvar f₂ as₂ => eqvFVar f₁ f₂ <&&> eqvArgs as₁ as₂ + | _, _ => return false + @[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α := withReader (·.insert fvarId₂ fvarId₁) x @@ -48,7 +70,7 @@ def eqvExprs (es₁ es₂ : Array Expr) : EqvM Bool := do let p₁ := params₁[i] have : i < params₂.size := by simp_all_arith let p₂ := params₂[i] - unless (← eqvExpr p₁.type p₂.type) do return false + unless (← eqvType p₁.type p₂.type) do return false withFVar p₁.fvarId p₂.fvarId do go (i+1) else @@ -84,20 +106,20 @@ partial def eqvAlts (alts₁ alts₂ : Array Alt) : EqvM Bool := do partial def eqv (code₁ code₂ : Code) : EqvM Bool := do match code₁, code₂ with | .let decl₁ k₁, .let decl₂ k₂ => - eqvExpr decl₁.type decl₂.type <&&> - eqvExpr decl₁.value decl₂.value <&&> + eqvType decl₁.type decl₂.type <&&> + eqvLetExpr decl₁.value decl₂.value <&&> withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂) | .fun decl₁ k₁, .fun decl₂ k₂ | .jp decl₁ k₁, .jp decl₂ k₂ => - eqvExpr decl₁.type decl₂.type <&&> + eqvType decl₁.type decl₂.type <&&> withParams decl₁.params decl₂.params (eqv decl₁.value decl₂.value) <&&> withFVar decl₁.fvarId decl₂.fvarId (eqv k₁ k₂) | .return fvarId₁, .return fvarId₂ => eqvFVar fvarId₁ fvarId₂ - | .unreach type₁, .unreach type₂ => eqvExpr type₁ type₂ - | .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvExprs args₁ args₂ + | .unreach type₁, .unreach type₂ => eqvType type₁ type₂ + | .jmp fvarId₁ args₁, .jmp fvarId₂ args₂ => eqvFVar fvarId₁ fvarId₂ <&&> eqvArgs args₁ args₂ | .cases c₁, .cases c₂ => eqvFVar c₁.discr c₂.discr <&&> - eqvExpr c₁.resultType c₂.resultType <&&> + eqvType c₁.resultType c₂.resultType <&&> eqvAlts c₁.alts c₂.alts | _, _ => return false diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index ffe9de5320..baff11af71 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -7,6 +7,7 @@ import Lean.Expr import Lean.Meta.Instances import Lean.Compiler.InlineAttrs import Lean.Compiler.Specialize +import Lean.Compiler.LCNF.Types namespace Lean.Compiler.LCNF @@ -34,11 +35,90 @@ inductive AltCore (Code : Type) where | default (code : Code) deriving Inhabited +inductive Value where + | natVal (val : Nat) + | strVal (val : String) + -- TODO: add constructors for `Int`, `Float`, `UInt` ... + deriving Inhabited, BEq, Hashable + +inductive Arg where + | erased + | fvar (fvarId : FVarId) + | type (expr : Expr) + deriving Inhabited, BEq, Hashable + +def Arg.toExpr (arg : Arg) : Expr := + match arg with + | .erased => erasedExpr + | .fvar fvarId => .fvar fvarId + | .type e => e + +private unsafe def Arg.updateTypeImp (arg : Arg) (type' : Expr) : Arg := + match arg with + | .type ty => if ptrEq ty type' then arg else .type type' + | _ => unreachable! + +@[implemented_by Arg.updateTypeImp] opaque Arg.updateType! (arg : Arg) (type : Expr) : Arg + +private unsafe def Arg.updateFVarImp (arg : Arg) (fvarId' : FVarId) : Arg := + match arg with + | .fvar fvarId => if fvarId' == fvarId then arg else .fvar fvarId' + | _ => unreachable! + +@[implemented_by Arg.updateFVarImp] opaque Arg.updateFVar! (arg : Arg) (fvarId' : FVarId) : Arg + +inductive LetExpr where + | value (value : Value) + | erased + | proj (typeName : Name) (idx : Nat) (struct : FVarId) + | const (declName : Name) (us : List Level) (args : Array Arg) + | fvar (fvarId : FVarId) (args : Array Arg) + -- TODO: add constructors for mono and impure phases + deriving Inhabited, BEq, Hashable + +private unsafe def LetExpr.updateProjImp (e : LetExpr) (fvarId' : FVarId) : LetExpr := + match e with + | .proj s i fvarId => if fvarId == fvarId' then e else .proj s i fvarId' + | _ => unreachable! + +@[implemented_by LetExpr.updateProjImp] opaque LetExpr.updateProj! (e : LetExpr) (fvarId' : FVarId) : LetExpr + +private unsafe def LetExpr.updateConstImp (e : LetExpr) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetExpr := + match e with + | .const declName us args => if declName == declName' && ptrEq us us' && ptrEq args args' then e else .const declName' us' args' + | _ => unreachable! + +@[implemented_by LetExpr.updateConstImp] opaque LetExpr.updateConst! (e : LetExpr) (declName' : Name) (us' : List Level) (args' : Array Arg) : LetExpr + +private unsafe def LetExpr.updateFVarImp (e : LetExpr) (fvarId' : FVarId) (args' : Array Arg) : LetExpr := + match e with + | .fvar fvarId args => if fvarId == fvarId' && ptrEq args args' then e else .fvar fvarId' args' + | _ => unreachable! + +@[implemented_by LetExpr.updateFVarImp] opaque LetExpr.updateFVar! (e : LetExpr) (fvarId' : FVarId) (args' : Array Arg) : LetExpr + +private unsafe def LetExpr.updateArgsImp (e : LetExpr) (args' : Array Arg) : LetExpr := + match e with + | .const declName us args => if ptrEq args args' then e else .const declName us args' + | .fvar fvarId args => if ptrEq args args' then e else .fvar fvarId args' + | _ => unreachable! + +@[implemented_by LetExpr.updateArgsImp] opaque LetExpr.updateArgs! (e : LetExpr) (args' : Array Arg) : LetExpr + +def LetExpr.toExpr (e : LetExpr) : Expr := + match e with + | .value (.natVal val) => .lit (.natVal val) + | .value (.strVal val) => .lit (.strVal val) + | .erased => erasedExpr + | .proj n i s => .proj n i (.fvar s) + | .const n us as => mkAppN (.const n us) (as.map Arg.toExpr) + | .fvar fvarId as => mkAppN (.fvar fvarId) (as.map Arg.toExpr) + structure LetDecl where fvarId : FVarId binderName : Name type : Expr - value : Expr + value : LetExpr deriving Inhabited, BEq structure FunDeclCore (Code : Type) where @@ -63,7 +143,7 @@ inductive Code where | let (decl : LetDecl) (k : Code) | fun (decl : FunDeclCore Code) (k : Code) | jp (decl : FunDeclCore Code) (k : Code) - | jmp (fvarId : FVarId) (args : Array Expr) + | jmp (fvarId : FVarId) (args : Array Arg) | cases (cases : CasesCore Code) | return (fvarId : FVarId) | unreach (type : Expr) @@ -218,12 +298,12 @@ private unsafe def updateAltImp (alt : Alt) (ps' : Array Param) (k' : Code) : Al @[implemented_by updateReturnImp] opaque Code.updateReturn! (c : Code) (fvarId' : FVarId) : Code -@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code := +@[inline] private unsafe def updateJmpImp (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code := match c with | .jmp fvarId args => if fvarId == fvarId' && ptrEq args args' then c else .jmp fvarId' args' | _ => unreachable! -@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Expr) : Code +@[implemented_by updateJmpImp] opaque Code.updateJmp! (c : Code) (fvarId' : FVarId) (args' : Array Arg) : Code @[inline] private unsafe def updateUnreachImp (c : Code) (type' : Expr) : Code := match c with @@ -245,7 +325,7 @@ to be updated. -/ @[implemented_by updateParamCoreImp] opaque Param.updateCore (p : Param) (type : Expr) : Param -private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : Expr) : LetDecl := +private unsafe def updateLetDeclCoreImp (decl : LetDecl) (type : Expr) (value : LetExpr) : LetDecl := if ptrEq type decl.type && ptrEq value decl.value then decl else @@ -256,7 +336,7 @@ Low-level update `LetDecl` function. It does not update the local context. Consider using `LetDecl.update : LetDecl → Expr → Expr → CompilerM LetDecl` if you want the local context to be updated. -/ -@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : Expr) : LetDecl +@[implemented_by updateLetDeclCoreImp] opaque LetDecl.updateCore (decl : LetDecl) (type : Expr) (value : LetExpr) : LetDecl private unsafe def updateFunDeclCoreImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : FunDecl := if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then @@ -459,6 +539,9 @@ def Decl.instantiateParamsLevelParams (decl : Decl) (us : List Level) : Array Pa partial def Decl.instantiateValueLevelParams (decl : Decl) (us : List Level) : Code := instCode decl.value where + instLevel (u : Level) := + u.instantiateParams decl.levelParams us + instExpr (e : Expr) := e.instantiateLevelParamsNoCache decl.levelParams us @@ -470,8 +553,19 @@ where | .default k => alt.updateCode (instCode k) | .alt _ ps k => alt.updateAlt! (instParams ps) (instCode k) + instArg (arg : Arg) : Arg := + match arg with + | .type e => arg.updateType! (instExpr e) + | .fvar .. | .erased => arg + + instLetExpr (e : LetExpr) : LetExpr := + match e with + | .const declName vs args => e.updateConst! declName (vs.mapMono instLevel) (args.mapMono instArg) + | .fvar fvarId args => e.updateFVar! fvarId (args.mapMono instArg) + | .proj .. | .value .. | .erased => e + instLetDecl (decl : LetDecl) := - decl.updateCore (instExpr decl.type) (instExpr decl.value) + decl.updateCore (instExpr decl.type) (instLetExpr decl.value) instFunDecl (decl : FunDecl) := decl.updateCore (instExpr decl.type) (instParams decl.params) (instCode decl.value) @@ -481,7 +575,7 @@ where | .let decl k => code.updateLet! (instLetDecl decl) (instCode k) | .jp decl k | .fun decl k => code.updateFun! (instFunDecl decl) (instCode k) | .cases c => code.updateCases! (instExpr c.resultType) c.discr (c.alts.mapMono instAlt) - | .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instExpr) + | .jmp fvarId args => code.updateJmp! fvarId (args.mapMono instArg) | .return .. => code | .unreach type => code.updateUnreach! (instExpr type) @@ -506,45 +600,56 @@ def Decl.isTemplateLike (decl : Decl) : CoreM Bool := do else return false -mutual -partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet := - decl.value.collectUsed <| collectParams decl.params <| collectExpr decl.type s +private partial def collectType (e : Expr) : FVarIdSet → FVarIdSet := + match e with + | .forallE _ d b _ => collectType b ∘ collectType d + | .lam _ d b _ => collectType b ∘ collectType d + | .app f a => collectType f ∘ collectType a + | .fvar fvarId => fun s => s.insert fvarId + | .proj .. | .letE .. | .mdata .. => unreachable! + | _ => id + +private def collectArg (arg : Arg) (s : FVarIdSet) : FVarIdSet := + match arg with + | .erased => s + | .fvar fvarId => s.insert fvarId + | .type e => collectType e s + +private def collectArgs (args : Array Arg) (s : FVarIdSet) : FVarIdSet := + args.foldl (init := s) fun s arg => collectArg arg s + +private def collectLetExpr (e : LetExpr) (s : FVarIdSet) : FVarIdSet := + match e with + | .fvar fvarId args => collectArgs args <| s.insert fvarId + | .const _ _ args => collectArgs args s + | .proj _ _ fvarId => s.insert fvarId + | .value .. | .erased => s private partial def collectParams (ps : Array Param) (s : FVarIdSet) : FVarIdSet := - ps.foldl (init := s) fun s p => collectExpr p.type s + ps.foldl (init := s) fun s p => collectType p.type s -private partial def collectExprs (es : Array Expr) (s : FVarIdSet) : FVarIdSet := - es.foldl (init := s) fun s e => collectExpr e s - -private partial def collectExpr (e : Expr) : FVarIdSet → FVarIdSet := - match e with - | .proj _ _ e => collectExpr e - | .forallE _ d b _ => collectExpr b ∘ collectExpr d - | .lam _ d b _ => collectExpr b ∘ collectExpr d - | .letE .. => unreachable! - | .app f a => collectExpr f ∘ collectExpr a - | .mdata _ b => collectExpr b - | .fvar fvarId => fun s => s.insert fvarId - | _ => id +mutual +partial def FunDeclCore.collectUsed (decl : FunDecl) (s : FVarIdSet := {}) : FVarIdSet := + decl.value.collectUsed <| collectParams decl.params <| collectType decl.type s partial def Code.collectUsed (code : Code) (s : FVarIdSet := {}) : FVarIdSet := match code with - | .let decl k => k.collectUsed <| collectExpr decl.value <| collectExpr decl.type s + | .let decl k => k.collectUsed <| collectLetExpr decl.value <| collectType decl.type s | .jp decl k | .fun decl k => k.collectUsed <| decl.collectUsed s | .cases c => let s := s.insert c.discr - let s := collectExpr c.resultType s + let s := collectType c.resultType s c.alts.foldl (init := s) fun s alt => match alt with | .default k => k.collectUsed s | .alt _ ps k => k.collectUsed <| collectParams ps s | .return fvarId => s.insert fvarId - | .unreach type => collectExpr type s - | .jmp fvarId args => collectExprs args <| s.insert fvarId + | .unreach type => collectType type s + | .jmp fvarId args => collectArgs args <| s.insert fvarId end abbrev collectUsedAtExpr (s : FVarIdSet) (e : Expr) : FVarIdSet := - collectExpr e s + collectType e s /-- Traverse the given block of potentially mutually recursive functions @@ -568,7 +673,7 @@ where | .cases c => c.alts.forM fun alt => visit alt.getCode | .unreach .. | .jmp .. | .return .. => return () | .let decl k => - if let .const declName _ := decl.value.getAppFn then + if let .const declName _ _ := decl.value then if decls.any (·.name == declName) then modify fun s => s.insert declName visit k diff --git a/src/Lean/Compiler/LCNF/CSE.lean b/src/Lean/Compiler/LCNF/CSE.lean index 6e1761c204..5a77fdb98e 100644 --- a/src/Lean/Compiler/LCNF/CSE.lean +++ b/src/Lean/Compiler/LCNF/CSE.lean @@ -57,12 +57,13 @@ where | .let decl k => let decl ← normLetDecl decl -- We only apply CSE to pure code - match (← get).map.find? decl.value with + let key := decl.value.toExpr + match (← get).map.find? key with | some fvarId => replaceLet decl fvarId go k | none => - addEntry decl.value decl.fvarId + addEntry key decl.fvarId return code.updateLet! decl (← go k) | .fun decl k => let decl ← goFunDecl decl @@ -91,7 +92,7 @@ where | .default k => withNewScope do return alt.updateCode (← go k) return code.updateCases! resultType discr alts | .return fvarId => return code.updateReturn! (← normFVar fvarId) - | .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args) + | .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normArgs args) | .unreach .. => return code end CSE diff --git a/src/Lean/Compiler/LCNF/Closure.lean b/src/Lean/Compiler/LCNF/Closure.lean index 927a82e51f..ce2f5b36a6 100644 --- a/src/Lean/Compiler/LCNF/Closure.lean +++ b/src/Lean/Compiler/LCNF/Closure.lean @@ -71,7 +71,20 @@ mutual contain other type parameters. -/ partial def collectParams (params : Array Param) : ClosureM Unit := - params.forM (collectExpr ·.type) + params.forM (collectType ·.type) + + partial def collectArg (arg : Arg) : ClosureM Unit := + match arg with + | .erased => return () + | .type e => collectType e + | .fvar fvarId => collectFVar fvarId + + partial def collectLetExpr (e : LetExpr) : ClosureM Unit := do + match e with + | .erased | .value .. => return () + | .proj _ _ fvarId => collectFVar fvarId + | .const _ _ args => args.forM collectArg + | .fvar fvarId args => collectFVar fvarId; args.forM collectArg /-- Collect dependencies in the given code. We need this function to be able @@ -79,22 +92,22 @@ mutual -/ partial def collectCode (c : Code) : ClosureM Unit := do match c with - | .let decl k => collectExpr decl.type; collectExpr decl.value; collectCode k + | .let decl k => collectType decl.type; collectLetExpr decl.value; collectCode k | .fun decl k | .jp decl k => collectFunDecl decl; collectCode k | .cases c => - collectExpr c.resultType + collectType c.resultType collectFVar c.discr c.alts.forM fun alt => do match alt with | .default k => collectCode k | .alt _ ps k => collectParams ps; collectCode k - | .jmp _ args => args.forM collectExpr - | .unreach type => collectExpr type + | .jmp _ args => args.forM collectArg + | .unreach type => collectType type | .return fvarId => collectFVar fvarId /-- Collect dependencies of a local function declaration. -/ partial def collectFunDecl (decl : FunDecl) : ClosureM Unit := do - collectExpr decl.type + collectType decl.type collectParams decl.params collectCode decl.value @@ -114,21 +127,21 @@ mutual collectFunDecl funDecl modify fun s => { s with decls := s.decls.push <| .fun funDecl } else if let some param ← findParam? fvarId then - collectExpr param.type + collectType param.type modify fun s => { s with params := s.params.push param } else if let some letDecl ← findLetDecl? fvarId then - collectExpr letDecl.type + collectType letDecl.type if (← read).abstract letDecl.fvarId then modify fun s => { s with params := s.params.push <| { letDecl with borrow := false } } else - collectExpr letDecl.value + collectLetExpr letDecl.value modify fun s => { s with decls := s.decls.push <| .let letDecl } else unreachable! /-- Collect dependencies of the given expression. -/ - partial def collectExpr (e : Expr) : ClosureM Unit := do - e.forEach fun e => do + partial def collectType (type : Expr) : ClosureM Unit := do + type.forEach fun e => do match e with | .fvar fvarId => collectFVar fvarId | _ => pure () diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 877f310c34..eda0a0de0f 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -190,6 +190,23 @@ where else e +/-- +Replace the free variables in `arg` using the given substitution. + +See `normExprImp` +-/ +private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg := + match arg with + | .erased => arg + | .fvar fvarId => + match s.find? fvarId with + | some (.fvar fvarId') => + let arg' := .fvar fvarId' + if translator then arg' else normArgImp s arg' translator + | some e => if e.isErased then .erased else .type e + | none => arg + | .type e => arg.updateType! (normExprImp s e translator) + /-- Normalize the given free variable. See `normExprImp` for documentation on the `translator` parameter. @@ -208,6 +225,22 @@ private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : | some e => panic! s!"invalid LCNF substitution of free variable with expression {e}" | none => fvarId + +private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg := + args.mapMono (normArgImp s · translator) + +/-- +Replace the free variables in `e` using the given substitution. + +See `normExprImp` +-/ +private partial def normLetExprImp (s : FVarSubst) (e : LetExpr) (translator : Bool) : LetExpr := + match e with + | .erased | .value .. => e + | .proj _ _ fvarId => e.updateProj! (normFVarImp s fvarId translator) + | .const _ _ args => e.updateArgs! (normArgsImp s args translator) + | .fvar fvarId args => e.updateFVar! (normFVarImp s fvarId translator) (normArgsImp s args translator) + /-- Interface for monads that have a free substitutions. -/ @@ -248,15 +281,21 @@ See `Check.lean` for the free variable substitution checker. @[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr := return normExprImp (← getSubst) e t +@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg := + return normArgImp (← getSubst) arg t + +@[inline, inherit_doc normLetExprImp] def normLetExpr [MonadFVarSubst m t] [Monad m] (e : LetExpr) : m LetExpr := + return normLetExprImp (← getSubst) e t + @[inherit_doc normExprImp] abbrev normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr := normExprImp s e translator /-- -Normalize the given expressions using the current substitution. +Normalize the given arguments using the current substitution. -/ -def normExprs [MonadFVarSubst m t] [Monad m] (es : Array Expr) : m (Array Expr) := - es.mapMonoM normExpr +def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) := + return normArgsImp (← getSubst) args t def mkFreshBinderName (binderName := `_x): CompilerM Name := do let declName := .num binderName (← get).nextIdx @@ -280,7 +319,7 @@ def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param modifyLCtx fun lctx => lctx.addParam param return param -def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) : CompilerM LetDecl := do +def mkLetDecl (binderName : Name) (type : Expr) (value : LetExpr) : CompilerM LetDecl := do let fvarId ← mkFreshFVarId let binderName ← ensureNotAnonymous binderName `_x let decl := { fvarId, binderName, type, value } @@ -304,7 +343,7 @@ private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := @[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param -private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl := do +private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetExpr) : CompilerM LetDecl := do if ptrEq type decl.type && ptrEq value decl.value then return decl else @@ -312,9 +351,9 @@ private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : Expr modifyLCtx fun lctx => lctx.addLetDecl decl return decl -@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : Expr) : CompilerM LetDecl +@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetExpr) : CompilerM LetDecl -def LetDecl.updateValue (decl : LetDecl) (value : Expr) : CompilerM LetDecl := +def LetDecl.updateValue (decl : LetDecl) (value : LetExpr) : CompilerM LetDecl := decl.update decl.type value private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do @@ -340,7 +379,7 @@ def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Arr ps.mapMonoM normParam def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do - decl.update (← normExpr decl.type) (← normExpr decl.value) + decl.update (← normExpr decl.type) (← normLetExpr decl.value) abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM @@ -359,7 +398,7 @@ mutual | .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k) | .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k) | .return fvarId => return code.updateReturn! (← normFVar fvarId) - | .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normExprs args) + | .jmp fvarId args => return code.updateJmp! (← normFVar fvarId) (← normArgs args) | .unreach type => return code.updateUnreach! (← normExpr type) | .cases c => let resultType ← normExpr c.resultType diff --git a/src/Lean/Compiler/LCNF/DependsOn.lean b/src/Lean/Compiler/LCNF/DependsOn.lean index baf134a959..b50ec24629 100644 --- a/src/Lean/Compiler/LCNF/DependsOn.lean +++ b/src/Lean/Compiler/LCNF/DependsOn.lean @@ -12,19 +12,32 @@ private abbrev M := ReaderT FVarIdSet Id private def fvarDepOn (fvarId : FVarId) : M Bool := return (← read).contains fvarId -private def exprDepOn (e : Expr) : M Bool := do +private def typeDepOn (e : Expr) : M Bool := do let s ← read return e.hasAnyFVar fun fvarId => s.contains fvarId +private def argDepOn (a : Arg) : M Bool := do + match a with + | .erased => return false + | .fvar fvarId => fvarDepOn fvarId + | .type e => typeDepOn e + +private def letExprDepOn (e : LetExpr) : M Bool := + match e with + | .erased | .value .. => return false + | .proj _ _ fvarId => fvarDepOn fvarId + | .fvar fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn + | .const _ _ args => args.anyM argDepOn + private def LetDecl.depOn (decl : LetDecl) : M Bool := - exprDepOn decl.type <||> exprDepOn decl.value + typeDepOn decl.type <||> letExprDepOn decl.value private partial def depOn (c : Code) : M Bool := match c with | .let decl k => decl.depOn <||> depOn k - | .jp decl k | .fun decl k => exprDepOn decl.type <||> depOn decl.value <||> depOn k - | .cases c => exprDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode - | .jmp fvarId args => fvarDepOn fvarId <||> args.anyM exprDepOn + | .jp decl k | .fun decl k => typeDepOn decl.type <||> depOn decl.value <||> depOn k + | .cases c => typeDepOn c.resultType <||> fvarDepOn c.discr <||> c.alts.anyM fun alt => depOn alt.getCode + | .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn | .return fvarId => fvarDepOn fvarId | .unreach _ => return false @@ -32,7 +45,7 @@ abbrev LetDecl.dependsOn (decl : LetDecl) (s : FVarIdSet) : Bool := decl.depOn s abbrev FunDecl.dependsOn (decl : FunDecl) (s : FVarIdSet) : Bool := - exprDepOn decl.type s || depOn decl.value s + typeDepOn decl.type s || depOn decl.value s def CodeDecl.dependsOn (decl : CodeDecl) (s : FVarIdSet) : Bool := match decl with diff --git a/src/Lean/Compiler/LCNF/ElimDead.lean b/src/Lean/Compiler/LCNF/ElimDead.lean index 156004bcc2..8f84ff72be 100644 --- a/src/Lean/Compiler/LCNF/ElimDead.lean +++ b/src/Lean/Compiler/LCNF/ElimDead.lean @@ -13,26 +13,43 @@ abbrev UsedLocalDecls := FVarIdHashSet Collect set of (let) free variables in a LCNF value. This code exploits the LCNF property that local declarations do not occur in types. -/ -def collectLocalDecls (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls := - go s e +def collectLocalDeclsType (s : UsedLocalDecls) (type : Expr) : UsedLocalDecls := + go s type where go (s : UsedLocalDecls) (e : Expr) : UsedLocalDecls := match e with - | .proj _ _ e => go s e | .forallE .. => s | .lam _ _ b _ => go s b - | .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations | .app f a => go (go s a) f - | .mdata _ b => go s b | .fvar fvarId => s.insert fvarId + | .letE .. | .proj .. | .mdata .. => unreachable! -- Valid LCNF type does not contain this kind of expr | _ => s +def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg) : UsedLocalDecls := + match arg with + | .erased => s + | .type e => collectLocalDeclsType s e + | .fvar fvarId => s.insert fvarId + +def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array Arg) : UsedLocalDecls := + args.foldl (init := s) collectLocalDeclsArg + +def collectLocalDeclsLetExpr (s : UsedLocalDecls) (e : LetExpr) : UsedLocalDecls := + match e with + | .erased | .value .. => s + | .proj _ _ fvarId => s.insert fvarId + | .const _ _ args => collectLocalDeclsArgs s args + | .fvar fvarId args => collectLocalDeclsArgs (s.insert fvarId) args + namespace ElimDead abbrev M := StateRefT UsedLocalDecls CompilerM -private abbrev collectExprM (e : Expr) : M Unit := - modify (collectLocalDecls · e) +private abbrev collectArgM (arg : Arg) : M Unit := + modify (collectLocalDeclsArg · arg) + +private abbrev collectLetExprM (e : LetExpr) : M Unit := + modify (collectLocalDeclsLetExpr · e) private abbrev collectFVarM (fvarId : FVarId) : M Unit := modify (·.insert fvarId) @@ -48,7 +65,7 @@ partial def elimDead (code : Code) : M Code := do let k ← elimDead k if (← get).contains decl.fvarId then /- Remark: we don't need to collect `decl.type` because LCNF local declarations do not occur in types. -/ - collectExprM decl.value + collectLetExprM decl.value return code.updateCont! k else eraseLetDecl decl @@ -66,7 +83,7 @@ partial def elimDead (code : Code) : M Code := do collectFVarM c.discr return code.updateAlts! alts | .return fvarId => collectFVarM fvarId; return code - | .jmp fvarId args => collectFVarM fvarId; args.forM collectExprM; return code + | .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code | .unreach .. => return code end diff --git a/src/Lean/Compiler/LCNF/FixedParams.lean b/src/Lean/Compiler/LCNF/FixedParams.lean index c6e8e21e93..36b8f42e68 100644 --- a/src/Lean/Compiler/LCNF/FixedParams.lean +++ b/src/Lean/Compiler/LCNF/FixedParams.lean @@ -34,7 +34,7 @@ marked as not fixed. -/ /-- Abstract value for the "fixed parameter" analysis. -/ -inductive Value where +inductive AbsValue where | top | erased | val (i : Nat) @@ -52,7 +52,7 @@ structure Context where The assignment maps free variable ids in the current code being analyzed to abstract values. We only track the abstract value assigned to parameters. -/ - assignment : FVarIdMap Value + assignment : FVarIdMap AbsValue structure State where /-- @@ -61,7 +61,7 @@ structure State where Whenever there is function application `f a₁ ... aₙ`, where `f` is in `decls`, `f` is not `main`, and we visit with the abstract values assigned to `aᵢ`, but first we record the visit here. -/ - visited : HashSet (Name × Array Value) := {} + visited : HashSet (Name × Array AbsValue) := {} /-- Bitmask containing the result, i.e., which parameters of `main` are fixed. We initialize it with `true` everywhere. @@ -76,17 +76,18 @@ abbrev abort : FixParamM α := do modify fun s => { s with fixed := s.fixed.map fun _ => false } throw () -def evalArg (arg : Expr) : FixParamM Value := do - if arg.isErased then - return .erased - let .fvar fvarId := arg | return .top - let some val := (← read).assignment.find? fvarId | return .top - return val +def evalArg (arg : Arg) : FixParamM AbsValue := do + match arg with + | .erased => return .erased + | .type _ => return .top + | .fvar fvarId => + let some val := (← read).assignment.find? fvarId | return .top + return val def inMutualBlock (declName : Name) : FixParamM Bool := return (← read).decls.any (·.name == declName) -def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.run do +def mkAssignment (decl : Decl) (values : Array AbsValue) : FVarIdMap AbsValue := Id.run do let mut assignment := {} for param in decl.params, value in values do assignment := assignment.insert param.fvarId value @@ -94,23 +95,19 @@ def mkAssignment (decl : Decl) (values : Array Value) : FVarIdMap Value := Id.ru mutual -partial def evalExpr (e : Expr) : FixParamM Unit := do +partial def evalLetExpr (e : LetExpr) : FixParamM Unit := do match e with - | .const declName _ => evalApp declName #[] - | .app .. => - let .const declName _ := e.getAppFn | return () - if (← inMutualBlock declName) then - evalApp declName e.getAppArgs + | .const declName _ args => evalApp declName args | _ => return () partial def evalCode (code : Code) : FixParamM Unit := do match code with - | .let decl k => evalExpr decl.value; evalCode k + | .let decl k => evalLetExpr decl.value; evalCode k | .fun decl k | .jp decl k => evalCode decl.value; evalCode k | .cases c => c.alts.forM fun alt => evalCode alt.getCode | .unreach .. | .jmp .. | .return .. => return () -partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do +partial def evalApp (declName : Name) (args : Array Arg) : FixParamM Unit := do let main := (← read).main if declName == main.name then -- Recursive call to the function being analyzed @@ -145,7 +142,7 @@ partial def evalApp (declName : Name) (args : Array Expr) : FixParamM Unit := do end -def mkInitialValues (numParams : Nat) : Array Value := Id.run do +def mkInitialValues (numParams : Nat) : Array AbsValue := Id.run do let mut values := #[] for i in [:numParams] do values := values.push <| .val i diff --git a/src/Lean/Compiler/LCNF/ForEachExpr.lean b/src/Lean/Compiler/LCNF/ForEachExpr.lean index 1bc99d153e..79da9c1247 100644 --- a/src/Lean/Compiler/LCNF/ForEachExpr.lean +++ b/src/Lean/Compiler/LCNF/ForEachExpr.lean @@ -8,44 +8,6 @@ import Lean.Compiler.LCNF.Basic namespace Lean.Compiler.LCNF -partial def Code.forEachExpr [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (f : Expr → m Unit) (c : Code) (skipTypes := false) : m Unit := do - visit c |>.run -where - visit (c : Code) : MonadCacheT Expr Unit m Unit := do - match c with - | .let decl k => - visitType decl.type - visitExpr decl.value - visit k - | .jp decl k | .fun decl k => - visitType decl.type - decl.params.forM visitParam - visit decl.value - visit k - | .unreach type => visitType type - | .return .. => return () - | .jmp _ args => args.forM visitExpr - | .cases c => visitType c.resultType; c.alts.forM fun alt => do - match alt with - | .default k => visit k - | .alt _ ps k => ps.forM visitParam; visit k - - visitParam (p : Param) : MonadCacheT Expr Unit m Unit := - visitType p.type - - visitExpr (e : Expr) : MonadCacheT Expr Unit m Unit := - ForEachExpr.visit (fun e => f e *> return true) e - - visitType (e : Expr) : MonadCacheT Expr Unit m Unit := - unless skipTypes do - visitExpr e - -def Decl.forEachExpr [STWorld ω m] [MonadLiftT (ST ω) m] [Monad m] (f : Expr → m Unit) (decl : Decl) (skipTypes := false) : m Unit := do - visit |>.run -where - visit : MonadCacheT Expr Unit m Unit := do - Code.forEachExpr.visitType f decl.type (skipTypes := skipTypes) - decl.params.forM (Code.forEachExpr.visitParam f (skipTypes := skipTypes)) - Code.forEachExpr.visit f decl.value (skipTypes := skipTypes) +-- TODO: delete end Lean.Compiler.LCNF \ No newline at end of file diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index 4a4f68a59c..830c0dbc06 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -101,21 +101,22 @@ def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do getOtherDeclType declName us mutual + partial def inferArgType (arg : Arg) : InferTypeM Expr := + match arg with + | .erased => return erasedExpr + | .type e => inferType e + | .fvar fvarId => LCNF.getType fvarId +-- TODO: stopped here partial def inferType (e : Expr) : InferTypeM Expr := match e with | .const c us => inferConstType c us - | .proj n i s => inferProjType n i s | .app .. => inferAppType e - | .mvar .. => throwError "unexpected metavariable {e}" | .fvar fvarId => InferType.getType fvarId - | .bvar .. => throwError "unexpected bound variable {e}" - | .mdata _ e => inferType e - | .lit v => return v.type | .sort lvl => return .sort (mkLevelSucc lvl) | .forallE .. => inferForallType e | .lam .. => inferLambdaType e - | .letE .. => inferLambdaType e + | .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable! partial def inferAppTypeCore (f : Expr) (args : Array Expr) : InferTypeM Expr := do let mut j := 0 diff --git a/src/Lean/Compiler/LCNF/LCtx.lean b/src/Lean/Compiler/LCNF/LCtx.lean index 571493e810..ba50bf037d 100644 --- a/src/Lean/Compiler/LCNF/LCtx.lean +++ b/src/Lean/Compiler/LCNF/LCtx.lean @@ -65,7 +65,7 @@ def LCtx.toLocalContext (lctx : LCtx) : LocalContext := Id.run do for (_, param) in lctx.params.toArray do result := result.addDecl (.cdecl 0 param.fvarId param.binderName param.type .default .default) for (_, decl) in lctx.letDecls.toArray do - result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value true .default) + result := result.addDecl (.ldecl 0 decl.fvarId decl.binderName decl.type decl.value.toExpr true .default) for (_, decl) in lctx.funDecls.toArray do result := result.addDecl (.cdecl 0 decl.fvarId decl.binderName decl.type .default .default) return result diff --git a/src/Lean/Compiler/LCNF/Level.lean b/src/Lean/Compiler/LCNF/Level.lean index c0d4d1e9b3..e37d502a2f 100644 --- a/src/Lean/Compiler/LCNF/Level.lean +++ b/src/Lean/Compiler/LCNF/Level.lean @@ -97,8 +97,25 @@ See `Decl.setLevelParams`. -/ open Lean.CollectLevelParams +abbrev visitType (type : Expr) : Visitor := + visitExpr type + +def visitArg (arg : Arg) : Visitor := + match arg with + | .erased | .fvar .. => id + | .type e => visitType e + +def visitArgs (args : Array Arg) : Visitor := + fun s => args.foldl (init := s) fun s arg => visitArg arg s + +def visitLetExpr (e : LetExpr) : Visitor := + match e with + | .erased | .value .. | .proj .. => id + | .const _ us args => visitLevels us ∘ visitArgs args + | .fvar _ args => visitArgs args + def visitParam (p : Param) : Visitor := - visitExpr p.type + visitType p.type def visitParams (ps : Array Param) : Visitor := fun s => ps.foldl (init := s) fun s p => visitParam p s @@ -113,12 +130,12 @@ mutual fun s => alts.foldl (init := s) fun s alt => visitAlt alt s partial def visitCode : Code → Visitor - | .let decl k => visitCode k ∘ visitExpr decl.value ∘ visitExpr decl.type - | .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitExpr decl.type - | .cases c => visitAlts c.alts ∘ visitExpr c.resultType - | .unreach type => visitExpr type + | .let decl k => visitCode k ∘ visitLetExpr decl.value ∘ visitType decl.type + | .fun decl k | .jp decl k => visitCode k ∘ visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type + | .cases c => visitAlts c.alts ∘ visitType c.resultType + | .unreach type => visitType type | .return _ => id - | .jmp _ args => fun s => args.foldl (init := s) fun s arg => visitExpr arg s + | .jmp _ args => visitArgs args end end CollectLevelParams @@ -131,7 +148,7 @@ Collect universe level parameters collecting in the type, parameters, and value, set `decl.levelParams` with the resulting value. -/ def Decl.setLevelParams (decl : Decl) : Decl := - let levelParams := (visitCode decl.value ∘ visitParams decl.params ∘ visitExpr decl.type) {} |>.params.toList + let levelParams := (visitCode decl.value ∘ visitParams decl.params ∘ visitType decl.type) {} |>.params.toList { decl with levelParams } end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Simp/Basic.lean b/src/Lean/Compiler/LCNF/Simp/Basic.lean index 5e7f97b6ae..e29c71657f 100644 --- a/src/Lean/Compiler/LCNF/Simp/Basic.lean +++ b/src/Lean/Compiler/LCNF/Simp/Basic.lean @@ -11,25 +11,29 @@ import Lean.Compiler.LCNF.CompilerM namespace Lean.Compiler.LCNF namespace Simp -partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do +/- +-- TODO: cleanup +partial def findExpr (e : LetExpr) : CompilerM LetExpr := do match e with - | .fvar fvarId => - let some decl ← findLetDecl? fvarId | return e - findExpr decl.value - | .mdata _ e' => if skipMData then findExpr e' else return e + | .fvar fvarId args => + if args.isEmpty then + let some decl ← findLetDecl? fvarId | return e + findExpr decl.value + else + return e | _ => return e -partial def findFunDecl? (e : Expr) : CompilerM (Option FunDecl) := do +partial def findFunDecl? (e : LetExpr) : CompilerM (Option FunDecl) := do match e with - | .fvar fvarId => + | .fvar fvarId args => + if let some decl ← LCNF.findFunDecl? fvarId then return some decl else if let some decl ← findLetDecl? fvarId then findFunDecl? decl.value else return none - | .mdata _ e => findFunDecl? e | _ => return none - +-/ end Simp end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ToExpr.lean b/src/Lean/Compiler/LCNF/ToExpr.lean index bad589b21a..9abc4fb15d 100644 --- a/src/Lean/Compiler/LCNF/ToExpr.lean +++ b/src/Lean/Compiler/LCNF/ToExpr.lean @@ -74,6 +74,9 @@ end ToExpr open ToExpr +private def Arg.toExprM (arg : Arg) : ToExprM Expr := + return arg.toExpr.abstract' (← read) (← get) + mutual partial def FunDeclCore.toExprM (decl : FunDecl) : ToExprM Expr := withParams decl.params do mkLambdaM decl.params (← decl.value.toExprM) @@ -82,7 +85,7 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do match code with | .let decl k => let type ← abstractM decl.type - let value ← abstractM decl.value + let value ← abstractM decl.value.toExpr let body ← withFVar decl.fvarId k.toExprM return .letE decl.binderName type value body true | .fun decl k | .jp decl k => @@ -91,7 +94,7 @@ partial def Code.toExprM (code : Code) : ToExprM Expr := do let body ← withFVar decl.fvarId k.toExprM return .letE decl.binderName type value body true | .return fvarId => fvarId.toExprM - | .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM abstractM) + | .jmp fvarId args => return mkAppN (← fvarId.toExprM) (← args.mapM Arg.toExprM) | .unreach type => return mkApp (mkConst ``lcUnreachable) (← abstractM type) | .cases c => let alts ← c.alts.mapM fun