diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 9acb4cca34..1737e91307 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`. `Check.lean` contains a substitution validator. -/ -abbrev FVarSubst := Std.HashMap FVarId Expr +abbrev FVarSubst := Std.HashMap FVarId Arg /-- Replace the free variables in `e` using the given substitution. @@ -191,7 +191,9 @@ where if e.hasFVar then match e with | .fvar fvarId => match s[fvarId]? with - | some e => if translator then e else go e + | some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId') + | some (.type e) => if translator then e else go e + | some .erased => erasedExpr | none => e | .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e | .app f a => e.updateApp! (goApp f) (go a) |>.headBeta @@ -230,11 +232,9 @@ private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : .fvar fvarId' else normFVarImp s fvarId' translator - | some e => - if e.isErased then - .erased - else - panic! s!"invalid LCNF substitution of free variable with expression {e}" + -- Types and type formers are only preserved as hints and + -- are erased in computationally relevant contexts. + | some .erased | some (.type _) => .erased | none => .fvar fvarId /-- @@ -247,10 +247,9 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : | .erased => arg | .fvar fvarId => match s[fvarId]? with - | some (.fvar fvarId') => - let arg' := .fvar fvarId' + | some (arg'@(.fvar _)) => if translator then arg' else normArgImp s arg' translator - | some e => if e.isErased then .erased else .type e + | some (arg'@.erased) | some (arg'@(.type _)) => arg' | none => arg | .type e => arg.updateType! (normExprImp s e translator) @@ -292,21 +291,20 @@ export MonadFVarSubstState (modifySubst) instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where modifySubst f := liftM (modifySubst f : m _) +/-- +Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`. + +See `Check.lean` for the free variable substitution checker. +-/ +@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit := + modifySubst fun s => s.insert fvarId arg + /-- Add the entry `fvarId ↦ fvarId'` to the free variable substitution. -/ @[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit := modifySubst fun s => s.insert fvarId (.fvar fvarId') -/-- -Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF argument. -That is, it must be a free variable, type (or type former), or `lcErased`. - -See `Check.lean` for the free variable substitution checker. --/ -@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (e : Expr) : m Unit := - modifySubst fun s => s.insert fvarId e - @[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult := return normFVarImp (← getSubst) fvarId t diff --git a/src/Lean/Compiler/LCNF/JoinPoints.lean b/src/Lean/Compiler/LCNF/JoinPoints.lean index c6e3879430..8ecd79f24e 100644 --- a/src/Lean/Compiler/LCNF/JoinPoints.lean +++ b/src/Lean/Compiler/LCNF/JoinPoints.lean @@ -546,13 +546,13 @@ where let mut newArgs := knownArgs for (param, arg) in decl.params.zip args do if let some knownVal := newArgs[param.fvarId]? then - if arg.toExpr != knownVal then + if arg != knownVal then newArgs := newArgs.erase param.fvarId modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs } else let folder := fun acc (param, arg) => do if (← allFVarM (isInJpScope fn) arg) then - return acc.insert param.fvarId arg.toExpr + return acc.insert param.fvarId arg else return acc let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder diff --git a/src/Lean/Compiler/LCNF/Simp/Main.lean b/src/Lean/Compiler/LCNF/Simp/Main.lean index 649d6dbd1d..e956ec4906 100644 --- a/src/Lean/Compiler/LCNF/Simp/Main.lean +++ b/src/Lean/Compiler/LCNF/Simp/Main.lean @@ -46,7 +46,7 @@ We use this function to inline/specialize a partial application of a local funct def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do let mut subst := {} for param in info.params, arg in info.args do - subst := subst.insert param.fvarId arg.toExpr + subst := subst.insert param.fvarId arg let mut paramsNew := #[] for param in info.params[info.args.size:] do let type ← replaceExprFVars param.type subst (translator := true) @@ -201,7 +201,7 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do | .ctor ctorVal ctorArgs => let fields := ctorArgs[ctorVal.numParams:] for param in params, field in fields do - addSubst param.fvarId field.toExpr + addSubst param.fvarId field let k ← simp k eraseParams params return k @@ -231,7 +231,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do -- and `FVarId` rather than `Arg`, and the substitution will end up -- creating a new erased let decl in that case. if decl.type.isErased && decl.value != .erased then - modifySubst fun s => s.insert decl.fvarId (.const ``lcErased []) + addSubst decl.fvarId .erased eraseLetDecl decl simp k else if let some decls ← ConstantFold.foldConstants decl then diff --git a/src/Lean/Compiler/LCNF/Simp/SimpM.lean b/src/Lean/Compiler/LCNF/Simp/SimpM.lean index fbfbe770cb..d0cf683979 100644 --- a/src/Lean/Compiler/LCNF/Simp/SimpM.lean +++ b/src/Lean/Compiler/LCNF/Simp/SimpM.lean @@ -212,7 +212,7 @@ See comment at `updateFunDeclInfo`. def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInline := false) : SimpM Code := do let mut subst := {} for param in params, arg in args do - subst := subst.insert param.fvarId arg.toExpr + subst := subst.insert param.fvarId arg let code ← code.internalize subst updateFunDeclInfo code mustInline return code diff --git a/src/Lean/Compiler/LCNF/Specialize.lean b/src/Lean/Compiler/LCNF/Specialize.lean index cc0072c662..8b77861b16 100644 --- a/src/Lean/Compiler/LCNF/Specialize.lean +++ b/src/Lean/Compiler/LCNF/Specialize.lean @@ -238,7 +238,7 @@ where for param in decl.params, arg in argMask do if let some arg := arg then let arg ← normArg arg - modify fun s => s.insert param.fvarId arg.toExpr + modify fun s => s.insert param.fvarId arg else -- Keep the parameter let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us } diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index eaa3de81ca..0d1b4937b4 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -120,7 +120,7 @@ where let type ← replaceExprFVars param.type subst (translator := true) let paramNew ← mkAuxParam type jpParams := jpParams.push paramNew - subst := subst.insert param.fvarId (Expr.fvar paramNew.fvarId) + subst := subst.insert param.fvarId (.fvar paramNew.fvarId) jpArgs := jpArgs.push (Arg.fvar paramNew.fvarId) let letDecl ← mkAuxLetDecl (.fvar f jpArgs) let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])