fix: use Arg in LCNF FVarSubst rather than Expr (#8729)
This PR changes LCNF's `FVarSubst` to use `Arg` rather than `Expr`. This enforces the requirements on substitutions, which match the requirements on `Arg`.
This commit is contained in:
parent
77fd1ba6b9
commit
39cbe04946
6 changed files with 25 additions and 27 deletions
|
|
@ -166,7 +166,7 @@ it is a free variable, a type (or type former), or `lcErased`.
|
|||
|
||||
`Check.lean` contains a substitution validator.
|
||||
-/
|
||||
abbrev FVarSubst := Std.HashMap FVarId Expr
|
||||
abbrev FVarSubst := Std.HashMap FVarId Arg
|
||||
|
||||
/--
|
||||
Replace the free variables in `e` using the given substitution.
|
||||
|
|
@ -191,7 +191,9 @@ where
|
|||
if e.hasFVar then
|
||||
match e with
|
||||
| .fvar fvarId => match s[fvarId]? with
|
||||
| some e => if translator then e else go e
|
||||
| some (.fvar fvarId') => if translator then .fvar fvarId' else go (.fvar fvarId')
|
||||
| some (.type e) => if translator then e else go e
|
||||
| some .erased => erasedExpr
|
||||
| none => e
|
||||
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
|
||||
| .app f a => e.updateApp! (goApp f) (go a) |>.headBeta
|
||||
|
|
@ -230,11 +232,9 @@ private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator :
|
|||
.fvar fvarId'
|
||||
else
|
||||
normFVarImp s fvarId' translator
|
||||
| some e =>
|
||||
if e.isErased then
|
||||
.erased
|
||||
else
|
||||
panic! s!"invalid LCNF substitution of free variable with expression {e}"
|
||||
-- Types and type formers are only preserved as hints and
|
||||
-- are erased in computationally relevant contexts.
|
||||
| some .erased | some (.type _) => .erased
|
||||
| none => .fvar fvarId
|
||||
|
||||
/--
|
||||
|
|
@ -247,10 +247,9 @@ private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) :
|
|||
| .erased => arg
|
||||
| .fvar fvarId =>
|
||||
match s[fvarId]? with
|
||||
| some (.fvar fvarId') =>
|
||||
let arg' := .fvar fvarId'
|
||||
| some (arg'@(.fvar _)) =>
|
||||
if translator then arg' else normArgImp s arg' translator
|
||||
| some e => if e.isErased then .erased else .type e
|
||||
| some (arg'@.erased) | some (arg'@(.type _)) => arg'
|
||||
| none => arg
|
||||
| .type e => arg.updateType! (normExprImp s e translator)
|
||||
|
||||
|
|
@ -292,21 +291,20 @@ export MonadFVarSubstState (modifySubst)
|
|||
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
|
||||
modifySubst f := liftM (modifySubst f : m _)
|
||||
|
||||
/--
|
||||
Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF `Arg`.
|
||||
|
||||
See `Check.lean` for the free variable substitution checker.
|
||||
-/
|
||||
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (arg : Arg) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId arg
|
||||
|
||||
/--
|
||||
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
|
||||
-/
|
||||
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId (.fvar fvarId')
|
||||
|
||||
/--
|
||||
Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF argument.
|
||||
That is, it must be a free variable, type (or type former), or `lcErased`.
|
||||
|
||||
See `Check.lean` for the free variable substitution checker.
|
||||
-/
|
||||
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (e : Expr) : m Unit :=
|
||||
modifySubst fun s => s.insert fvarId e
|
||||
|
||||
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
|
||||
return normFVarImp (← getSubst) fvarId t
|
||||
|
||||
|
|
|
|||
|
|
@ -546,13 +546,13 @@ where
|
|||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
if let some knownVal := newArgs[param.fvarId]? then
|
||||
if arg.toExpr != knownVal then
|
||||
if arg != knownVal then
|
||||
newArgs := newArgs.erase param.fvarId
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
|
||||
else
|
||||
let folder := fun acc (param, arg) => do
|
||||
if (← allFVarM (isInJpScope fn) arg) then
|
||||
return acc.insert param.fvarId arg.toExpr
|
||||
return acc.insert param.fvarId arg
|
||||
else
|
||||
return acc
|
||||
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ We use this function to inline/specialize a partial application of a local funct
|
|||
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
||||
let mut subst := {}
|
||||
for param in info.params, arg in info.args do
|
||||
subst := subst.insert param.fvarId arg.toExpr
|
||||
subst := subst.insert param.fvarId arg
|
||||
let mut paramsNew := #[]
|
||||
for param in info.params[info.args.size:] do
|
||||
let type ← replaceExprFVars param.type subst (translator := true)
|
||||
|
|
@ -201,7 +201,7 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
|||
| .ctor ctorVal ctorArgs =>
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
for param in params, field in fields do
|
||||
addSubst param.fvarId field.toExpr
|
||||
addSubst param.fvarId field
|
||||
let k ← simp k
|
||||
eraseParams params
|
||||
return k
|
||||
|
|
@ -231,7 +231,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
|||
-- and `FVarId` rather than `Arg`, and the substitution will end up
|
||||
-- creating a new erased let decl in that case.
|
||||
if decl.type.isErased && decl.value != .erased then
|
||||
modifySubst fun s => s.insert decl.fvarId (.const ``lcErased [])
|
||||
addSubst decl.fvarId .erased
|
||||
eraseLetDecl decl
|
||||
simp k
|
||||
else if let some decls ← ConstantFold.foldConstants decl then
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ See comment at `updateFunDeclInfo`.
|
|||
def betaReduce (params : Array Param) (code : Code) (args : Array Arg) (mustInline := false) : SimpM Code := do
|
||||
let mut subst := {}
|
||||
for param in params, arg in args do
|
||||
subst := subst.insert param.fvarId arg.toExpr
|
||||
subst := subst.insert param.fvarId arg
|
||||
let code ← code.internalize subst
|
||||
updateFunDeclInfo code mustInline
|
||||
return code
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ where
|
|||
for param in decl.params, arg in argMask do
|
||||
if let some arg := arg then
|
||||
let arg ← normArg arg
|
||||
modify fun s => s.insert param.fvarId arg.toExpr
|
||||
modify fun s => s.insert param.fvarId arg
|
||||
else
|
||||
-- Keep the parameter
|
||||
let param := { param with type := param.type.instantiateLevelParamsNoCache decl.levelParams us }
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ where
|
|||
let type ← replaceExprFVars param.type subst (translator := true)
|
||||
let paramNew ← mkAuxParam type
|
||||
jpParams := jpParams.push paramNew
|
||||
subst := subst.insert param.fvarId (Expr.fvar paramNew.fvarId)
|
||||
subst := subst.insert param.fvarId (.fvar paramNew.fvarId)
|
||||
jpArgs := jpArgs.push (Arg.fvar paramNew.fvarId)
|
||||
let letDecl ← mkAuxLetDecl (.fvar f jpArgs)
|
||||
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue