fix: use Arg in LCNF FVarSubst rather than Expr (#8729)

This PR changes LCNF's `FVarSubst` to use `Arg` rather than `Expr`. This
enforces the requirements on substitutions, which match the requirements
on `Arg`.
This commit is contained in:
Cameron Zwarich 2025-06-11 11:08:30 -07:00 committed by GitHub
parent 77fd1ba6b9
commit 39cbe04946
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 25 additions and 27 deletions

View file

@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
`Check.lean` contains a substitution validator.
-/
abbrev FVarSubst := Std.HashMap FVarId Expr
abbrev FVarSubst := Std.HashMap FVarId Arg
/--
Replace the free variables in `e` using the given substitution.
@ -191,7 +191,9 @@ where
if e.hasFVar then
match e with
| .fvar fvarId => match s[fvarId]? with
| some e => if translator then e else go e
| some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId')
| some (.type e) => if translator then e else go e
| some .erased => erasedExpr
| none => e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
| .app f a => e.updateApp! (goApp f) (go a) |>.headBeta
@ -230,11 +232,9 @@ private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator :
.fvar fvarId'
else
normFVarImp s fvarId' translator
| some e =>
if e.isErased then
.erased
else
panic! s!"invalid LCNF substitution of free variable with expression {e}"
-- Types and type formers are only preserved as hints and
-- are erased in computationally relevant contexts.
| some .erased | some (.type _) => .erased
| none => .fvar fvarId
/--
@ -247,10 +247,9 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
| .erased => arg
| .fvar fvarId =>
match s[fvarId]? with
| some (.fvar fvarId') =>
let arg' := .fvar fvarId'
| some (arg'@(.fvar _)) =>
if translator then arg' else normArgImp s arg' translator
| some e => if e.isErased then .erased else .type e
| some (arg'@.erased) | some (arg'@(.type _)) => arg'
| none => arg
| .type e => arg.updateType! (normExprImp s e translator)
@ -292,21 +291,20 @@ export MonadFVarSubstState (modifySubst)
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
modifySubst f := liftM (modifySubst f : m _)
/--
Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`.
See `Check.lean` for the free variable substitution checker.
-/
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit :=
modifySubst fun s => s.insert fvarId arg
/--
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
-/
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
modifySubst fun s => s.insert fvarId (.fvar fvarId')
/--
Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF argument.
That is, it must be a free variable, type (or type former), or `lcErased`.
See `Check.lean` for the free variable substitution checker.
-/
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (e : Expr) : m Unit :=
modifySubst fun s => s.insert fvarId e
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
return normFVarImp (← getSubst) fvarId t

View file

@ -546,13 +546,13 @@ where
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
if let some knownVal := newArgs[param.fvarId]? then
if arg.toExpr != knownVal then
if arg != knownVal then
newArgs := newArgs.erase param.fvarId
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
else
let folder := fun acc (param, arg) => do
if (← allFVarM (isInJpScope fn) arg) then
return acc.insert param.fvarId arg.toExpr
return acc.insert param.fvarId arg
else
return acc
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder

View file

@ -46,7 +46,7 @@ We use this function to inline/specialize a partial application of a local funct
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
let mut subst := {}
for param in info.params, arg in info.args do
subst := subst.insert param.fvarId arg.toExpr
subst := subst.insert param.fvarId arg
let mut paramsNew := #[]
for param in info.params[info.args.size:] do
let type ← replaceExprFVars param.type subst (translator := true)
@ -201,7 +201,7 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
| .ctor ctorVal ctorArgs =>
let fields := ctorArgs[ctorVal.numParams:]
for param in params, field in fields do
addSubst param.fvarId field.toExpr
addSubst param.fvarId field
let k ← simp k
eraseParams params
return k
@ -231,7 +231,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
-- and `FVarId` rather than `Arg`, and the substitution will end up
-- creating a new erased let decl in that case.
if decl.type.isErased && decl.value != .erased then
modifySubst fun s => s.insert decl.fvarId (.const ``lcErased [])
addSubst decl.fvarId .erased
eraseLetDecl decl
simp k
else if let some decls ← ConstantFold.foldConstants decl then

View file

@ -212,7 +212,7 @@ See comment at `updateFunDeclInfo`.
def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInline := false) : SimpM Code := do
let mut subst := {}
for param in params, arg in args do
subst := subst.insert param.fvarId arg.toExpr
subst := subst.insert param.fvarId arg
let code ← code.internalize subst
updateFunDeclInfo code mustInline
return code

View file

@ -238,7 +238,7 @@ where
for param in decl.params, arg in argMask do
if let some arg := arg then
let arg ← normArg arg
modify fun s => s.insert param.fvarId arg.toExpr
modify fun s => s.insert param.fvarId arg
else
-- Keep the parameter
let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }

View file

@ -120,7 +120,7 @@ where
let type ← replaceExprFVars param.type subst (translator := true)
let paramNew ← mkAuxParam type
jpParams := jpParams.push paramNew
subst := subst.insert param.fvarId (Expr.fvar paramNew.fvarId)
subst := subst.insert param.fvarId (.fvar paramNew.fvarId)
jpArgs := jpArgs.push (Arg.fvar paramNew.fvarId)
let letDecl ← mkAuxLetDecl (.fvar f jpArgs)
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])