chore: cleanup
This commit is contained in:
parent
91d51d06e0
commit
5481999560
8 changed files with 1280 additions and 1282 deletions
|
|
@ -9,60 +9,60 @@ namespace Lean.Meta
|
|||
namespace AbstractNestedProofs
|
||||
|
||||
def isNonTrivialProof (e : Expr) : MetaM Bool := do
|
||||
if ! (← isProof e) then
|
||||
pure false
|
||||
else
|
||||
e.withApp fun f args =>
|
||||
pure $ !f.isAtomic || args.any fun arg => !arg.isAtomic
|
||||
if !(← isProof e) then
|
||||
pure false
|
||||
else
|
||||
e.withApp fun f args =>
|
||||
pure $ !f.isAtomic || args.any fun arg => !arg.isAtomic
|
||||
|
||||
structure Context :=
|
||||
(baseName : Name)
|
||||
(baseName : Name)
|
||||
|
||||
structure State :=
|
||||
(nextIdx : Nat := 1)
|
||||
(nextIdx : Nat := 1)
|
||||
|
||||
abbrev M := ReaderT Context $ MonadCacheT Expr Expr $ StateRefT State $ MetaM
|
||||
|
||||
private def mkAuxLemma (e : Expr) : M Expr := do
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let lemmaName ← mkAuxName (ctx.baseName ++ `proof) s.nextIdx
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
mkAuxDefinitionFor lemmaName e
|
||||
let ctx ← read
|
||||
let s ← get
|
||||
let lemmaName ← mkAuxName (ctx.baseName ++ `proof) s.nextIdx
|
||||
modify fun s => { s with nextIdx := s.nextIdx + 1 }
|
||||
mkAuxDefinitionFor lemmaName e
|
||||
|
||||
partial def visit (e : Expr) : M Expr := do
|
||||
if e.isAtomic then
|
||||
pure e
|
||||
else
|
||||
let visitBinders (xs : Array Expr) (k : M Expr) : M Expr := do
|
||||
let localInstances ← getLocalInstances
|
||||
let lctx ← getLCtx
|
||||
for x in xs do
|
||||
let xFVarId := x.fvarId!
|
||||
let localDecl ← getLocalDecl xFVarId
|
||||
let type ← visit localDecl.type
|
||||
let localDecl := localDecl.setType type
|
||||
let localDecl ← match localDecl.value? with
|
||||
| some value => do let value ← visit value; pure $ localDecl.setValue value
|
||||
| none => pure localDecl
|
||||
lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl
|
||||
withLCtx lctx localInstances k
|
||||
checkCache e fun e =>
|
||||
if (← isNonTrivialProof e) then
|
||||
mkAuxLemma e
|
||||
else match e with
|
||||
| Expr.lam _ _ _ _ => lambdaLetTelescope e fun xs b => visitBinders xs do mkLambdaFVars xs (← visit b)
|
||||
| Expr.letE _ _ _ _ _ => lambdaLetTelescope e fun xs b => visitBinders xs do mkLambdaFVars xs (← visit b)
|
||||
| Expr.forallE _ _ _ _ => forallTelescope e fun xs b => visitBinders xs do mkForallFVars xs (← visit b)
|
||||
| Expr.mdata _ b _ => do pure $ e.updateMData! (← visit b)
|
||||
| Expr.proj _ _ b _ => do pure $ e.updateProj! (← visit b)
|
||||
| Expr.app _ _ _ => e.withApp fun f args => do pure $ mkAppN f (← args.mapM visit)
|
||||
| _ => pure e
|
||||
if e.isAtomic then
|
||||
pure e
|
||||
else
|
||||
let visitBinders (xs : Array Expr) (k : M Expr) : M Expr := do
|
||||
let localInstances ← getLocalInstances
|
||||
let lctx ← getLCtx
|
||||
for x in xs do
|
||||
let xFVarId := x.fvarId!
|
||||
let localDecl ← getLocalDecl xFVarId
|
||||
let type ← visit localDecl.type
|
||||
let localDecl := localDecl.setType type
|
||||
let localDecl ← match localDecl.value? with
|
||||
| some value => do let value ← visit value; pure $ localDecl.setValue value
|
||||
| none => pure localDecl
|
||||
lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl
|
||||
withLCtx lctx localInstances k
|
||||
checkCache e fun e => do
|
||||
if (← isNonTrivialProof e) then
|
||||
mkAuxLemma e
|
||||
else match e with
|
||||
| Expr.lam _ _ _ _ => lambdaLetTelescope e fun xs b => visitBinders xs do mkLambdaFVars xs (← visit b)
|
||||
| Expr.letE _ _ _ _ _ => lambdaLetTelescope e fun xs b => visitBinders xs do mkLambdaFVars xs (← visit b)
|
||||
| Expr.forallE _ _ _ _ => forallTelescope e fun xs b => visitBinders xs do mkForallFVars xs (← visit b)
|
||||
| Expr.mdata _ b _ => return e.updateMData! (← visit b)
|
||||
| Expr.proj _ _ b _ => return e.updateProj! (← visit b)
|
||||
| Expr.app _ _ _ => e.withApp fun f args => do return mkAppN f (← args.mapM visit)
|
||||
| _ => pure e
|
||||
|
||||
end AbstractNestedProofs
|
||||
|
||||
/-- Replace proofs nested in `e` with new lemmas. The new lemmas have names of the form `mainDeclName.proof_<idx>` -/
|
||||
def abstractNestedProofs (mainDeclName : Name) (e : Expr) : MetaM Expr :=
|
||||
(((AbstractNestedProofs.visit e).run { baseName := mainDeclName }).run).run' { nextIdx := 1 }
|
||||
AbstractNestedProofs.visit e $.run { baseName := mainDeclName } $.run $.run' { nextIdx := 1 }
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -14,100 +14,100 @@ whether terms produced by tactics and `isDefEq` are type correct.
|
|||
namespace Lean.Meta
|
||||
|
||||
private def ensureType (e : Expr) : MetaM Unit := do
|
||||
getLevel e
|
||||
pure ()
|
||||
getLevel e
|
||||
pure ()
|
||||
|
||||
def throwLetTypeMismatchMessage {α} (fvarId : FVarId) : MetaM α := do
|
||||
let lctx ← getLCtx
|
||||
match lctx.find? fvarId with
|
||||
| some (LocalDecl.ldecl _ _ n t v _) => do
|
||||
let vType ← inferType v
|
||||
throwError! "invalid let declaration, term{indentExpr v}\nhas type{indentExpr vType}\nbut is expected to have type{indentExpr t}"
|
||||
| _ => unreachable!
|
||||
let lctx ← getLCtx
|
||||
match lctx.find? fvarId with
|
||||
| some (LocalDecl.ldecl _ _ n t v _) => do
|
||||
let vType ← inferType v
|
||||
throwError! "invalid let declaration, term{indentExpr v}\nhas type{indentExpr vType}\nbut is expected to have type{indentExpr t}"
|
||||
| _ => unreachable!
|
||||
|
||||
@[specialize] private def checkLambdaLet
|
||||
(check : Expr → MetaM Unit)
|
||||
(e : Expr) : MetaM Unit :=
|
||||
lambdaLetTelescope e fun xs b => do
|
||||
xs.forM fun x => do
|
||||
let xDecl ← getFVarLocalDecl x;
|
||||
match xDecl with
|
||||
| LocalDecl.cdecl _ _ _ t _ =>
|
||||
ensureType t
|
||||
check t
|
||||
| LocalDecl.ldecl _ _ _ t v _ =>
|
||||
ensureType t
|
||||
check t
|
||||
let vType ← inferType v
|
||||
unless (← isDefEq t vType) do throwLetTypeMismatchMessage x.fvarId!
|
||||
check v
|
||||
check b
|
||||
lambdaLetTelescope e fun xs b => do
|
||||
xs.forM fun x => do
|
||||
let xDecl ← getFVarLocalDecl x;
|
||||
match xDecl with
|
||||
| LocalDecl.cdecl _ _ _ t _ =>
|
||||
ensureType t
|
||||
check t
|
||||
| LocalDecl.ldecl _ _ _ t v _ =>
|
||||
ensureType t
|
||||
check t
|
||||
let vType ← inferType v
|
||||
unless (← isDefEq t vType) do throwLetTypeMismatchMessage x.fvarId!
|
||||
check v
|
||||
check b
|
||||
|
||||
@[specialize] private def checkForall
|
||||
(check : Expr → MetaM Unit)
|
||||
(e : Expr) : MetaM Unit :=
|
||||
forallTelescope e fun xs b => do
|
||||
xs.forM fun x => do
|
||||
let xDecl ← getFVarLocalDecl x
|
||||
ensureType xDecl.type
|
||||
check xDecl.type
|
||||
ensureType b
|
||||
check b
|
||||
forallTelescope e fun xs b => do
|
||||
xs.forM fun x => do
|
||||
let xDecl ← getFVarLocalDecl x
|
||||
ensureType xDecl.type
|
||||
check xDecl.type
|
||||
ensureType b
|
||||
check b
|
||||
|
||||
private def checkConstant (constName : Name) (us : List Level) : MetaM Unit := do
|
||||
let cinfo ← getConstInfo constName
|
||||
unless us.length == cinfo.lparams.length do
|
||||
throwIncorrectNumberOfLevels constName us
|
||||
let cinfo ← getConstInfo constName
|
||||
unless us.length == cinfo.lparams.length do
|
||||
throwIncorrectNumberOfLevels constName us
|
||||
|
||||
private def getFunctionDomain (f : Expr) : MetaM Expr := do
|
||||
let fType ← inferType f
|
||||
let fType ← whnfD fType
|
||||
match fType with
|
||||
| Expr.forallE _ d _ _ => pure d
|
||||
| _ => throwFunctionExpected f
|
||||
let fType ← inferType f
|
||||
let fType ← whnfD fType
|
||||
match fType with
|
||||
| Expr.forallE _ d _ _ => pure d
|
||||
| _ => throwFunctionExpected f
|
||||
|
||||
def throwAppTypeMismatch {α} {m} [Monad m] [MonadExceptOf Exception m] [Ref m] [AddErrorMessageContext m] [MonadLiftT MetaM m]
|
||||
(f a : Expr) (extraMsg : MessageData := Format.nil) : m α := do
|
||||
let e := mkApp f a
|
||||
let aType ← inferType a
|
||||
let expectedType ← liftM $ getFunctionDomain f
|
||||
throwError! "application type mismatch{indentExpr e}\nargument{indentExpr a}\nhas type{indentExpr aType}\nbut is expected to have type{indentExpr expectedType}{extraMsg}"
|
||||
let e := mkApp f a
|
||||
let aType ← inferType a
|
||||
let expectedType ← liftM $ getFunctionDomain f
|
||||
throwError! "application type mismatch{indentExpr e}\nargument{indentExpr a}\nhas type{indentExpr aType}\nbut is expected to have type{indentExpr expectedType}{extraMsg}"
|
||||
|
||||
@[specialize] private def checkApp
|
||||
(check : Expr → MetaM Unit)
|
||||
(f a : Expr) : MetaM Unit := do
|
||||
check f
|
||||
check a
|
||||
let fType ← inferType f
|
||||
let fType ← whnf fType
|
||||
match fType with
|
||||
| Expr.forallE _ d _ _ =>
|
||||
let aType ← inferType a
|
||||
unless (← isDefEq d aType) do
|
||||
throwAppTypeMismatch f a
|
||||
| _ => throwFunctionExpected (mkApp f a)
|
||||
check f
|
||||
check a
|
||||
let fType ← inferType f
|
||||
let fType ← whnf fType
|
||||
match fType with
|
||||
| Expr.forallE _ d _ _ =>
|
||||
let aType ← inferType a
|
||||
unless (← isDefEq d aType) do
|
||||
throwAppTypeMismatch f a
|
||||
| _ => throwFunctionExpected (mkApp f a)
|
||||
|
||||
private partial def checkAux : Expr → MetaM Unit
|
||||
| e@(Expr.forallE ..) => checkForall checkAux e
|
||||
| e@(Expr.lam ..) => checkLambdaLet checkAux e
|
||||
| e@(Expr.letE ..) => checkLambdaLet checkAux e
|
||||
| Expr.const c lvls _ => checkConstant c lvls
|
||||
| Expr.app f a _ => checkApp checkAux f a
|
||||
| Expr.mdata _ e _ => checkAux e
|
||||
| Expr.proj _ _ e _ => checkAux e
|
||||
| _ => pure ()
|
||||
| e@(Expr.forallE ..) => checkForall checkAux e
|
||||
| e@(Expr.lam ..) => checkLambdaLet checkAux e
|
||||
| e@(Expr.letE ..) => checkLambdaLet checkAux e
|
||||
| Expr.const c lvls _ => checkConstant c lvls
|
||||
| Expr.app f a _ => checkApp checkAux f a
|
||||
| Expr.mdata _ e _ => checkAux e
|
||||
| Expr.proj _ _ e _ => checkAux e
|
||||
| _ => pure ()
|
||||
|
||||
def check (e : Expr) : MetaM Unit :=
|
||||
traceCtx `Meta.check $
|
||||
withTransparency TransparencyMode.all $ checkAux e
|
||||
traceCtx `Meta.check do
|
||||
withTransparency TransparencyMode.all $ checkAux e
|
||||
|
||||
def isTypeCorrect (e : Expr) : MetaM Bool := do
|
||||
try
|
||||
check e
|
||||
pure true
|
||||
catch ex =>
|
||||
trace[Meta.typeError]! ex.toMessageData
|
||||
pure false
|
||||
try
|
||||
check e
|
||||
pure true
|
||||
catch ex =>
|
||||
trace[Meta.typeError]! ex.toMessageData
|
||||
pure false
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Meta.check
|
||||
|
|
|
|||
|
|
@ -98,196 +98,197 @@ namespace Lean.Meta
|
|||
namespace Closure
|
||||
|
||||
structure ToProcessElement :=
|
||||
(fvarId : FVarId) (newFVarId : FVarId)
|
||||
(fvarId : FVarId)
|
||||
(newFVarId : FVarId)
|
||||
|
||||
instance : Inhabited ToProcessElement :=
|
||||
⟨⟨arbitrary _, arbitrary _⟩⟩
|
||||
⟨⟨arbitrary _, arbitrary _⟩⟩
|
||||
|
||||
structure Context :=
|
||||
(zeta : Bool)
|
||||
(zeta : Bool)
|
||||
|
||||
structure State :=
|
||||
(visitedLevel : LevelMap Level := {})
|
||||
(visitedExpr : ExprStructMap Expr := {})
|
||||
(levelParams : Array Name := #[])
|
||||
(nextLevelIdx : Nat := 1)
|
||||
(levelArgs : Array Level := #[])
|
||||
(newLocalDecls : Array LocalDecl := #[])
|
||||
(newLocalDeclsForMVars : Array LocalDecl := #[])
|
||||
(newLetDecls : Array LocalDecl := #[])
|
||||
(nextExprIdx : Nat := 1)
|
||||
(exprMVarArgs : Array Expr := #[])
|
||||
(exprFVarArgs : Array Expr := #[])
|
||||
(toProcess : Array ToProcessElement := #[])
|
||||
(visitedLevel : LevelMap Level := {})
|
||||
(visitedExpr : ExprStructMap Expr := {})
|
||||
(levelParams : Array Name := #[])
|
||||
(nextLevelIdx : Nat := 1)
|
||||
(levelArgs : Array Level := #[])
|
||||
(newLocalDecls : Array LocalDecl := #[])
|
||||
(newLocalDeclsForMVars : Array LocalDecl := #[])
|
||||
(newLetDecls : Array LocalDecl := #[])
|
||||
(nextExprIdx : Nat := 1)
|
||||
(exprMVarArgs : Array Expr := #[])
|
||||
(exprFVarArgs : Array Expr := #[])
|
||||
(toProcess : Array ToProcessElement := #[])
|
||||
|
||||
abbrev ClosureM := ReaderT Context $ StateRefT State MetaM
|
||||
|
||||
@[inline] def visitLevel (f : Level → ClosureM Level) (u : Level) : ClosureM Level := do
|
||||
if !u.hasMVar && !u.hasParam then
|
||||
pure u
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedLevel.find? u with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let v ← f u
|
||||
modify fun s => { s with visitedLevel := s.visitedLevel.insert u v }
|
||||
pure v
|
||||
if !u.hasMVar && !u.hasParam then
|
||||
pure u
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedLevel.find? u with
|
||||
| some v => pure v
|
||||
| none => do
|
||||
let v ← f u
|
||||
modify fun s => { s with visitedLevel := s.visitedLevel.insert u v }
|
||||
pure v
|
||||
|
||||
@[inline] def visitExpr (f : Expr → ClosureM Expr) (e : Expr) : ClosureM Expr := do
|
||||
if !e.hasLevelParam && !e.hasFVar && !e.hasMVar then
|
||||
pure e
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedExpr.find? e with
|
||||
| some r => pure r
|
||||
| none =>
|
||||
let r ← f e
|
||||
modify fun s => { s with visitedExpr := s.visitedExpr.insert e r }
|
||||
pure r
|
||||
if !e.hasLevelParam && !e.hasFVar && !e.hasMVar then
|
||||
pure e
|
||||
else
|
||||
let s ← get
|
||||
match s.visitedExpr.find? e with
|
||||
| some r => pure r
|
||||
| none =>
|
||||
let r ← f e
|
||||
modify fun s => { s with visitedExpr := s.visitedExpr.insert e r }
|
||||
pure r
|
||||
|
||||
def mkNewLevelParam (u : Level) : ClosureM Level := do
|
||||
let s ← get
|
||||
let p := (`u).appendIndexAfter s.nextLevelIdx
|
||||
modify fun s => { s with levelParams := s.levelParams.push p, nextLevelIdx := s.nextLevelIdx + 1, levelArgs := s.levelArgs.push u }
|
||||
pure $ mkLevelParam p
|
||||
let s ← get
|
||||
let p := (`u).appendIndexAfter s.nextLevelIdx
|
||||
modify fun s => { s with levelParams := s.levelParams.push p, nextLevelIdx := s.nextLevelIdx + 1, levelArgs := s.levelArgs.push u }
|
||||
pure $ mkLevelParam p
|
||||
|
||||
partial def collectLevelAux : Level → ClosureM Level
|
||||
| u@(Level.succ v _) => do return u.updateSucc! (← visitLevel collectLevelAux v)
|
||||
| u@(Level.max v w _) => do return u.updateMax! (← visitLevel collectLevelAux v) (← visitLevel collectLevelAux w)
|
||||
| u@(Level.imax v w _) => do return u.updateIMax! (← visitLevel collectLevelAux v) (← visitLevel collectLevelAux w)
|
||||
| u@(Level.mvar mvarId _) => mkNewLevelParam u
|
||||
| u@(Level.param _ _) => mkNewLevelParam u
|
||||
| u@(Level.zero _) => pure u
|
||||
| u@(Level.succ v _) => do return u.updateSucc! (← visitLevel collectLevelAux v)
|
||||
| u@(Level.max v w _) => do return u.updateMax! (← visitLevel collectLevelAux v) (← visitLevel collectLevelAux w)
|
||||
| u@(Level.imax v w _) => do return u.updateIMax! (← visitLevel collectLevelAux v) (← visitLevel collectLevelAux w)
|
||||
| u@(Level.mvar mvarId _) => mkNewLevelParam u
|
||||
| u@(Level.param _ _) => mkNewLevelParam u
|
||||
| u@(Level.zero _) => pure u
|
||||
|
||||
def collectLevel (u : Level) : ClosureM Level := do
|
||||
-- u ← instantiateLevelMVars u
|
||||
visitLevel collectLevelAux u
|
||||
-- u ← instantiateLevelMVars u
|
||||
visitLevel collectLevelAux u
|
||||
|
||||
def preprocess (e : Expr) : ClosureM Expr := do
|
||||
let e ← instantiateMVars e
|
||||
let ctx ← read
|
||||
-- If we are not zeta-expanding let-decls, then we use `check` to find
|
||||
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
|
||||
if !ctx.zeta then
|
||||
check e
|
||||
pure e
|
||||
let e ← instantiateMVars e
|
||||
let ctx ← read
|
||||
-- If we are not zeta-expanding let-decls, then we use `check` to find
|
||||
-- which let-decls are dependent. We say a let-decl is dependent if its lambda abstraction is type incorrect.
|
||||
if !ctx.zeta then
|
||||
check e
|
||||
pure e
|
||||
|
||||
/--
|
||||
Remark: This method does not guarantee unique user names.
|
||||
The correctness of the procedure does not rely on unique user names.
|
||||
Recall that the pretty printer takes care of unintended collisions. -/
|
||||
def mkNextUserName : ClosureM Name := do
|
||||
let s ← get
|
||||
let n := (`_x).appendIndexAfter s.nextExprIdx
|
||||
modify fun s => { s with nextExprIdx := s.nextExprIdx + 1 }
|
||||
pure n
|
||||
let s ← get
|
||||
let n := (`_x).appendIndexAfter s.nextExprIdx
|
||||
modify fun s => { s with nextExprIdx := s.nextExprIdx + 1 }
|
||||
pure n
|
||||
|
||||
def pushToProcess (elem : ToProcessElement) : ClosureM Unit :=
|
||||
modify fun s => { s with toProcess := s.toProcess.push elem }
|
||||
modify fun s => { s with toProcess := s.toProcess.push elem }
|
||||
|
||||
partial def collectExprAux (e : Expr) : ClosureM Expr := do
|
||||
let collect (e : Expr) := visitExpr collectExprAux e
|
||||
match e with
|
||||
| Expr.proj _ _ s _ => return e.updateProj! (← collect s)
|
||||
| Expr.forallE _ d b _ => return e.updateForallE! (← collect d) (← collect b)
|
||||
| Expr.lam _ d b _ => return e.updateLambdaE! (← collect d) (← collect b)
|
||||
| Expr.letE _ t v b _ => return e.updateLet! (← collect t) (← collect v) (← collect b)
|
||||
| Expr.app f a _ => return e.updateApp! (← collect f) (← collect a)
|
||||
| Expr.mdata _ b _ => return e.updateMData! (← collect b)
|
||||
| Expr.sort u _ => return e.updateSort! (← collectLevel u)
|
||||
| Expr.const c us _ => return e.updateConst! (← us.mapM collectLevel)
|
||||
| Expr.mvar mvarId _ =>
|
||||
let mvarDecl ← getMVarDecl mvarId
|
||||
let type ← preprocess mvarDecl.type
|
||||
let type ← collect type
|
||||
let newFVarId ← mkFreshFVarId
|
||||
let userName ← mkNextUserName
|
||||
modify fun s => { s with
|
||||
newLocalDeclsForMVars := s.newLocalDeclsForMVars.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type BinderInfo.default,
|
||||
exprMVarArgs := s.exprMVarArgs.push e
|
||||
}
|
||||
return mkFVar newFVarId
|
||||
| Expr.fvar fvarId _ =>
|
||||
match (← read).zeta, (← getLocalDecl fvarId).value? with
|
||||
| true, some value => collect (← preprocess value)
|
||||
| _, _ =>
|
||||
let collect (e : Expr) := visitExpr collectExprAux e
|
||||
match e with
|
||||
| Expr.proj _ _ s _ => return e.updateProj! (← collect s)
|
||||
| Expr.forallE _ d b _ => return e.updateForallE! (← collect d) (← collect b)
|
||||
| Expr.lam _ d b _ => return e.updateLambdaE! (← collect d) (← collect b)
|
||||
| Expr.letE _ t v b _ => return e.updateLet! (← collect t) (← collect v) (← collect b)
|
||||
| Expr.app f a _ => return e.updateApp! (← collect f) (← collect a)
|
||||
| Expr.mdata _ b _ => return e.updateMData! (← collect b)
|
||||
| Expr.sort u _ => return e.updateSort! (← collectLevel u)
|
||||
| Expr.const c us _ => return e.updateConst! (← us.mapM collectLevel)
|
||||
| Expr.mvar mvarId _ =>
|
||||
let mvarDecl ← getMVarDecl mvarId
|
||||
let type ← preprocess mvarDecl.type
|
||||
let type ← collect type
|
||||
let newFVarId ← mkFreshFVarId
|
||||
pushToProcess ⟨fvarId, newFVarId⟩
|
||||
let userName ← mkNextUserName
|
||||
modify fun s => { s with
|
||||
newLocalDeclsForMVars := s.newLocalDeclsForMVars.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type BinderInfo.default,
|
||||
exprMVarArgs := s.exprMVarArgs.push e
|
||||
}
|
||||
return mkFVar newFVarId
|
||||
| e => pure e
|
||||
| Expr.fvar fvarId _ =>
|
||||
match (← read).zeta, (← getLocalDecl fvarId).value? with
|
||||
| true, some value => collect (← preprocess value)
|
||||
| _, _ =>
|
||||
let newFVarId ← mkFreshFVarId
|
||||
pushToProcess ⟨fvarId, newFVarId⟩
|
||||
return mkFVar newFVarId
|
||||
| e => pure e
|
||||
|
||||
def collectExpr (e : Expr) : ClosureM Expr := do
|
||||
let e ← preprocess e
|
||||
visitExpr collectExprAux e
|
||||
let e ← preprocess e
|
||||
visitExpr collectExprAux e
|
||||
|
||||
partial def pickNextToProcessAux (lctx : LocalContext) (i : Nat) (toProcess : Array ToProcessElement) (elem : ToProcessElement)
|
||||
: ToProcessElement × Array ToProcessElement :=
|
||||
if h : i < toProcess.size then
|
||||
let elem' := toProcess.get ⟨i, h⟩
|
||||
if (lctx.get! elem.fvarId).index < (lctx.get! elem'.fvarId).index then
|
||||
pickNextToProcessAux lctx (i+1) (toProcess.set ⟨i, h⟩ elem) elem'
|
||||
if h : i < toProcess.size then
|
||||
let elem' := toProcess.get ⟨i, h⟩
|
||||
if (lctx.get! elem.fvarId).index < (lctx.get! elem'.fvarId).index then
|
||||
pickNextToProcessAux lctx (i+1) (toProcess.set ⟨i, h⟩ elem) elem'
|
||||
else
|
||||
pickNextToProcessAux lctx (i+1) toProcess elem
|
||||
else
|
||||
pickNextToProcessAux lctx (i+1) toProcess elem
|
||||
else
|
||||
(elem, toProcess)
|
||||
(elem, toProcess)
|
||||
|
||||
def pickNextToProcess? : ClosureM (Option ToProcessElement) := do
|
||||
let lctx ← getLCtx
|
||||
let s ← get
|
||||
if s.toProcess.isEmpty then pure none
|
||||
else
|
||||
modifyGet fun s =>
|
||||
let elem := s.toProcess.back
|
||||
let toProcess := s.toProcess.pop
|
||||
let (elem, toProcess) := pickNextToProcessAux lctx 0 toProcess elem
|
||||
(some elem, { s with toProcess := toProcess })
|
||||
let lctx ← getLCtx
|
||||
let s ← get
|
||||
if s.toProcess.isEmpty then
|
||||
pure none
|
||||
else
|
||||
modifyGet fun s =>
|
||||
let elem := s.toProcess.back
|
||||
let toProcess := s.toProcess.pop
|
||||
let (elem, toProcess) := pickNextToProcessAux lctx 0 toProcess elem
|
||||
(some elem, { s with toProcess := toProcess })
|
||||
|
||||
def pushFVarArg (e : Expr) : ClosureM Unit :=
|
||||
modify fun s => { s with exprFVarArgs := s.exprFVarArgs.push e }
|
||||
modify fun s => { s with exprFVarArgs := s.exprFVarArgs.push e }
|
||||
|
||||
def pushLocalDecl (newFVarId : FVarId) (userName : Name) (type : Expr) (bi := BinderInfo.default) : ClosureM Unit := do
|
||||
let type ← collectExpr type
|
||||
modify fun s => { s with newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type bi }
|
||||
let type ← collectExpr type
|
||||
modify fun s => { s with newLocalDecls := s.newLocalDecls.push $ LocalDecl.cdecl (arbitrary _) newFVarId userName type bi }
|
||||
|
||||
partial def process : ClosureM Unit := do
|
||||
match (← pickNextToProcess?) with
|
||||
| none => pure ()
|
||||
| some ⟨fvarId, newFVarId⟩ =>
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
match localDecl with
|
||||
| LocalDecl.cdecl _ _ userName type bi =>
|
||||
pushLocalDecl newFVarId userName type bi
|
||||
pushFVarArg (mkFVar fvarId)
|
||||
process
|
||||
| LocalDecl.ldecl _ _ userName type val _ =>
|
||||
let zetaFVarIds ← getZetaFVarIds
|
||||
if !zetaFVarIds.contains fvarId then
|
||||
/- Non-dependent let-decl
|
||||
|
||||
Recall that if `fvarId` is in `zetaFVarIds`, then we zeta-expanded it
|
||||
during type checking (see `check` at `collectExpr`).
|
||||
|
||||
Our type checker may zeta-expand declarations that are not needed, but this
|
||||
check is conservative, and seems to work well in practice. -/
|
||||
pushLocalDecl newFVarId userName type
|
||||
match (← pickNextToProcess?) with
|
||||
| none => pure ()
|
||||
| some ⟨fvarId, newFVarId⟩ =>
|
||||
let localDecl ← getLocalDecl fvarId
|
||||
match localDecl with
|
||||
| LocalDecl.cdecl _ _ userName type bi =>
|
||||
pushLocalDecl newFVarId userName type bi
|
||||
pushFVarArg (mkFVar fvarId)
|
||||
process
|
||||
else
|
||||
/- Dependent let-decl -/
|
||||
let type ← collectExpr type
|
||||
let val ← collectExpr val
|
||||
modify fun s => { s with newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) newFVarId userName type val false }
|
||||
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of newFVarId
|
||||
at `newLocalDecls` -/
|
||||
modify fun s => { s with newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl newFVarId val) }
|
||||
process
|
||||
| LocalDecl.ldecl _ _ userName type val _ =>
|
||||
let zetaFVarIds ← getZetaFVarIds
|
||||
if !zetaFVarIds.contains fvarId then
|
||||
/- Non-dependent let-decl
|
||||
|
||||
Recall that if `fvarId` is in `zetaFVarIds`, then we zeta-expanded it
|
||||
during type checking (see `check` at `collectExpr`).
|
||||
|
||||
Our type checker may zeta-expand declarations that are not needed, but this
|
||||
check is conservative, and seems to work well in practice. -/
|
||||
pushLocalDecl newFVarId userName type
|
||||
pushFVarArg (mkFVar fvarId)
|
||||
process
|
||||
else
|
||||
/- Dependent let-decl -/
|
||||
let type ← collectExpr type
|
||||
let val ← collectExpr val
|
||||
modify fun s => { s with newLetDecls := s.newLetDecls.push $ LocalDecl.ldecl (arbitrary _) newFVarId userName type val false }
|
||||
/- We don't want to interleave let and lambda declarations in our closure. So, we expand any occurrences of newFVarId
|
||||
at `newLocalDecls` -/
|
||||
modify fun s => { s with newLocalDecls := s.newLocalDecls.map (replaceFVarIdAtLocalDecl newFVarId val) }
|
||||
process
|
||||
|
||||
@[inline] def mkBinding (isLambda : Bool) (decls : Array LocalDecl) (b : Expr) : Expr :=
|
||||
let xs := decls.map LocalDecl.toExpr
|
||||
let b := b.abstract xs
|
||||
decls.size.foldRev
|
||||
(fun i b =>
|
||||
let xs := decls.map LocalDecl.toExpr
|
||||
let b := b.abstract xs
|
||||
decls.size.foldRev (init := b) fun i b =>
|
||||
let decl := decls[i]
|
||||
match decl with
|
||||
| LocalDecl.cdecl _ _ n ty bi =>
|
||||
|
|
@ -302,64 +303,63 @@ decls.size.foldRev
|
|||
let val := val.abstractRange i xs
|
||||
mkLet n ty val b nonDep
|
||||
else
|
||||
b.lowerLooseBVars 1 1)
|
||||
b
|
||||
b.lowerLooseBVars 1 1
|
||||
|
||||
def mkLambda (decls : Array LocalDecl) (b : Expr) : Expr :=
|
||||
mkBinding true decls b
|
||||
mkBinding true decls b
|
||||
|
||||
def mkForall (decls : Array LocalDecl) (b : Expr) : Expr :=
|
||||
mkBinding false decls b
|
||||
mkBinding false decls b
|
||||
|
||||
structure MkValueTypeClosureResult :=
|
||||
(levelParams : Array Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
(levelArgs : Array Level)
|
||||
(exprArgs : Array Expr)
|
||||
(levelParams : Array Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
(levelArgs : Array Level)
|
||||
(exprArgs : Array Expr)
|
||||
|
||||
def mkValueTypeClosureAux (type : Expr) (value : Expr) : ClosureM (Expr × Expr) := do
|
||||
resetZetaFVarIds
|
||||
withTrackingZeta do
|
||||
let type ← collectExpr type
|
||||
let value ← collectExpr value
|
||||
process
|
||||
pure (type, value)
|
||||
resetZetaFVarIds
|
||||
withTrackingZeta do
|
||||
let type ← collectExpr type
|
||||
let value ← collectExpr value
|
||||
process
|
||||
pure (type, value)
|
||||
|
||||
def mkValueTypeClosure (type : Expr) (value : Expr) (zeta : Bool) : MetaM MkValueTypeClosureResult := do
|
||||
let ((type, value), s) ← ((mkValueTypeClosureAux type value).run { zeta := zeta }).run {}
|
||||
let newLocalDecls := s.newLocalDecls.reverse ++ s.newLocalDeclsForMVars
|
||||
let newLetDecls := s.newLetDecls.reverse
|
||||
let type := mkForall newLocalDecls (mkForall newLetDecls type)
|
||||
let value := mkLambda newLocalDecls (mkLambda newLetDecls value)
|
||||
pure {
|
||||
type := type,
|
||||
value := value,
|
||||
levelParams := s.levelParams,
|
||||
levelArgs := s.levelArgs,
|
||||
exprArgs := s.exprFVarArgs.reverse ++ s.exprMVarArgs
|
||||
}
|
||||
let ((type, value), s) ← ((mkValueTypeClosureAux type value).run { zeta := zeta }).run {}
|
||||
let newLocalDecls := s.newLocalDecls.reverse ++ s.newLocalDeclsForMVars
|
||||
let newLetDecls := s.newLetDecls.reverse
|
||||
let type := mkForall newLocalDecls (mkForall newLetDecls type)
|
||||
let value := mkLambda newLocalDecls (mkLambda newLetDecls value)
|
||||
pure {
|
||||
type := type,
|
||||
value := value,
|
||||
levelParams := s.levelParams,
|
||||
levelArgs := s.levelArgs,
|
||||
exprArgs := s.exprFVarArgs.reverse ++ s.exprMVarArgs
|
||||
}
|
||||
|
||||
end Closure
|
||||
|
||||
variables {m : Type → Type} [MonadLiftT MetaM m]
|
||||
|
||||
private def mkAuxDefinitionImp (name : Name) (type : Expr) (value : Expr) (zeta : Bool) (compile : Bool) : MetaM Expr := do
|
||||
let result ← Closure.mkValueTypeClosure type value zeta
|
||||
let env ← getEnv
|
||||
let decl := Declaration.defnDecl {
|
||||
name := name,
|
||||
lparams := result.levelParams.toList,
|
||||
type := result.type,
|
||||
value := result.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight env result.value + 1),
|
||||
isUnsafe := env.hasUnsafe result.type || env.hasUnsafe result.value
|
||||
}
|
||||
trace[Meta.debug]! "{name} : {result.type} := {result.value}"
|
||||
addDecl decl
|
||||
if compile then
|
||||
compileDecl decl
|
||||
pure $ mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
|
||||
let result ← Closure.mkValueTypeClosure type value zeta
|
||||
let env ← getEnv
|
||||
let decl := Declaration.defnDecl {
|
||||
name := name,
|
||||
lparams := result.levelParams.toList,
|
||||
type := result.type,
|
||||
value := result.value,
|
||||
hints := ReducibilityHints.regular (getMaxHeight env result.value + 1),
|
||||
isUnsafe := env.hasUnsafe result.type || env.hasUnsafe result.value
|
||||
}
|
||||
trace[Meta.debug]! "{name} : {result.type} := {result.value}"
|
||||
addDecl decl
|
||||
if compile then
|
||||
compileDecl decl
|
||||
pure $ mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
|
||||
|
||||
/--
|
||||
Create an auxiliary definition with the given name, type and value.
|
||||
|
|
@ -368,13 +368,13 @@ pure $ mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
|
|||
returned where `u_i`s are universe parameters and metavariables `type` and `value` depend on,
|
||||
and `t_j`s are free and meta variables `type` and `value` depend on. -/
|
||||
def mkAuxDefinition (name : Name) (type : Expr) (value : Expr) (zeta := false) (compile := true) : m Expr := liftMetaM do
|
||||
trace[Meta.debug]! "{name} : {type} := {value}"
|
||||
mkAuxDefinitionImp name type value zeta compile
|
||||
trace[Meta.debug]! "{name} : {type} := {value}"
|
||||
mkAuxDefinitionImp name type value zeta compile
|
||||
|
||||
/-- Similar to `mkAuxDefinition`, but infers the type of `value`. -/
|
||||
def mkAuxDefinitionFor (name : Name) (value : Expr) : m Expr := liftMetaM do
|
||||
let type ← inferType value
|
||||
let type := type.headBeta
|
||||
mkAuxDefinition name type value
|
||||
let type ← inferType value
|
||||
let type := type.headBeta
|
||||
mkAuxDefinition name type value
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -18,42 +18,42 @@ namespace Lean.Meta
|
|||
we collect `?m` and unassigned metavariables occurring in `t`.
|
||||
We collect `?m` because it has not been assigned yet. -/
|
||||
partial def collectMVars (e : Expr) : StateRefT CollectMVars.State MetaM Unit := do
|
||||
let e ← instantiateMVars e
|
||||
let s ← get
|
||||
let resultSavedSize := s.result.size
|
||||
let s := e.collectMVars s
|
||||
set s
|
||||
for mvarId in s.result[resultSavedSize:] do
|
||||
match (← getDelayedAssignment? mvarId) with
|
||||
| none => pure ()
|
||||
| some d => collectMVars d.val
|
||||
let e ← instantiateMVars e
|
||||
let s ← get
|
||||
let resultSavedSize := s.result.size
|
||||
let s := e.collectMVars s
|
||||
set s
|
||||
for mvarId in s.result[resultSavedSize:] do
|
||||
match (← getDelayedAssignment? mvarId) with
|
||||
| none => pure ()
|
||||
| some d => collectMVars d.val
|
||||
|
||||
variables {m : Type → Type} [MonadLiftT MetaM m]
|
||||
|
||||
def getMVarsImp (e : Expr) : MetaM (Array MVarId) := do
|
||||
let (_, s) ← (collectMVars e).run {}
|
||||
pure s.result
|
||||
let (_, s) ← (collectMVars e).run {}
|
||||
pure s.result
|
||||
|
||||
/-- Return metavariables in occuring the given expression. See `collectMVars` -/
|
||||
def getMVars (e : Expr) : m (Array MVarId) :=
|
||||
liftM $ getMVarsImp e
|
||||
liftM $ getMVarsImp e
|
||||
|
||||
def getMVarsNoDelayedImp (e : Expr) : MetaM (Array MVarId) := do
|
||||
let mvarIds ← getMVars e
|
||||
mvarIds.filterM fun mvarId => not <$> isDelayedAssigned mvarId
|
||||
let mvarIds ← getMVars e
|
||||
mvarIds.filterM fun mvarId => not <$> isDelayedAssigned mvarId
|
||||
|
||||
/-- Similar to getMVars, but removes delayed assignments. -/
|
||||
def getMVarsNoDelayed (e : Expr) : m (Array MVarId) :=
|
||||
liftM $ getMVarsNoDelayedImp e
|
||||
liftM $ getMVarsNoDelayedImp e
|
||||
|
||||
def collectMVarsAtDecl (d : Declaration) : StateRefT CollectMVars.State MetaM Unit :=
|
||||
d.forExprM collectMVars
|
||||
d.forExprM collectMVars
|
||||
|
||||
def getMVarsAtDeclImp (d : Declaration) : MetaM (Array MVarId) := do
|
||||
let (_, s) ← (collectMVarsAtDecl d).run {}
|
||||
pure s.result
|
||||
let (_, s) ← (collectMVarsAtDecl d).run {}
|
||||
pure s.result
|
||||
|
||||
def getMVarsAtDecl (d : Declaration) : m (Array MVarId) :=
|
||||
liftM $ getMVarsAtDeclImp d
|
||||
liftM $ getMVarsAtDeclImp d
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
|
|
@ -49,35 +49,35 @@ namespace Lean.Meta.DiscrTree
|
|||
-/
|
||||
|
||||
def Key.ctorIdx : Key → Nat
|
||||
| Key.star => 0
|
||||
| Key.other => 1
|
||||
| Key.lit _ => 2
|
||||
| Key.fvar _ _ => 3
|
||||
| Key.const _ _ => 4
|
||||
| Key.star => 0
|
||||
| Key.other => 1
|
||||
| Key.lit _ => 2
|
||||
| Key.fvar _ _ => 3
|
||||
| Key.const _ _ => 4
|
||||
|
||||
def Key.lt : Key → Key → Bool
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ < v₂
|
||||
| Key.fvar n₁ a₁, Key.fvar n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
|
||||
| Key.const n₁ a₁, Key.const n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
|
||||
| k₁, k₂ => k₁.ctorIdx < k₂.ctorIdx
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ < v₂
|
||||
| Key.fvar n₁ a₁, Key.fvar n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
|
||||
| Key.const n₁ a₁, Key.const n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
|
||||
| k₁, k₂ => k₁.ctorIdx < k₂.ctorIdx
|
||||
|
||||
instance : HasLess Key := ⟨fun a b => Key.lt a b⟩
|
||||
instance (a b : Key) : Decidable (a < b) := inferInstanceAs (Decidable (Key.lt a b))
|
||||
|
||||
def Key.format : Key → Format
|
||||
| Key.star => "*"
|
||||
| Key.other => "◾"
|
||||
| Key.lit (Literal.natVal v) => fmt v
|
||||
| Key.lit (Literal.strVal v) => repr v
|
||||
| Key.const k _ => fmt k
|
||||
| Key.fvar k _ => fmt k
|
||||
| Key.star => "*"
|
||||
| Key.other => "◾"
|
||||
| Key.lit (Literal.natVal v) => fmt v
|
||||
| Key.lit (Literal.strVal v) => repr v
|
||||
| Key.const k _ => fmt k
|
||||
| Key.fvar k _ => fmt k
|
||||
|
||||
instance : HasFormat Key := ⟨Key.format⟩
|
||||
|
||||
def Key.arity : Key → Nat
|
||||
| Key.const _ a => a
|
||||
| Key.fvar _ a => a
|
||||
| _ => 0
|
||||
| Key.const _ a => a
|
||||
| Key.fvar _ a => a
|
||||
| _ => 0
|
||||
|
||||
instance {α} : Inhabited (Trie α) := ⟨Trie.node #[] #[]⟩
|
||||
|
||||
|
|
@ -121,27 +121,26 @@ instance {α} : Inhabited (DiscrTree α) := ⟨{}⟩
|
|||
and `ignoreArg` would return true for any term of the form `noIndexing t`.
|
||||
-/
|
||||
private def ignoreArg (a : Expr) (i : Nat) (infos : Array ParamInfo) : MetaM Bool :=
|
||||
if h : i < infos.size then
|
||||
let info := infos.get ⟨i, h⟩
|
||||
if info.instImplicit then
|
||||
pure true
|
||||
else if info.implicit then
|
||||
not <$> isType a
|
||||
if h : i < infos.size then
|
||||
let info := infos.get ⟨i, h⟩
|
||||
if info.instImplicit then
|
||||
pure true
|
||||
else if info.implicit then
|
||||
not <$> isType a
|
||||
else
|
||||
isProof a
|
||||
else
|
||||
isProof a
|
||||
else
|
||||
isProof a
|
||||
|
||||
private partial def pushArgsAux (infos : Array ParamInfo) : Nat → Expr → Array Expr → MetaM (Array Expr)
|
||||
| i, Expr.app f a _, todo => do
|
||||
if (← ignoreArg a i infos) then
|
||||
pushArgsAux infos (i-1) f (todo.push tmpStar)
|
||||
else
|
||||
pushArgsAux infos (i-1) f (todo.push a)
|
||||
| _, _, todo => pure todo
|
||||
| i, Expr.app f a _, todo => do
|
||||
if (← ignoreArg a i infos) then
|
||||
pushArgsAux infos (i-1) f (todo.push tmpStar)
|
||||
else
|
||||
pushArgsAux infos (i-1) f (todo.push a)
|
||||
| _, _, todo => pure todo
|
||||
|
||||
private partial def whnfEta : Expr → MetaM Expr
|
||||
| e => do
|
||||
private partial def whnfEta (e : Expr) : MetaM Expr := do
|
||||
let e ← whnf e
|
||||
match e.etaExpandedStrict? with
|
||||
| some e => whnfEta e
|
||||
|
|
@ -152,38 +151,37 @@ private partial def whnfEta : Expr → MetaM Expr
|
|||
Then, `DiscrTree` users may control which symbols should be treated as wildcards.
|
||||
Different `DiscrTree` users may populate this set using, for example, attributes. -/
|
||||
private def shouldAddAsStar (constName : Name) : Bool :=
|
||||
constName == `Nat.zero || constName == `Nat.succ || constName == `Nat.add || constName == `HasAdd.add
|
||||
constName == `Nat.zero || constName == `Nat.succ || constName == `Nat.add || constName == `HasAdd.add
|
||||
|
||||
private def pushArgs (todo : Array Expr) (e : Expr) : MetaM (Key × Array Expr) := do
|
||||
let e ← whnfEta e
|
||||
let fn := e.getAppFn
|
||||
let push (k : Key) (nargs : Nat) : MetaM (Key × Array Expr) := do
|
||||
let info ← getFunInfoNArgs fn nargs
|
||||
let todo ← pushArgsAux info.paramInfo (nargs-1) e todo
|
||||
pure (k, todo)
|
||||
match fn with
|
||||
| Expr.lit v _ => pure (Key.lit v, todo)
|
||||
| Expr.const c _ _ =>
|
||||
if shouldAddAsStar c then
|
||||
pure (Key.star, todo)
|
||||
else
|
||||
let e ← whnfEta e
|
||||
let fn := e.getAppFn
|
||||
let push (k : Key) (nargs : Nat) : MetaM (Key × Array Expr) := do
|
||||
let info ← getFunInfoNArgs fn nargs
|
||||
let todo ← pushArgsAux info.paramInfo (nargs-1) e todo
|
||||
pure (k, todo)
|
||||
match fn with
|
||||
| Expr.lit v _ => pure (Key.lit v, todo)
|
||||
| Expr.const c _ _ =>
|
||||
if shouldAddAsStar c then
|
||||
pure (Key.star, todo)
|
||||
else
|
||||
let nargs := e.getAppNumArgs
|
||||
push (Key.const c nargs) nargs
|
||||
| Expr.fvar fvarId _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
push (Key.const c nargs) nargs
|
||||
| Expr.fvar fvarId _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
push (Key.fvar fvarId nargs) nargs
|
||||
| Expr.mvar mvarId _ =>
|
||||
if mvarId == tmpMVarId then
|
||||
-- We use `tmp to mark implicit arguments and proofs
|
||||
pure (Key.star, todo)
|
||||
else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then
|
||||
pure (Key.other, todo)
|
||||
else
|
||||
pure (Key.star, todo)
|
||||
| _ => pure (Key.other, todo)
|
||||
push (Key.fvar fvarId nargs) nargs
|
||||
| Expr.mvar mvarId _ =>
|
||||
if mvarId == tmpMVarId then
|
||||
-- We use `tmp to mark implicit arguments and proofs
|
||||
pure (Key.star, todo)
|
||||
else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then
|
||||
pure (Key.other, todo)
|
||||
else
|
||||
pure (Key.star, todo)
|
||||
| _ => pure (Key.other, todo)
|
||||
|
||||
partial def mkPathAux : Array Expr → Array Key → MetaM (Array Key)
|
||||
| todo, keys => do
|
||||
partial def mkPathAux (todo : Array Expr) (keys : Array Key) : MetaM (Array Key) := do
|
||||
if todo.isEmpty then
|
||||
pure keys
|
||||
else
|
||||
|
|
@ -195,13 +193,12 @@ partial def mkPathAux : Array Expr → Array Key → MetaM (Array Key)
|
|||
private def initCapacity := 8
|
||||
|
||||
def mkPath (e : Expr) : MetaM (Array Key) :=
|
||||
withReducible do
|
||||
let todo : Array Expr := Array.mkEmpty initCapacity
|
||||
let keys : Array Key := Array.mkEmpty initCapacity
|
||||
mkPathAux (todo.push e) keys
|
||||
withReducible do
|
||||
let todo : Array Expr := Array.mkEmpty initCapacity
|
||||
let keys : Array Key := Array.mkEmpty initCapacity
|
||||
mkPathAux (todo.push e) keys
|
||||
|
||||
private partial def createNodes {α} (keys : Array Key) (v : α) : Nat → Trie α
|
||||
| i =>
|
||||
private partial def createNodes {α} (keys : Array Key) (v : α) (i : Nat) : Trie α :=
|
||||
if h : i < keys.size then
|
||||
let k := keys.get ⟨i, h⟩
|
||||
let c := createNodes keys v (i+1)
|
||||
|
|
@ -210,168 +207,168 @@ private partial def createNodes {α} (keys : Array Key) (v : α) : Nat → Trie
|
|||
Trie.node #[v] #[]
|
||||
|
||||
private def insertVal {α} [HasBeq α] (vs : Array α) (v : α) : Array α :=
|
||||
if vs.contains v then vs else vs.push v
|
||||
if vs.contains v then vs else vs.push v
|
||||
|
||||
private partial def insertAux {α} [HasBeq α] (keys : Array Key) (v : α) : Nat → Trie α → Trie α
|
||||
| i, Trie.node vs cs =>
|
||||
if h : i < keys.size then
|
||||
let k := keys.get ⟨i, h⟩
|
||||
let c := Id.run $ cs.binInsertM
|
||||
(fun a b => a.1 < b.1)
|
||||
(fun ⟨_, s⟩ => let c := insertAux keys v (i+1) s; (k, c)) -- merge with existing
|
||||
(fun _ => let c := createNodes keys v (i+1); (k, c))
|
||||
(k, arbitrary _)
|
||||
Trie.node vs c
|
||||
else
|
||||
Trie.node (insertVal vs v) cs
|
||||
| i, Trie.node vs cs =>
|
||||
if h : i < keys.size then
|
||||
let k := keys.get ⟨i, h⟩
|
||||
let c := Id.run $ cs.binInsertM
|
||||
(fun a b => a.1 < b.1)
|
||||
(fun ⟨_, s⟩ => let c := insertAux keys v (i+1) s; (k, c)) -- merge with existing
|
||||
(fun _ => let c := createNodes keys v (i+1); (k, c))
|
||||
(k, arbitrary _)
|
||||
Trie.node vs c
|
||||
else
|
||||
Trie.node (insertVal vs v) cs
|
||||
|
||||
def insertCore {α} [HasBeq α] (d : DiscrTree α) (keys : Array Key) (v : α) : DiscrTree α :=
|
||||
if keys.isEmpty then panic! "invalid key sequence"
|
||||
else
|
||||
let k := keys[0]
|
||||
match d.root.find? k with
|
||||
| none =>
|
||||
let c := createNodes keys v 1
|
||||
{ root := d.root.insert k c }
|
||||
| some c =>
|
||||
let c := insertAux keys v 1 c
|
||||
{ root := d.root.insert k c }
|
||||
if keys.isEmpty then panic! "invalid key sequence"
|
||||
else
|
||||
let k := keys[0]
|
||||
match d.root.find? k with
|
||||
| none =>
|
||||
let c := createNodes keys v 1
|
||||
{ root := d.root.insert k c }
|
||||
| some c =>
|
||||
let c := insertAux keys v 1 c
|
||||
{ root := d.root.insert k c }
|
||||
|
||||
def insert {α} [HasBeq α] (d : DiscrTree α) (e : Expr) (v : α) : MetaM (DiscrTree α) := do
|
||||
let keys ← mkPath e
|
||||
pure $ d.insertCore keys v
|
||||
let keys ← mkPath e
|
||||
pure $ d.insertCore keys v
|
||||
|
||||
partial def Trie.format {α} [HasFormat α] : Trie α → Format
|
||||
| Trie.node vs cs => Format.group $ Format.paren $
|
||||
"node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs)
|
||||
++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c))
|
||||
| Trie.node vs cs => Format.group $ Format.paren $
|
||||
"node" ++ (if vs.isEmpty then Format.nil else " " ++ fmt vs)
|
||||
++ Format.join (cs.toList.map $ fun ⟨k, c⟩ => Format.line ++ Format.paren (fmt k ++ " => " ++ format c))
|
||||
|
||||
instance {α} [HasFormat α] : HasFormat (Trie α) := ⟨Trie.format⟩
|
||||
|
||||
partial def format {α} [HasFormat α] (d : DiscrTree α) : Format :=
|
||||
let (_, r) := d.root.foldl
|
||||
(fun (p : Bool × Format) k c =>
|
||||
(false, p.2 ++ (if p.1 == true then Format.nil else Format.line) ++ Format.paren (fmt k ++ " => " ++ fmt c))) -- TODO: fix p.1 == true
|
||||
(true, Format.nil)
|
||||
Format.group r
|
||||
let (_, r) := d.root.foldl
|
||||
(fun (p : Bool × Format) k c =>
|
||||
(false, p.2 ++ (if p.1 == true then Format.nil else Format.line) ++ Format.paren (fmt k ++ " => " ++ fmt c))) -- TODO: fix p.1 == true
|
||||
(true, Format.nil)
|
||||
Format.group r
|
||||
|
||||
instance {α} [HasFormat α] : HasFormat (DiscrTree α) := ⟨format⟩
|
||||
|
||||
private def getKeyArgs (e : Expr) (isMatch? : Bool) : MetaM (Key × Array Expr) := do
|
||||
let e ← whnfEta e
|
||||
match e.getAppFn with
|
||||
| Expr.lit v _ => pure (Key.lit v, #[])
|
||||
| Expr.const c _ _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
pure (Key.const c nargs, e.getAppRevArgs)
|
||||
| Expr.fvar fvarId _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
pure (Key.fvar fvarId nargs, e.getAppRevArgs)
|
||||
| Expr.mvar mvarId _ =>
|
||||
if isMatch? then
|
||||
pure (Key.other, #[])
|
||||
else do
|
||||
let ctx ← read
|
||||
if ctx.config.isDefEqStuckEx then
|
||||
/-
|
||||
When the configuration flag `isDefEqStuckEx` is set to true,
|
||||
we want `isDefEq` to throw an exception whenever it tries to assign
|
||||
a read-only metavariable.
|
||||
This feature is useful for type class resolution where
|
||||
we may want to notify the caller that the TC problem may be solveable
|
||||
later after it assigns `?m`.
|
||||
The method `DiscrTree.getUnify e` returns candidates `c` that may "unify" with `e`.
|
||||
That is, `isDefEq c e` may return true. Now, consider `DiscrTree.getUnify d (HasAdd ?m)`
|
||||
where `?m` is a read-only metavariable, and the discrimination tree contains the keys
|
||||
`HadAdd Nat` and `HasAdd Int`. If `isDefEqStuckEx` is set to true, we must treat `?m` as
|
||||
a regular metavariable here, otherwise we return the empty set of candidates.
|
||||
This is incorrect because it is equivalent to saying that there is no solution even if
|
||||
the caller assigns `?m` and try again. -/
|
||||
pure (Key.star, #[])
|
||||
else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then
|
||||
let e ← whnfEta e
|
||||
match e.getAppFn with
|
||||
| Expr.lit v _ => pure (Key.lit v, #[])
|
||||
| Expr.const c _ _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
pure (Key.const c nargs, e.getAppRevArgs)
|
||||
| Expr.fvar fvarId _ =>
|
||||
let nargs := e.getAppNumArgs
|
||||
pure (Key.fvar fvarId nargs, e.getAppRevArgs)
|
||||
| Expr.mvar mvarId _ =>
|
||||
if isMatch? then
|
||||
pure (Key.other, #[])
|
||||
else
|
||||
pure (Key.star, #[])
|
||||
| _ => pure (Key.other, #[])
|
||||
else do
|
||||
let ctx ← read
|
||||
if ctx.config.isDefEqStuckEx then
|
||||
/-
|
||||
When the configuration flag `isDefEqStuckEx` is set to true,
|
||||
we want `isDefEq` to throw an exception whenever it tries to assign
|
||||
a read-only metavariable.
|
||||
This feature is useful for type class resolution where
|
||||
we may want to notify the caller that the TC problem may be solveable
|
||||
later after it assigns `?m`.
|
||||
The method `DiscrTree.getUnify e` returns candidates `c` that may "unify" with `e`.
|
||||
That is, `isDefEq c e` may return true. Now, consider `DiscrTree.getUnify d (HasAdd ?m)`
|
||||
where `?m` is a read-only metavariable, and the discrimination tree contains the keys
|
||||
`HadAdd Nat` and `HasAdd Int`. If `isDefEqStuckEx` is set to true, we must treat `?m` as
|
||||
a regular metavariable here, otherwise we return the empty set of candidates.
|
||||
This is incorrect because it is equivalent to saying that there is no solution even if
|
||||
the caller assigns `?m` and try again. -/
|
||||
pure (Key.star, #[])
|
||||
else if (← isReadOnlyOrSyntheticOpaqueExprMVar mvarId) then
|
||||
pure (Key.other, #[])
|
||||
else
|
||||
pure (Key.star, #[])
|
||||
| _ => pure (Key.other, #[])
|
||||
|
||||
private abbrev getMatchKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
|
||||
getKeyArgs e true
|
||||
getKeyArgs e true
|
||||
|
||||
private abbrev getUnifyKeyArgs (e : Expr) : MetaM (Key × Array Expr) :=
|
||||
getKeyArgs e false
|
||||
getKeyArgs e false
|
||||
|
||||
private partial def getMatchAux {α} : Array Expr → Trie α → Array α → MetaM (Array α)
|
||||
| todo, Trie.node vs cs, result =>
|
||||
if todo.isEmpty then pure $ result ++ vs
|
||||
else if cs.isEmpty then pure result
|
||||
else do
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
|
||||
let (k, args) ← getMatchKeyArgs e
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStarChild (result : Array α) : MetaM (Array α) :=
|
||||
if first.1 == Key.star then getMatchAux todo first.2 result else pure result
|
||||
match k with
|
||||
| Key.star => visitStarChild result
|
||||
| _ =>
|
||||
match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with
|
||||
| none => visitStarChild result
|
||||
| some c =>
|
||||
let result ← visitStarChild result
|
||||
getMatchAux (todo ++ args) c.2 result
|
||||
| todo, Trie.node vs cs, result =>
|
||||
if todo.isEmpty then pure $ result ++ vs
|
||||
else if cs.isEmpty then pure result
|
||||
else do
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
|
||||
let (k, args) ← getMatchKeyArgs e
|
||||
/- We must always visit `Key.star` edges since they are wildcards.
|
||||
Thus, `todo` is not used linearly when there is `Key.star` edge
|
||||
and there is an edge for `k` and `k != Key.star`. -/
|
||||
let visitStarChild (result : Array α) : MetaM (Array α) :=
|
||||
if first.1 == Key.star then getMatchAux todo first.2 result else pure result
|
||||
match k with
|
||||
| Key.star => visitStarChild result
|
||||
| _ =>
|
||||
match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with
|
||||
| none => visitStarChild result
|
||||
| some c =>
|
||||
let result ← visitStarChild result
|
||||
getMatchAux (todo ++ args) c.2 result
|
||||
|
||||
private def getStarResult {α} (d : DiscrTree α) : Array α :=
|
||||
let result : Array α := Array.mkEmpty initCapacity
|
||||
match d.root.find? Key.star with
|
||||
| none => result
|
||||
| some (Trie.node vs _) => result ++ vs
|
||||
let result : Array α := Array.mkEmpty initCapacity
|
||||
match d.root.find? Key.star with
|
||||
| none => result
|
||||
| some (Trie.node vs _) => result ++ vs
|
||||
|
||||
def getMatch {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
withReducible do
|
||||
let result := getStarResult d
|
||||
let (k, args) ← getMatchKeyArgs e
|
||||
match k with
|
||||
| Key.star => pure result
|
||||
| _ =>
|
||||
match d.root.find? k with
|
||||
| none => pure result
|
||||
| some c => getMatchAux args c result
|
||||
withReducible do
|
||||
let result := getStarResult d
|
||||
let (k, args) ← getMatchKeyArgs e
|
||||
match k with
|
||||
| Key.star => pure result
|
||||
| _ =>
|
||||
match d.root.find? k with
|
||||
| none => pure result
|
||||
| some c => getMatchAux args c result
|
||||
|
||||
private partial def getUnifyAux {α} : Nat → Array Expr → Trie α → (Array α) → MetaM (Array α)
|
||||
| skip+1, todo, Trie.node vs cs, result =>
|
||||
if cs.isEmpty then pure result
|
||||
else cs.foldlM (fun result ⟨k, c⟩ => getUnifyAux (skip + k.arity) todo c result) result
|
||||
| 0, todo, Trie.node vs cs, result => do
|
||||
if todo.isEmpty then pure (result ++ vs)
|
||||
else if cs.isEmpty then pure result
|
||||
else
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
let (k, args) ← getUnifyKeyArgs e
|
||||
match k with
|
||||
| Key.star => cs.foldlM (fun result ⟨k, c⟩ => getUnifyAux k.arity todo c result) result
|
||||
| _ =>
|
||||
let first := cs[0]
|
||||
let visitStarChild (result : Array α) : MetaM (Array α) :=
|
||||
if first.1 == Key.star then getUnifyAux 0 todo first.2 result else pure result
|
||||
match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with
|
||||
| none => visitStarChild result
|
||||
| some c =>
|
||||
let result ← visitStarChild result
|
||||
getUnifyAux 0 (todo ++ args) c.2 result
|
||||
| skip+1, todo, Trie.node vs cs, result =>
|
||||
if cs.isEmpty then pure result
|
||||
else cs.foldlM (fun result ⟨k, c⟩ => getUnifyAux (skip + k.arity) todo c result) result
|
||||
| 0, todo, Trie.node vs cs, result => do
|
||||
if todo.isEmpty then pure (result ++ vs)
|
||||
else if cs.isEmpty then pure result
|
||||
else
|
||||
let e := todo.back
|
||||
let todo := todo.pop
|
||||
let (k, args) ← getUnifyKeyArgs e
|
||||
match k with
|
||||
| Key.star => cs.foldlM (fun result ⟨k, c⟩ => getUnifyAux k.arity todo c result) result
|
||||
| _ =>
|
||||
let first := cs[0]
|
||||
let visitStarChild (result : Array α) : MetaM (Array α) :=
|
||||
if first.1 == Key.star then getUnifyAux 0 todo first.2 result else pure result
|
||||
match cs.binSearch (k, arbitrary _) (fun a b => a.1 < b.1) with
|
||||
| none => visitStarChild result
|
||||
| some c =>
|
||||
let result ← visitStarChild result
|
||||
getUnifyAux 0 (todo ++ args) c.2 result
|
||||
|
||||
def getUnify {α} (d : DiscrTree α) (e : Expr) : MetaM (Array α) :=
|
||||
withReducible do
|
||||
let (k, args) ← getUnifyKeyArgs e
|
||||
match k with
|
||||
| Key.star => d.root.foldlM (fun result k c => getUnifyAux k.arity #[] c result) #[]
|
||||
| _ =>
|
||||
let result := getStarResult d
|
||||
match d.root.find? k with
|
||||
| none => pure result
|
||||
| some c => getUnifyAux 0 args c result
|
||||
withReducible do
|
||||
let (k, args) ← getUnifyKeyArgs e
|
||||
match k with
|
||||
| Key.star => d.root.foldlM (fun result k c => getUnifyAux k.arity #[] c result) #[]
|
||||
| _ =>
|
||||
let result := getStarResult d
|
||||
match d.root.find? k with
|
||||
| none => pure result
|
||||
| some c => getUnifyAux 0 args c result
|
||||
|
||||
end Lean.Meta.DiscrTree
|
||||
|
|
|
|||
|
|
@ -12,35 +12,35 @@ namespace Lean.Meta
|
|||
namespace DiscrTree
|
||||
|
||||
inductive Key
|
||||
| const : Name → Nat → Key
|
||||
| fvar : FVarId → Nat → Key
|
||||
| lit : Literal → Key
|
||||
| star : Key
|
||||
| other : Key
|
||||
| const : Name → Nat → Key
|
||||
| fvar : FVarId → Nat → Key
|
||||
| lit : Literal → Key
|
||||
| star : Key
|
||||
| other : Key
|
||||
|
||||
instance : Inhabited Key := ⟨Key.star⟩
|
||||
|
||||
protected def Key.hash : Key → USize
|
||||
| Key.const n a => mixHash 5237 $ mixHash (hash n) (hash a)
|
||||
| Key.fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
|
||||
| Key.lit v => mixHash 1879 $ hash v
|
||||
| Key.star => 7883
|
||||
| Key.other => 2411
|
||||
| Key.const n a => mixHash 5237 $ mixHash (hash n) (hash a)
|
||||
| Key.fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
|
||||
| Key.lit v => mixHash 1879 $ hash v
|
||||
| Key.star => 7883
|
||||
| Key.other => 2411
|
||||
|
||||
instance : Hashable Key := ⟨Key.hash⟩
|
||||
|
||||
protected def Key.beq : Key → Key → Bool
|
||||
| Key.const c₁ a₁, Key.const c₂ a₂ => c₁ == c₂ && a₁ == a₂
|
||||
| Key.fvar c₁ a₁, Key.fvar c₂ a₂ => c₁ == c₂ && a₁ == a₂
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ == v₂
|
||||
| Key.star, Key.star => true
|
||||
| Key.other, Key.other => true
|
||||
| _, _ => false
|
||||
| Key.const c₁ a₁, Key.const c₂ a₂ => c₁ == c₂ && a₁ == a₂
|
||||
| Key.fvar c₁ a₁, Key.fvar c₂ a₂ => c₁ == c₂ && a₁ == a₂
|
||||
| Key.lit v₁, Key.lit v₂ => v₁ == v₂
|
||||
| Key.star, Key.star => true
|
||||
| Key.other, Key.other => true
|
||||
| _, _ => false
|
||||
|
||||
instance : HasBeq Key := ⟨Key.beq⟩
|
||||
|
||||
inductive Trie (α : Type)
|
||||
| node (vs : Array α) (children : Array (Key × Trie α)) : Trie α
|
||||
| node (vs : Array α) (children : Array (Key × Trie α)) : Trie α
|
||||
|
||||
end DiscrTree
|
||||
|
||||
|
|
@ -48,6 +48,6 @@ open DiscrTree
|
|||
open Std (PersistentHashMap)
|
||||
|
||||
structure DiscrTree (α : Type) :=
|
||||
(root : PersistentHashMap Key (Trie α) := {})
|
||||
(root : PersistentHashMap Key (Trie α) := {})
|
||||
|
||||
end Lean.Meta
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -9,169 +9,168 @@ import Lean.Meta.InferType
|
|||
namespace Lean.Meta
|
||||
|
||||
private partial def decAux? : Level → MetaM (Option Level)
|
||||
| Level.zero _ => pure none
|
||||
| Level.param _ _ => pure none
|
||||
| Level.mvar mvarId _ => do
|
||||
let mctx ← getMCtx
|
||||
match mctx.getLevelAssignment? mvarId with
|
||||
| some u => decAux? u
|
||||
| none =>
|
||||
if (← isReadOnlyLevelMVar mvarId) then
|
||||
pure none
|
||||
else
|
||||
let u ← mkFreshLevelMVar
|
||||
assignLevelMVar mvarId (mkLevelSucc u)
|
||||
pure u
|
||||
| Level.succ u _ => pure u
|
||||
| u =>
|
||||
let process (u v : Level) : MetaM (Option Level) := do
|
||||
match (← decAux? u) with
|
||||
| none => pure none
|
||||
| some u => do
|
||||
match (← decAux? v) with
|
||||
| Level.zero _ => pure none
|
||||
| Level.param _ _ => pure none
|
||||
| Level.mvar mvarId _ => do
|
||||
let mctx ← getMCtx
|
||||
match mctx.getLevelAssignment? mvarId with
|
||||
| some u => decAux? u
|
||||
| none =>
|
||||
if (← isReadOnlyLevelMVar mvarId) then
|
||||
pure none
|
||||
else
|
||||
let u ← mkFreshLevelMVar
|
||||
assignLevelMVar mvarId (mkLevelSucc u)
|
||||
pure u
|
||||
| Level.succ u _ => pure u
|
||||
| u =>
|
||||
let process (u v : Level) : MetaM (Option Level) := do
|
||||
match (← decAux? u) with
|
||||
| none => pure none
|
||||
| some v => pure $ mkLevelMax u v
|
||||
match u with
|
||||
| Level.max u v _ => process u v
|
||||
/- Remark: If `decAux? v` returns `some ...`, then `imax u v` is equivalent to `max u v`. -/
|
||||
| Level.imax u v _ => process u v
|
||||
| _ => unreachable!
|
||||
| some u => do
|
||||
match (← decAux? v) with
|
||||
| none => pure none
|
||||
| some v => pure $ mkLevelMax u v
|
||||
match u with
|
||||
| Level.max u v _ => process u v
|
||||
/- Remark: If `decAux? v` returns `some ...`, then `imax u v` is equivalent to `max u v`. -/
|
||||
| Level.imax u v _ => process u v
|
||||
| _ => unreachable!
|
||||
|
||||
variables {m : Type → Type} [MonadLiftT MetaM m]
|
||||
|
||||
private def decLevelImp (u : Level) : MetaM (Option Level) := do
|
||||
let mctx ← getMCtx
|
||||
match (← decAux? u) with
|
||||
| some v => pure $ some v
|
||||
| none => do
|
||||
modify fun s => { s with mctx := mctx }
|
||||
pure none
|
||||
let mctx ← getMCtx
|
||||
match (← decAux? u) with
|
||||
| some v => pure $ some v
|
||||
| none => do
|
||||
modify fun s => { s with mctx := mctx }
|
||||
pure none
|
||||
|
||||
def decLevel? (u : Level) : m (Option Level) :=
|
||||
liftMetaM $ decLevelImp u
|
||||
liftMetaM $ decLevelImp u
|
||||
|
||||
def decLevel (u : Level) : m Level := liftMetaM do
|
||||
match (← decLevel? u) with
|
||||
| some u => pure u
|
||||
| none => throwError! "invalid universe level, {u} is not greater than 0"
|
||||
match (← decLevel? u) with
|
||||
| some u => pure u
|
||||
| none => throwError! "invalid universe level, {u} is not greater than 0"
|
||||
|
||||
/- This method is useful for inferring universe level parameters for function that take arguments such as `{α : Type u}`.
|
||||
Recall that `Type u` is `Sort (u+1)` in Lean. Thus, given `α`, we must infer its universe level,
|
||||
and then decrement 1 to obtain `u`. -/
|
||||
def getDecLevel (type : Expr) : m Level := liftMetaM do
|
||||
let u ← getLevel type
|
||||
decLevel u
|
||||
let u ← getLevel type
|
||||
decLevel u
|
||||
|
||||
private def strictOccursMaxAux (lvl : Level) : Level → Bool
|
||||
| Level.max u v _ => strictOccursMaxAux lvl u || strictOccursMaxAux lvl v
|
||||
| u => u != lvl && lvl.occurs u
|
||||
| Level.max u v _ => strictOccursMaxAux lvl u || strictOccursMaxAux lvl v
|
||||
| u => u != lvl && lvl.occurs u
|
||||
|
||||
/--
|
||||
Return true iff `lvl` occurs in `max u_1 ... u_n` and `lvl != u_i` for all `i in [1, n]`.
|
||||
That is, `lvl` is a proper level subterm of some `u_i`. -/
|
||||
private def strictOccursMax (lvl : Level) : Level → Bool
|
||||
| Level.max u v _ => strictOccursMaxAux lvl u || strictOccursMaxAux lvl v
|
||||
| _ => false
|
||||
| Level.max u v _ => strictOccursMaxAux lvl u || strictOccursMaxAux lvl v
|
||||
| _ => false
|
||||
|
||||
/-- `mkMaxArgsDiff mvarId (max u_1 ... (mvar mvarId) ... u_n) v` => `max v u_1 ... u_n` -/
|
||||
private def mkMaxArgsDiff (mvarId : MVarId) : Level → Level → Level
|
||||
| Level.max u v _, acc => mkMaxArgsDiff mvarId v $ mkMaxArgsDiff mvarId u acc
|
||||
| l@(Level.mvar id _), acc => if id != mvarId then mkLevelMax acc l else acc
|
||||
| l, acc => mkLevelMax acc l
|
||||
| Level.max u v _, acc => mkMaxArgsDiff mvarId v $ mkMaxArgsDiff mvarId u acc
|
||||
| l@(Level.mvar id _), acc => if id != mvarId then mkLevelMax acc l else acc
|
||||
| l, acc => mkLevelMax acc l
|
||||
|
||||
/--
|
||||
Solve `?m =?= max ?m v` by creating a fresh metavariable `?n`
|
||||
and assigning `?m := max ?n v` -/
|
||||
private def solveSelfMax (mvarId : MVarId) (v : Level) : MetaM Unit := do
|
||||
let n ← mkFreshLevelMVar
|
||||
assignLevelMVar mvarId $ mkMaxArgsDiff mvarId v n
|
||||
let n ← mkFreshLevelMVar
|
||||
assignLevelMVar mvarId $ mkMaxArgsDiff mvarId v n
|
||||
|
||||
private def postponeIsLevelDefEq (lhs : Level) (rhs : Level) : DefEqM Unit :=
|
||||
modify fun postponed => postponed.push { lhs := lhs, rhs := rhs }
|
||||
modify fun postponed => postponed.push { lhs := lhs, rhs := rhs }
|
||||
|
||||
mutual
|
||||
private partial def solve (u v : Level) : DefEqM LBool := do
|
||||
match u, v with
|
||||
| Level.mvar mvarId _, _ =>
|
||||
if (← isReadOnlyLevelMVar mvarId) then
|
||||
pure LBool.undef
|
||||
else if !u.occurs v then
|
||||
assignLevelMVar u.mvarId! v
|
||||
pure LBool.true
|
||||
else if !strictOccursMax u v then
|
||||
solveSelfMax u.mvarId! v
|
||||
pure LBool.true
|
||||
else
|
||||
pure LBool.undef
|
||||
| Level.zero _, Level.max v₁ v₂ _ =>
|
||||
Bool.toLBool <$> (isLevelDefEqAux levelZero v₁ <&&> isLevelDefEqAux levelZero v₂)
|
||||
| Level.zero _, Level.imax _ v₂ _ =>
|
||||
Bool.toLBool <$> isLevelDefEqAux levelZero v₂
|
||||
| Level.succ u _, v =>
|
||||
match (← Meta.decLevel? v) with
|
||||
| some v => Bool.toLBool <$> isLevelDefEqAux u v
|
||||
| none => pure LBool.undef
|
||||
| _, _ => pure LBool.undef
|
||||
match u, v with
|
||||
| Level.mvar mvarId _, _ =>
|
||||
if (← isReadOnlyLevelMVar mvarId) then
|
||||
pure LBool.undef
|
||||
else if !u.occurs v then
|
||||
assignLevelMVar u.mvarId! v
|
||||
pure LBool.true
|
||||
else if !strictOccursMax u v then
|
||||
solveSelfMax u.mvarId! v
|
||||
pure LBool.true
|
||||
else
|
||||
pure LBool.undef
|
||||
| Level.zero _, Level.max v₁ v₂ _ =>
|
||||
Bool.toLBool <$> (isLevelDefEqAux levelZero v₁ <&&> isLevelDefEqAux levelZero v₂)
|
||||
| Level.zero _, Level.imax _ v₂ _ =>
|
||||
Bool.toLBool <$> isLevelDefEqAux levelZero v₂
|
||||
| Level.succ u _, v =>
|
||||
match (← Meta.decLevel? v) with
|
||||
| some v => Bool.toLBool <$> isLevelDefEqAux u v
|
||||
| none => pure LBool.undef
|
||||
| _, _ => pure LBool.undef
|
||||
|
||||
partial def isLevelDefEqAux : Level → Level → DefEqM Bool
|
||||
| Level.succ lhs _, Level.succ rhs _ => isLevelDefEqAux lhs rhs
|
||||
| lhs, rhs => do
|
||||
if lhs == rhs then
|
||||
pure true
|
||||
else
|
||||
trace[Meta.isLevelDefEq.step]! "{lhs} =?= {rhs}"
|
||||
let lhs' ← instantiateLevelMVars lhs
|
||||
let lhs' := lhs'.normalize
|
||||
let rhs' ← instantiateLevelMVars rhs
|
||||
let rhs' := rhs'.normalize
|
||||
if lhs != lhs' || rhs != rhs' then
|
||||
isLevelDefEqAux lhs' rhs'
|
||||
| Level.succ lhs _, Level.succ rhs _ => isLevelDefEqAux lhs rhs
|
||||
| lhs, rhs => do
|
||||
if lhs == rhs then
|
||||
pure true
|
||||
else
|
||||
let r ← solve lhs rhs;
|
||||
if r != LBool.undef then
|
||||
pure $ r == LBool.true
|
||||
trace[Meta.isLevelDefEq.step]! "{lhs} =?= {rhs}"
|
||||
let lhs' ← instantiateLevelMVars lhs
|
||||
let lhs' := lhs'.normalize
|
||||
let rhs' ← instantiateLevelMVars rhs
|
||||
let rhs' := rhs'.normalize
|
||||
if lhs != lhs' || rhs != rhs' then
|
||||
isLevelDefEqAux lhs' rhs'
|
||||
else
|
||||
let r ← solve rhs lhs;
|
||||
let r ← solve lhs rhs;
|
||||
if r != LBool.undef then
|
||||
pure $ r == LBool.true
|
||||
else do
|
||||
let mctx ← getMCtx
|
||||
if !mctx.hasAssignableLevelMVar lhs && !mctx.hasAssignableLevelMVar rhs then
|
||||
let ctx ← read
|
||||
if ctx.config.isDefEqStuckEx && (lhs.isMVar || rhs.isMVar) then do
|
||||
trace[Meta.isLevelDefEq.stuck]! "{lhs} =?= {rhs}"
|
||||
Meta.throwIsDefEqStuck
|
||||
else
|
||||
let r ← solve rhs lhs;
|
||||
if r != LBool.undef then
|
||||
pure $ r == LBool.true
|
||||
else do
|
||||
let mctx ← getMCtx
|
||||
if !mctx.hasAssignableLevelMVar lhs && !mctx.hasAssignableLevelMVar rhs then
|
||||
let ctx ← read
|
||||
if ctx.config.isDefEqStuckEx && (lhs.isMVar || rhs.isMVar) then do
|
||||
trace[Meta.isLevelDefEq.stuck]! "{lhs} =?= {rhs}"
|
||||
Meta.throwIsDefEqStuck
|
||||
else
|
||||
pure false
|
||||
else
|
||||
pure false
|
||||
else
|
||||
postponeIsLevelDefEq lhs rhs; pure true
|
||||
postponeIsLevelDefEq lhs rhs; pure true
|
||||
end
|
||||
|
||||
def isListLevelDefEqAux : List Level → List Level → DefEqM Bool
|
||||
| [], [] => pure true
|
||||
| u::us, v::vs => isLevelDefEqAux u v <&&> isListLevelDefEqAux us vs
|
||||
| _, _ => pure false
|
||||
| [], [] => pure true
|
||||
| u::us, v::vs => isLevelDefEqAux u v <&&> isListLevelDefEqAux us vs
|
||||
| _, _ => pure false
|
||||
|
||||
private def getNumPostponed : DefEqM Nat := do
|
||||
pure (← get).size
|
||||
pure (← get).size
|
||||
|
||||
open Std (PersistentArray)
|
||||
|
||||
private def getResetPostponed : DefEqM (PersistentArray PostponedEntry) := do
|
||||
let ps ← get
|
||||
modify fun _ => {}
|
||||
pure ps
|
||||
let ps ← get
|
||||
modify fun _ => {}
|
||||
pure ps
|
||||
|
||||
private def processPostponedStep : DefEqM Bool :=
|
||||
traceCtx `Meta.isLevelDefEq.postponed.step do
|
||||
let ps ← getResetPostponed
|
||||
for p in ps do
|
||||
unless (← isLevelDefEqAux p.lhs p.rhs) do
|
||||
return false
|
||||
return true
|
||||
traceCtx `Meta.isLevelDefEq.postponed.step do
|
||||
let ps ← getResetPostponed
|
||||
for p in ps do
|
||||
unless (← isLevelDefEqAux p.lhs p.rhs) do
|
||||
return false
|
||||
return true
|
||||
|
||||
private partial def processPostponedAux : Unit → DefEqM Bool
|
||||
| _ => do
|
||||
private partial def processPostponedAux : DefEqM Bool := do
|
||||
let numPostponed ← getNumPostponed
|
||||
if numPostponed == 0 then
|
||||
pure true
|
||||
|
|
@ -184,20 +183,22 @@ private partial def processPostponedAux : Unit → DefEqM Bool
|
|||
if numPostponed' == 0 then
|
||||
pure true
|
||||
else if numPostponed' < numPostponed then
|
||||
processPostponedAux ()
|
||||
processPostponedAux
|
||||
else do
|
||||
trace[Meta.isLevelDefEq.postponed]! "no progress solving pending is-def-eq level constraints"
|
||||
pure false
|
||||
|
||||
private def processPostponed : DefEqM Bool := do
|
||||
let numPostponed ← getNumPostponed
|
||||
if numPostponed == 0 then pure true
|
||||
else traceCtx `Meta.isLevelDefEq.postponed $ processPostponedAux ()
|
||||
if (← getNumPostponed) == 0 then
|
||||
pure true
|
||||
else
|
||||
traceCtx `Meta.isLevelDefEq.postponed do
|
||||
processPostponedAux
|
||||
|
||||
private def restore (env : Environment) (mctx : MetavarContext) (postponed : PersistentArray PostponedEntry) : DefEqM Unit := do
|
||||
setEnv env
|
||||
setMCtx mctx
|
||||
set postponed
|
||||
setEnv env
|
||||
setMCtx mctx
|
||||
set postponed
|
||||
|
||||
/--
|
||||
`commitWhen x` executes `x` and process all postponed universe level constraints produced by `x`.
|
||||
|
|
@ -206,49 +207,49 @@ set postponed
|
|||
Remark: postponed universe level constraints must be solved before returning. Otherwise,
|
||||
we don't know whether `x` really succeeded. -/
|
||||
@[specialize] def commitWhen (x : DefEqM Bool) : DefEqM Bool := do
|
||||
let env ← getEnv
|
||||
let mctx ← getMCtx
|
||||
let postponed ← getResetPostponed
|
||||
try
|
||||
if (← x) then
|
||||
if (← processPostponed) then
|
||||
pure true
|
||||
let env ← getEnv
|
||||
let mctx ← getMCtx
|
||||
let postponed ← getResetPostponed
|
||||
try
|
||||
if (← x) then
|
||||
if (← processPostponed) then
|
||||
pure true
|
||||
else
|
||||
restore env mctx postponed
|
||||
pure false
|
||||
else
|
||||
restore env mctx postponed
|
||||
pure false
|
||||
else
|
||||
catch ex =>
|
||||
restore env mctx postponed
|
||||
pure false
|
||||
catch ex =>
|
||||
restore env mctx postponed
|
||||
throw ex
|
||||
throw ex
|
||||
|
||||
private def runDefEqM (x : DefEqM Bool) : MetaM Bool :=
|
||||
(commitWhen x).run' {}
|
||||
(commitWhen x).run' {}
|
||||
|
||||
def isLevelDefEq (u v : Level) : m Bool := liftMetaM do
|
||||
traceCtx `Meta.isLevelDefEq do
|
||||
let b ← runDefEqM $ Meta.isLevelDefEqAux u v
|
||||
trace[Meta.isLevelDefEq]! "{u} =?= {v} ... {if b then "success" else "failure"}"
|
||||
pure b
|
||||
traceCtx `Meta.isLevelDefEq do
|
||||
let b ← runDefEqM $ Meta.isLevelDefEqAux u v
|
||||
trace[Meta.isLevelDefEq]! "{u} =?= {v} ... {if b then "success" else "failure"}"
|
||||
pure b
|
||||
|
||||
def isExprDefEq (t s : Expr) : m Bool := liftMetaM do
|
||||
traceCtx `Meta.isDefEq $ do
|
||||
let b ← runDefEqM $ Meta.isExprDefEqAux t s
|
||||
trace[Meta.isDefEq]! "{t} =?= {s} ... {if b then "success" else "failure"}"
|
||||
pure b
|
||||
traceCtx `Meta.isDefEq $ do
|
||||
let b ← runDefEqM $ Meta.isExprDefEqAux t s
|
||||
trace[Meta.isDefEq]! "{t} =?= {s} ... {if b then "success" else "failure"}"
|
||||
pure b
|
||||
|
||||
abbrev isDefEq (t s : Expr) : m Bool :=
|
||||
isExprDefEq t s
|
||||
isExprDefEq t s
|
||||
|
||||
def isExprDefEqGuarded (a b : Expr) : m Bool := liftMetaM do
|
||||
try isExprDefEq a b catch _ => pure false
|
||||
try isExprDefEq a b catch _ => pure false
|
||||
|
||||
abbrev isDefEqGuarded (t s : Expr) : m Bool :=
|
||||
isExprDefEqGuarded t s
|
||||
isExprDefEqGuarded t s
|
||||
|
||||
def isDefEqNoConstantApprox (t s : Expr) : m Bool := liftMetaM do
|
||||
approxDefEq $ isDefEq t s
|
||||
approxDefEq $ isDefEq t s
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Meta.isLevelDefEq
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue