diff --git a/library/Init/Lean/Meta.lean b/library/Init/Lean/Meta.lean index 9de4136181..0edb61a761 100644 --- a/library/Init/Lean/Meta.lean +++ b/library/Init/Lean/Meta.lean @@ -5,70 +5,8 @@ Authors: Leonardo de Moura -/ prelude import Init.Lean.Meta.Basic +import Init.Lean.Meta.LevelDefEq import Init.Lean.Meta.WHNF import Init.Lean.Meta.InferType import Init.Lean.Meta.FunInfo -import Init.Lean.Meta.LevelDefEq import Init.Lean.Meta.ExprDefEq - -namespace Lean -namespace Meta - -/- =========================================== -/ -/- BIG HACK until we add `mutual` keyword back -/ -/- =========================================== -/ -inductive MetaOp -| whnfOp | inferTypeOp | isDefEqOp | synthPendingOp - -open MetaOp - -private def exprToBool : Expr → Bool -| Expr.sort _ _ => false -| _ => true - -private def boolToExpr : Bool → Expr -| false => mkSort levelZero -| true => mkBVar 0 -private partial def auxFixpoint : MetaOp → Expr → Expr → MetaM Expr -| op, e₁, e₂ => - let whnf := fun e => auxFixpoint whnfOp e e; - let inferType := fun e => auxFixpoint inferTypeOp e e; - let isDefEq := fun e₁ e₂ => exprToBool <$> auxFixpoint isDefEqOp e₁ e₂; - let synthPending := fun e => exprToBool <$> auxFixpoint synthPendingOp e e; - match op with - | whnfOp => whnfAux inferType isDefEq synthPending e₁ - | inferTypeOp => inferTypeAux whnf e₁ - -- | isDefEqOp => boolToExpr <$> isExprDefEqAux whnf synthPending e₁ e₂ - | isDefEqOp => boolToExpr <$> pure false - | synthPendingOp => boolToExpr <$> pure false -- TODO - -def whnf (e : Expr) : MetaM Expr := -auxFixpoint whnfOp e e - -def inferType (e : Expr) : MetaM Expr := -auxFixpoint inferTypeOp e e - -def isDefEq (e₁ e₂ : Expr) : MetaM Bool := -try $ exprToBool <$> auxFixpoint isDefEqOp e₁ e₂ -/- =========================================== -/ -/- END OF BIG HACK -/ -/- =========================================== -/ - -def isProp (e : Expr) : MetaM Bool := -isPropAux whnf e - -def getFunInfo (fn : Expr) : MetaM FunInfo := -getFunInfoAux whnf fn - -def getFunInfoNArgs (fn : Expr) (nargs : Nat) : MetaM FunInfo := -getFunInfoNArgsAux whnf fn nargs - -/-- Throws exception if `e` is not type correct. -/ -def check (e : Expr) : MetaM Unit := -checkAux whnf isDefEq e - -def isTypeCorrect (e : Expr) : MetaM Bool := -isTypeCorrectAux whnf isDefEq e - -end Meta -end Lean diff --git a/library/Init/Lean/Meta/Basic.lean b/library/Init/Lean/Meta/Basic.lean index 6116d6ec44..66c522025c 100644 --- a/library/Init/Lean/Meta/Basic.lean +++ b/library/Init/Lean/Meta/Basic.lean @@ -142,6 +142,62 @@ do s ← get; pure s.mctx @[inline] def getEnv : MetaM Environment := do s ← get; pure s.env +def mkWHNFRef : IO (IO.Ref (Expr → MetaM Expr)) := +IO.mkRef $ fun _ => throw $ Exception.other "whnf implementation was not set" + +@[init mkWHNFRef] def whnfRef : IO.Ref (Expr → MetaM Expr) := arbitrary _ + +def mkInferTypeRef : IO (IO.Ref (Expr → MetaM Expr)) := +IO.mkRef $ fun _ => throw $ Exception.other "inferType implementation was not set" + +@[init mkInferTypeRef] def inferTypeRef : IO.Ref (Expr → MetaM Expr) := arbitrary _ + +def mkIsExprDefEqAuxRef : IO (IO.Ref (Expr → Expr → MetaM Bool)) := +IO.mkRef $ fun _ _ => throw $ Exception.other "isDefEq implementation was not set" + +@[init mkIsExprDefEqAuxRef] def isExprDefEqAuxRef : IO.Ref (Expr → Expr → MetaM Bool) := arbitrary _ + +def mkSynthPendingRef : IO (IO.Ref (Expr → MetaM Bool)) := +IO.mkRef $ fun _ => pure false + +@[init mkSynthPendingRef] def synthPendingRef : IO.Ref (Expr → MetaM Bool) := arbitrary _ + +structure MetaExtState := +(whnf : Expr → MetaM Expr) +(inferType : Expr → MetaM Expr) +(isDefEqAux : Expr → Expr → MetaM Bool) +(synthPending : Expr → MetaM Bool) + +instance MetaExtState.inhabited : Inhabited MetaExtState := +⟨{ whnf := arbitrary _, inferType := arbitrary _, isDefEqAux := arbitrary _, synthPending := arbitrary _ }⟩ + +def mkMetaExtension : IO (EnvExtension MetaExtState) := +registerEnvExtension $ do + whnf ← whnfRef.get; + inferType ← inferTypeRef.get; + isDefEqAux ← isExprDefEqAuxRef.get; + synthPending ← synthPendingRef.get; + pure { whnf := whnf, inferType := inferType, isDefEqAux := isDefEqAux, synthPending := synthPending } + +@[init mkMetaExtension] +constant metaExt : EnvExtension MetaExtState := arbitrary _ + +def whnf (e : Expr) : MetaM Expr := +do env ← getEnv; + (metaExt.getState env).whnf e + +def inferType (e : Expr) : MetaM Expr := +do env ← getEnv; + (metaExt.getState env).inferType e + +def isExprDefEqAux (t s : Expr) : MetaM Bool := +do env ← getEnv; + (metaExt.getState env).isDefEqAux t s + +def synthPending (e : Expr) : MetaM Bool := +do env ← getEnv; + (metaExt.getState env).synthPending e + @[inline] def throwEx {α} (f : ExceptionContext → Exception) : MetaM α := do ctx ← read; s ← get; @@ -401,7 +457,6 @@ resettingTypeClassCache $ when `fvars.size == max` -/ @[specialize] private partial def forallTelescopeReducingAuxAux {α} - (whnf : Expr → MetaM Expr) (isClassExpensive : Expr → MetaM (Option Name)) (reducing? : Bool) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) @@ -438,22 +493,19 @@ resettingTypeClassCache $ /- We need this auxiliary definition because it depends on `isClassExpensive`, and `isClassExpensive` depends on it. -/ @[specialize] private def forallTelescopeReducingAux {α} - (whnf : Expr → MetaM Expr) (isClassExpensive : Expr → MetaM (Option Name)) (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := do newType ← whnf type; if newType.isForall then savingCache $ do lctx ← getLCtx; - forallTelescopeReducingAuxAux whnf isClassExpensive true maxFVars? k lctx #[] 0 newType + forallTelescopeReducingAuxAux isClassExpensive true maxFVars? k lctx #[] 0 newType else k #[] type -@[specialize] partial def isClassExpensive - (whnf : Expr → MetaM Expr) - : Expr → MetaM (Option Name) +partial def isClassExpensive : Expr → MetaM (Option Name) | type => usingTransparency TransparencyMode.reducible $ -- when testing whether a type is a type class, we only unfold reducible constants. - forallTelescopeReducingAux whnf isClassExpensive type none $ fun xs type => do + forallTelescopeReducingAux isClassExpensive type none $ fun xs type => do match type.getAppFn with | Expr.const c _ _ => do env ← getEnv; @@ -464,42 +516,33 @@ do newType ← whnf type; Given `type` of the form `forall xs, A`, execute `k xs A`. This combinator will declare local declarations, create free variables for them, execute `k` with updated local context, and make sure the cache is restored after executing `k`. -/ -@[inline] def forallTelescope {α} - (whnf : Expr → MetaM Expr) - (type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := +@[inline] def forallTelescope {α} (type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := savingCache $ do lctx ← getLCtx; - forallTelescopeReducingAuxAux whnf (isClassExpensive whnf) false none k lctx #[] 0 type + forallTelescopeReducingAuxAux isClassExpensive false none k lctx #[] 0 type /-- Similar to `forallTelescope`, but given `type` of the form `forall xs, A`, it reduces `A` and continues bulding the telescope if it is a `forall`. -/ -@[specialize] def forallTelescopeReducing {α} - (whnf : Expr → MetaM Expr) - (type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := -forallTelescopeReducingAux whnf (isClassExpensive whnf) type none k +@[inline] def forallTelescopeReducing {α} (type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := +forallTelescopeReducingAux isClassExpensive type none k /-- Similar to `forallTelescopeReducing`, stops constructing the telescope when it reaches size `maxFVars`. -/ -@[specialize] def forallBoundedTelescope {α} - (whnf : Expr → MetaM Expr) - (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := -forallTelescopeReducingAux whnf (isClassExpensive whnf) type maxFVars? k +@[inline] def forallBoundedTelescope {α} (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := +forallTelescopeReducingAux isClassExpensive type maxFVars? k -@[specialize] def isClass - (whnf : Expr → MetaM Expr) - (type : Expr) : MetaM (Option Name) := +def isClass (type : Expr) : MetaM (Option Name) := do c? ← isClassQuick type; match c? with | LOption.none => pure none | LOption.some c => pure (some c) - | LOption.undef => isClassExpensive whnf type + | LOption.undef => isClassExpensive type /-- Similar to `forallTelescopeAuxAux` but for lambda and let expressions. -/ @[specialize] private partial def lambdaTelescopeAux {α} - (whnf : Expr → MetaM Expr) - (k : Array Expr → Expr → MetaM α) + (k : Array Expr → Expr → MetaM α) : LocalContext → Array Expr → Nat → Expr → MetaM α | lctx, fvars, j, Expr.lam n d b c => do let d := d.instantiateRevRange j fvars.size fvars; @@ -517,16 +560,15 @@ do c? ← isClassQuick type; | lctx, fvars, j, e => let e := e.instantiateRevRange j fvars.size fvars; adaptReader (fun (ctx : Context) => { lctx := lctx, .. ctx }) $ - withNewLocalInstances (isClassExpensive whnf) fvars j $ do + withNewLocalInstances isClassExpensive fvars j $ do k fvars e /-- Similar to `forallTelescope` but for lambda and let expressions. -/ @[specialize] def lambdaTelescope {α} - (whnf : Expr → MetaM Expr) (e : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := savingCache $ do lctx ← getLCtx; - lambdaTelescopeAux whnf k lctx #[] 0 e + lambdaTelescopeAux k lctx #[] 0 e @[inline] def liftStateMCtx {α} (x : StateM MetavarContext α) : MetaM α := fun _ s => @@ -544,7 +586,7 @@ do mvarId ← mkFreshId; modify $ fun s => { mctx := s.mctx.addLevelMVarDecl mvarId, .. s }; pure mvarId -@[inline] def usingDefault (whnf : Expr → MetaM Expr) : Expr → MetaM Expr := +def whnfUsingDefault : Expr → MetaM Expr := fun e => usingTransparency TransparencyMode.default $ whnf e end Meta diff --git a/library/Init/Lean/Meta/Check.lean b/library/Init/Lean/Meta/Check.lean index ad422338e0..353626996c 100644 --- a/library/Init/Lean/Meta/Check.lean +++ b/library/Init/Lean/Meta/Check.lean @@ -14,94 +14,75 @@ whether terms produced by tactics and `isDefEq` are type correct. namespace Lean namespace Meta -@[specialize] private def ensureType - (whnf : Expr → MetaM Expr) - (e : Expr) : MetaM Unit := -do getLevelAux whnf (inferTypeAux whnf) e; - pure () +private def ensureType (e : Expr) : MetaM Unit := +do getLevel e; pure () @[specialize] private def checkLambdaLet - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) (check : Expr → MetaM Unit) (e : Expr) : MetaM Unit := -lambdaTelescope whnf e $ fun xs b => do +lambdaTelescope e $ fun xs b => do xs.forM $ fun x => do { xDecl ← getFVarLocalDecl x; match xDecl with | LocalDecl.cdecl _ _ _ t _ => do - ensureType whnf t; + ensureType t; check t | LocalDecl.ldecl _ _ _ t v => do - ensureType whnf t; + ensureType t; check t; - vType ← inferTypeAux whnf v; - unlessM (isDefEq t vType) $ throwEx $ Exception.letTypeMismatch x.fvarId!; + vType ← inferType v; + unlessM (isExprDefEqAux t vType) $ throwEx $ Exception.letTypeMismatch x.fvarId!; check v }; check b @[specialize] private def checkForall - (whnf : Expr → MetaM Expr) (check : Expr → MetaM Unit) (e : Expr) : MetaM Unit := -forallTelescope whnf e $ fun xs b => do +forallTelescope e $ fun xs b => do xs.forM $ fun x => do { xDecl ← getFVarLocalDecl x; - ensureType whnf xDecl.type; + ensureType xDecl.type; check xDecl.type }; - ensureType whnf b; + ensureType b; check b -@[specialize] private def checkConstant - (c : Name) (lvls : List Level) : MetaM Unit := +private def checkConstant (c : Name) (lvls : List Level) : MetaM Unit := do env ← getEnv; match env.find c with | none => throwEx $ Exception.unknownConst c | some cinfo => unless (lvls.length != cinfo.lparams.length) $ throwEx $ Exception.incorrectNumOfLevels c lvls @[specialize] private def checkApp - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) (check : Expr → MetaM Unit) (f a : Expr) : MetaM Unit := do check f; check a; - fType ← inferTypeAux whnf f; + fType ← inferType f; fType ← whnf fType; match fType with | Expr.forallE _ d _ _ => do - aType ← inferTypeAux whnf a; - unlessM (isDefEq d aType) $ throwEx $ Exception.appTypeMismatch f a + aType ← inferType a; + unlessM (isExprDefEqAux d aType) $ throwEx $ Exception.appTypeMismatch f a | _ => unless fType.isForall $ throwEx $ Exception.functionExpected f a -@[specialize] private partial def checkAuxAux - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - : Expr → MetaM Unit -| e@(Expr.forallE _ _ _ _) => checkForall whnf checkAuxAux e -| e@(Expr.lam _ _ _ _) => checkLambdaLet whnf isDefEq checkAuxAux e -| e@(Expr.letE _ _ _ _ _) => checkLambdaLet whnf isDefEq checkAuxAux e +private partial def checkAux : Expr → MetaM Unit +| e@(Expr.forallE _ _ _ _) => checkForall checkAux e +| e@(Expr.lam _ _ _ _) => checkLambdaLet checkAux e +| e@(Expr.letE _ _ _ _ _) => checkLambdaLet checkAux e | Expr.const c lvls _ => checkConstant c lvls -| Expr.app f a _ => checkApp whnf isDefEq checkAuxAux f a -| Expr.mdata _ e _ => checkAuxAux e -| Expr.proj _ _ e _ => checkAuxAux e +| Expr.app f a _ => checkApp checkAux f a +| Expr.mdata _ e _ => checkAux e +| Expr.proj _ _ e _ => checkAux e | _ => pure () -@[specialize] def checkAux - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (e : Expr) : MetaM Unit := -usingTransparency TransparencyMode.all $ - checkAuxAux whnf isDefEq e +def check (e : Expr) : MetaM Unit := +usingTransparency TransparencyMode.all $ checkAux e -@[specialize] def isTypeCorrectAux - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (e : Expr) : MetaM Bool := +def isTypeCorrect (e : Expr) : MetaM Bool := catch - (do checkAux whnf isDefEq e; pure true) + (do checkAux e; pure true) (fun _ => pure false) end Meta diff --git a/library/Init/Lean/Meta/ExprDefEq.lean b/library/Init/Lean/Meta/ExprDefEq.lean index 5e32febeee..88ad759256 100644 --- a/library/Init/Lean/Meta/ExprDefEq.lean +++ b/library/Init/Lean/Meta/ExprDefEq.lean @@ -23,17 +23,14 @@ namespace Meta (fun x : A => f ?m) =?= f ``` The left-hand side of the constraint above it not eta-reduced because `?m` is a metavariable. -/ -@[specialize] private def isDefEqEta - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (a b : Expr) : MetaM Bool := +private def isDefEqEta (a b : Expr) : MetaM Bool := if a.isLambda && !b.isLambda then do - bType ← inferTypeAux whnf b; - bType ← usingDefault whnf bType; + bType ← inferType b; + bType ← whnfUsingDefault bType; match bType with | Expr.forallE n d b c => let b' := Lean.mkLambda n c.binderInfo d (mkApp b (mkBVar 0)); - try $ isDefEq a b' + try $ isExprDefEqAux a b' | _ => pure false else pure false @@ -81,8 +78,7 @@ match e.etaExpanded? with Pre: `paramInfo.size <= args₁.size = args₂.size` -/ -@[specialize] private partial def isDefEqArgsFirstPass - (isDefEq : Expr → Expr → MetaM Bool) +private partial def isDefEqArgsFirstPass (paramInfo : Array ParamInfo) (args₁ args₂ : Array Expr) : Nat → Array Nat → MetaM (Option (Array Nat)) | i, postponed => if h : i < paramInfo.size then @@ -91,39 +87,33 @@ match e.etaExpanded? with let a₂ := args₂.get! i; if info.implicit || info.instImplicit then condM (isEtaUnassignedMVar a₁ <||> isEtaUnassignedMVar a₂) - (condM (isDefEq a₁ a₂) + (condM (isExprDefEqAux a₁ a₂) (isDefEqArgsFirstPass (i+1) postponed) (pure none)) (isDefEqArgsFirstPass (i+1) (postponed.push i)) else - condM (isDefEq a₁ a₂) + condM (isExprDefEqAux a₁ a₂) (isDefEqArgsFirstPass (i+1) postponed) (pure none) else pure (some postponed) -@[specialize] private partial def isDefEqArgsAux - (isDefEq : Expr → Expr → MetaM Bool) - (args₁ args₂ : Array Expr) (h : args₁.size = args₂.size) : Nat → MetaM Bool +private partial def isDefEqArgsAux (args₁ args₂ : Array Expr) (h : args₁.size = args₂.size) : Nat → MetaM Bool | i => if h₁ : i < args₁.size then let a₁ := args₁.get ⟨i, h₁⟩; let a₂ := args₂.get ⟨i, h ▸ h₁⟩; - condM (isDefEq a₁ a₂) + condM (isExprDefEqAux a₁ a₂) (isDefEqArgsAux (i+1)) (pure false) else pure true -@[specialize] private def isDefEqArgs - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (f : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := +private def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : MetaM Bool := if h : args₁.size = args₂.size then do - finfo ← getFunInfoNArgsAux whnf f args₁.size; - isDefEqArgsAux isDefEq args₁ args₂ h finfo.paramInfo.size; - (some postponed) ← isDefEqArgsFirstPass isDefEq finfo.paramInfo args₁ args₂ 0 #[] | pure false; + finfo ← getFunInfoNArgs f args₁.size; + isDefEqArgsAux args₁ args₂ h finfo.paramInfo.size; + (some postponed) ← isDefEqArgsFirstPass finfo.paramInfo args₁ args₂ 0 #[] | pure false; /- Second pass: unify implicit arguments. In the second pass, we make sure we are unfolding at least non reducible definitions (default setting). -/ @@ -132,11 +122,11 @@ if h : args₁.size = args₂.size then do let a₂ := args₂.get! i; let info := finfo.paramInfo.get! i; when info.instImplicit $ do { - synthesizePending a₁; - synthesizePending a₂; + synthPending a₁; + synthPending a₂; pure () }; - usingAtLeastTransparency TransparencyMode.default $ isDefEq a₁ a₂ + usingAtLeastTransparency TransparencyMode.default $ isExprDefEqAux a₁ a₂ else pure false @@ -151,18 +141,15 @@ else We can't use `withNewLocalInstances` because the `isDeq fvarType d₂` may use local instances. -/ -@[specialize] partial def isDefEqBindingDomain - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (fvars : Array Expr) (ds₂ : Array Expr) : Nat → MetaM Bool → MetaM Bool +@[specialize] partial def isDefEqBindingDomain (fvars : Array Expr) (ds₂ : Array Expr) : Nat → MetaM Bool → MetaM Bool | i, k => if h : i < fvars.size then do let fvar := fvars.get ⟨i, h⟩; fvarDecl ← getFVarLocalDecl fvar; let fvarType := fvarDecl.type; let d₂ := ds₂.get! i; - condM (isDefEq fvarType d₂) - (do c? ← isClass whnf fvarType; + condM (isExprDefEqAux fvarType d₂) + (do c? ← isClass fvarType; match c? with | some className => withNewLocalInstance className fvar $ isDefEqBindingDomain (i+1) k | none => isDefEqBindingDomain (i+1) k) @@ -174,10 +161,7 @@ else It accumulates the new free variables in `fvars`, and declare them at `lctx`. We use the domain types of `e₁` to create the new free variables. We store the domain types of `e₂` at `ds₂`. -/ -@[specialize] private partial def isDefEqBindingAux - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - : LocalContext → Array Expr → Expr → Expr → Array Expr → MetaM Bool +private partial def isDefEqBindingAux : LocalContext → Array Expr → Expr → Expr → Array Expr → MetaM Bool | lctx, fvars, e₁, e₂, ds₂ => let process (n : Name) (d₁ d₂ b₁ b₂ : Expr) : MetaM Bool := do { let d₁ := d₁.instantiateRev fvars; @@ -192,15 +176,12 @@ else | Expr.lam n d₁ b₁ _, Expr.lam _ d₂ b₂ _ => process n d₁ d₂ b₁ b₂ | _, _ => adaptReader (fun (ctx : Context) => { lctx := lctx, .. ctx }) $ - isDefEqBindingDomain whnf isDefEq fvars ds₂ 0 $ - isDefEq (e₁.instantiateRev fvars) (e₂.instantiateRev fvars) + isDefEqBindingDomain fvars ds₂ 0 $ + isExprDefEqAux (e₁.instantiateRev fvars) (e₂.instantiateRev fvars) -@[inline] private def isDefEqBinding - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (a b : Expr) : MetaM Bool := +@[inline] private def isDefEqBinding (a b : Expr) : MetaM Bool := do lctx ← getLCtx; - isDefEqBindingAux whnf isDefEq lctx #[] a b #[] + isDefEqBindingAux lctx #[] a b #[] /- Each metavariable is declared in a particular local context. @@ -389,7 +370,7 @@ instance : MonadCache Expr Expr CheckAssignmentM := @[inline] private def visit (f : Expr → CheckAssignmentM Expr) (e : Expr) : CheckAssignmentM Expr := if !e.hasExprMVar && !e.hasFVar then pure e else checkCache e f -@[inline] def checkFVar (check : Expr → CheckAssignmentM Expr) (fvar : Expr) : CheckAssignmentM Expr := +@[specialize] def checkFVar (check : Expr → CheckAssignmentM Expr) (fvar : Expr) : CheckAssignmentM Expr := do ctx ← read; if ctx.mvarDecl.lctx.containsFVar fvar then pure fvar else do @@ -409,7 +390,7 @@ do s ← get; modify $ fun s => { ngen := s.ngen.next, mctx := s.mctx.addExprMVarDecl mvarId Name.anonymous lctx type, .. s }; pure (mkMVar mvarId) -@[inline] def checkMVar (check : Expr → CheckAssignmentM Expr) (mvar : Expr) : CheckAssignmentM Expr := +@[specialize] def checkMVar (check : Expr → CheckAssignmentM Expr) (mvar : Expr) : CheckAssignmentM Expr := do let mvarId := mvar.mvarId!; ctx ← read; mctx ← getMCtx; @@ -550,22 +531,18 @@ fun ctx s => if !v.hasExprMVar && !v.hasFVar then EStateM.Result.ok (some v) s e We try to unify arguments before we try to unify the functions. The motivation is the following: the universe constraints in the arguments propagate to the function. -/ -@[specialize] private partial def isDefEqFOApprox - (isDefEq : Expr → Expr → MetaM Bool) - (f₁ f₂ : Expr) (args₁ args₂ : Array Expr) : Nat → Nat → MetaM Bool +private partial def isDefEqFOApprox (f₁ f₂ : Expr) (args₁ args₂ : Array Expr) : Nat → Nat → MetaM Bool | i₁, i₂ => if h : i₂ < args₂.size then let arg₁ := args₁.get! i₁; let arg₂ := args₂.get ⟨i₂, h⟩; - condM (isDefEq arg₁ arg₂) + condM (isExprDefEqAux arg₁ arg₂) (isDefEqFOApprox (i₁+1) (i₂+1)) (pure false) else - isDefEq f₁ f₂ + isExprDefEqAux f₁ f₂ -@[specialize] private def processAssignmentFOApproxAux - (isDefEq : Expr → Expr → MetaM Bool) - (mvar : Expr) (args : Array Expr) (v : Expr) : MetaM Bool := +private def processAssignmentFOApproxAux (mvar : Expr) (args : Array Expr) (v : Expr) : MetaM Bool := let vArgs := v.getAppArgs; if vArgs.isEmpty then /- ?m a_1 ... a_k =?= t, where t is not an application -/ @@ -583,7 +560,7 @@ else if args.size > vArgs.size then -/ let f₁ := mkAppRange mvar 0 (args.size - vArgs.size) args; let i₁ := args.size - vArgs.size; - isDefEqFOApprox isDefEq f₁ v.getAppFn args vArgs i₁ 0 + isDefEqFOApprox f₁ v.getAppFn args vArgs i₁ 0 else if args.size < vArgs.size then /- ?m a_1 ... a_k =?= f b_1 ... b_i b_{i+1} ... b_{i+k} @@ -597,7 +574,7 @@ else if args.size < vArgs.size then -/ let vFn := mkAppRange v.getAppFn 0 (vArgs.size - args.size) vArgs; let i₂ := vArgs.size - args.size; - isDefEqFOApprox isDefEq mvar vFn args vArgs 0 i₂ + isDefEqFOApprox mvar vFn args vArgs 0 i₂ else /- ?m a_1 ... a_k =?= f b_1 ... b_k @@ -609,7 +586,7 @@ else ... a_k =?= b_k -/ - isDefEqFOApprox isDefEq mvar v.getAppFn args vArgs 0 0 + isDefEqFOApprox mvar v.getAppFn args vArgs 0 0 /- Auxiliary method for applying first-order unification. It is an approximation. @@ -627,15 +604,11 @@ else def ITactic := Tactic Unit -/ -@[specialize] private partial def processAssignmentFOApprox - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (mvar : Expr) (args : Array Expr) : Expr → MetaM Bool +private partial def processAssignmentFOApprox (mvar : Expr) (args : Array Expr) : Expr → MetaM Bool | v => - condM (try $ processAssignmentFOApproxAux isDefEq mvar args v) + condM (try $ processAssignmentFOApproxAux mvar args v) (pure true) - (unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending v (pure false) processAssignmentFOApprox) + (unfoldDefinitionAux v (pure false) processAssignmentFOApprox) private partial def simpAssignmentArgAux : Expr → MetaM Expr | Expr.mdata _ e _ => simpAssignmentArgAux e @@ -653,14 +626,7 @@ private def simpAssignmentArg (arg : Expr) : MetaM Expr := do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg; simpAssignmentArgAux arg -@[specialize] private partial def processAssignmentAux - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (mvar : Expr) - (mvarDecl : MetavarDecl) - (v : Expr) - : Nat → Array Expr → MetaM Bool +private partial def processAssignmentAux (mvar : Expr) (mvarDecl : MetavarDecl) (v : Expr) : Nat → Array Expr → MetaM Bool | i, args => if h : i < args.size then do cfg ← getConfig; @@ -669,7 +635,7 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg; let args := args.set ⟨i, h⟩ arg; let useFOApprox : Unit → MetaM Bool := fun _ => if cfg.foApprox then - processAssignmentFOApprox whnf isDefEq synthesizePending mvar args v + processAssignmentFOApprox mvar args v else pure false; match arg with @@ -686,10 +652,10 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg; cfg ← getConfig; v ← instantiateMVars v; -- enforce A4 if cfg.foApprox && args.isEmpty && v.getAppFn == mvar then - processAssignmentFOApprox whnf isDefEq synthesizePending mvar args v + processAssignmentFOApprox mvar args v else do let useFOApprox : Unit → MetaM Bool := fun _ => - if cfg.foApprox then processAssignmentFOApprox whnf isDefEq synthesizePending mvar args v + if cfg.foApprox then processAssignmentFOApprox mvar args v else pure false; let mvarId := mvar.mvarId!; v? ← checkAssignment mvarId args v; @@ -699,9 +665,9 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg; v ← mkLambda args v; let finalize : Unit → MetaM Bool := fun _ => do { -- must check whether types are definitionally equal or not, before assigning and returning true - mvarType ← inferTypeAux whnf mvar; - vType ← inferTypeAux whnf v; - condM (usingTransparency TransparencyMode.default $ isDefEq mvarType vType) + mvarType ← inferType mvar; + vType ← inferType v; + condM (usingTransparency TransparencyMode.default $ isExprDefEqAux mvarType vType) (do assignExprMVar mvarId v; pure true) (do trace! `Meta.isDefEq.assignment.typeMismatch (mvar ++ " : " ++ mvarType ++ " := " ++ v ++ " : " ++ vType); pure false) @@ -709,7 +675,7 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg; if args.any (fun arg => mvarDecl.lctx.containsFVar arg) then /- We need to type check `v` because abstraction using `mkLambda` may have produced a type incorrect term. See discussion at A2 -/ - condM (isTypeCorrectAux whnf isDefEq v) + condM (isTypeCorrect v) (finalize ()) (useFOApprox ()) else @@ -717,22 +683,10 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg; /-- Tries to solve `?m a₁ ... aₙ =?= v` by assigning `?m`. It assumes `?m` is unassigned. -/ -@[specialize] private def processAssignment - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (mvarApp : Expr) (v : Expr) : MetaM Bool := +private def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool := do let mvar := mvarApp.getAppFn; mvarDecl ← getMVarDecl mvar.mvarId!; - processAssignmentAux whnf isDefEq synthesizePending mvar mvarDecl v 0 mvarApp.getAppArgs - -@[specialize] private def unfold {α} - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (e : Expr) - (failK : MetaM α) (successK : Expr → MetaM α) : MetaM α := -unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK + processAssignmentAux mvar mvarDecl v 0 mvarApp.getAppArgs private def isDeltaCandidate (t : Expr) : MetaM (Option ConstantInfo) := match t.getAppFn with @@ -744,64 +698,60 @@ match t.getAppFn, s.getAppFn with | Expr.const c₁ _ _, Expr.const c₂ _ _ => true | _, _ => false -@[specialize] private def isDefEqDelta - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (t s : Expr) - (k : MetaM Bool) -- continuation when `isDefEqDelta` could not decide - : MetaM Bool := -do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do { +@[specialize] private def isDefEqDelta (t s : Expr) : MetaM LBool := +do let isListLevelDefEqAuxL (us vs) : MetaM LBool := toLBoolM $ isListLevelDefEqAux us vs; + let isDefEqL (t s) : MetaM LBool := toLBoolM $ isExprDefEqAux t s; + let isDefEqLeft (fn : Name) (t s : Expr) : MetaM LBool := do { trace! `Meta.isDefEq.delta.unfoldLeft fn; - isDefEq t s + isDefEqL t s }; - let isDefEqRight (fn : Name) (t s : Expr) : MetaM Bool := do { + let isDefEqRight (fn : Name) (t s : Expr) : MetaM LBool := do { trace! `Meta.isDefEq.delta.unfoldRight fn; - isDefEq t s + isDefEqL t s }; - let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM Bool := do { + let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM LBool := do { trace! `Meta.isDefEq.delta.unfoldLeftRight fn; - isDefEq t s + isDefEqL t s }; - let unfold (e failK successK) : MetaM Bool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK; + let unfold (e failK successK) : MetaM LBool := unfoldDefinitionAux e failK successK; let tryHeuristic : MetaM Bool := /- Try to solve `f a₁ ... aₙ =?= f b₁ ... bₙ` by solving `a₁ =?= b₁, ..., aₙ =?= bₙ` -/ let tFn := t.getAppFn; let sFn := s.getAppFn; traceCtx `Meta.isDefEq.delta $ try $ - isDefEqArgs whnf isDefEq synthesizePending tFn t.getAppArgs s.getAppArgs + isDefEqArgs tFn t.getAppArgs s.getAppArgs <&&> isListLevelDefEqAux tFn.constLevels! sFn.constLevels!; tInfo? ← isDeltaCandidate t.getAppFn; sInfo? ← isDeltaCandidate s.getAppFn; match tInfo?, sInfo? with - | none, none => k - | some tInfo, none => unfold t k $ fun t => isDefEqLeft tInfo.name t s - | none, some sInfo => unfold s k $ fun s => isDefEqRight sInfo.name t s + | none, none => pure LBool.undef + | some tInfo, none => unfold t (pure LBool.undef) $ fun t => isDefEqLeft tInfo.name t s + | none, some sInfo => unfold s (pure LBool.undef) $ fun s => isDefEqRight sInfo.name t s | some tInfo, some sInfo => let isDefEqLeft (t s) := isDefEqLeft tInfo.name t s; let isDefEqRight (t s) := isDefEqRight sInfo.name t s; let isDefEqLeftRight (t s) := isDefEqLeftRight tInfo.name t s; if tInfo.name == sInfo.name then match t, s with - | Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAux ls₁ ls₂ + | Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAuxL ls₁ ls₂ | Expr.app _ _ _, Expr.app _ _ _ => condM tryHeuristic - (pure true) + (pure LBool.true) (unfold t - (unfold s (pure false) (fun s => isDefEqRight t s)) + (unfold s (pure LBool.false) (fun s => isDefEqRight t s)) (fun t => unfold s (isDefEqLeft t s) (fun s => isDefEqLeftRight t s))) - | _, _ => pure false + | _, _ => pure LBool.false else - let unfoldComparingHeads : Unit → MetaM Bool := fun _ => + let unfoldComparingHeads : Unit → MetaM LBool := fun _ => /- - If headSymbol (unfold t) == headSymbol s, then unfold t - If headSymbol (unfold s) == headSymbol t, then unfold s - Otherwise unfold t and s if possible. -/ unfold t (unfold s - k -- `t` and `s` failed to be unfolded + (pure LBool.undef) -- `t` and `s` failed to be unfolded (fun s => isDefEqRight t s)) (fun tNew => if sameHeadSymbol tNew s then @@ -810,7 +760,7 @@ do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do { unfold s (isDefEqLeft tNew s) (fun sNew => if sameHeadSymbol t sNew then isDefEqRight t sNew else isDefEqLeftRight tNew sNew)); - let kernelLikeUnfolding : Unit → MetaM Bool := fun _ => + let kernelLikeUnfolding : Unit → MetaM LBool := fun _ => if !t.hasExprMVar && !s.hasExprMVar then /- If `t` and `s` do not contain metavariables, we simulate strategy used in the kernel. -/ @@ -858,42 +808,38 @@ private def isLetFVar (fvarId : Name) : MetaM Bool := do decl ← getLocalDecl fvarId; pure decl.isLet -@[specialize] private partial def isDefEqQuick - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - : Expr → Expr → MetaM Bool → MetaM Bool -| Expr.lit l₁ _, Expr.lit l₂ _, _ => pure (l₁ == l₂) -| Expr.sort u _, Expr.sort v _, _ => isLevelDefEqAux u v -| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _), _ => isDefEqBinding whnf isDefEq t s -| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _), _ => isDefEqBinding whnf isDefEq t s -| Expr.mdata _ t _, s, k => isDefEqQuick t s k -| t, Expr.mdata _ s _, k => isDefEqQuick t s k -| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _, k => +private partial def isDefEqQuick : Expr → Expr → MetaM LBool +| Expr.lit l₁ _, Expr.lit l₂ _ => pure (l₁ == l₂).toLBool +| Expr.sort u _, Expr.sort v _ => toLBoolM $ isLevelDefEqAux u v +| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _) => toLBoolM $ isDefEqBinding t s +| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _) => toLBoolM $ isDefEqBinding t s +| Expr.mdata _ t _, s => isDefEqQuick t s +| t, Expr.mdata _ s _ => isDefEqQuick t s +| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _ => condM (isLetFVar fvarId₁ <||> isLetFVar fvarId₂) - k - (pure (fvarId₁ == fvarId₂)) -| t, s, k => - cond (t == s) (pure true) $ - cond (etaEq t s || etaEq s t) (pure true) $ -- t =?= (fun xs => t xs) + (pure LBool.undef) + (pure (fvarId₁ == fvarId₂).toLBool) +| t, s => + cond (t == s) (pure LBool.true) $ + cond (etaEq t s || etaEq s t) (pure LBool.true) $ -- t =?= (fun xs => t xs) let tFn := t.getAppFn; let sFn := s.getAppFn; - cond (!tFn.isMVar && !sFn.isMVar) k $ - condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $ - condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ - condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $ - condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ do + cond (!tFn.isMVar && !sFn.isMVar) (pure LBool.undef) $ + condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $ + condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ + condM (isSynthetic tFn <&&> synthPending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $ + condM (isSynthetic sFn <&&> synthPending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do tAssign? ← isAssignable tFn; sAssign? ← isAssignable sFn; - let assign (t s : Expr) : MetaM Bool := processAssignment whnf isDefEq synthesizePending t s; + let assign (t s : Expr) : MetaM LBool := toLBoolM $ processAssignment t s; cond (tAssign? && !sAssign?) (assign t s) $ cond (!tAssign? && sAssign?) (assign s t) $ cond (!tAssign? && !sAssign?) (if tFn.isMVar || sFn.isMVar then do ctx ← read; if ctx.config.isDefEqStuckEx then throwEx $ Exception.isDefEqStuck t s - else pure false - else k) $ do + else pure LBool.false + else pure LBool.undef) $ do -- Both `t` and `s` are terms of the form `?m ...` tMVarDecl ← getMVarDecl tFn.mvarId!; sMVarDecl ← getMVarDecl sFn.mvarId!; @@ -912,80 +858,68 @@ do decl ← getLocalDecl fvarId; cond (!s.isApp && t.isApp && tMVarDecl.lctx.isSubPrefixOf sMVarDecl.lctx) (assign s t) $ assign t s -@[specialize] private def isDefEqProofIrrel - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (t s : Expr) - (k : MetaM Bool) -- continuation when `isDefEqQuick` could not decide - : MetaM Bool := -do tType ← inferTypeAux whnf t; - condM (isPropAux whnf tType) - (do sType ← inferTypeAux whnf s; isDefEq tType sType) - k +private def isDefEqProofIrrel (t s : Expr) : MetaM LBool := +do tType ← inferType t; + condM (isProp tType) + (do sType ← inferType s; toLBoolM $ isExprDefEqAux tType sType) + (pure LBool.undef) -@[specialize] private partial def whnfCoreAux - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (e : Expr) : MetaM Expr := -Lean.whnfCore getConstNoEx isAuxDef? whnf (inferTypeAux whnf) isDefEq getLocalDecl getExprMVarAssignment e +private def whnfCoreAux (e : Expr) : MetaM Expr := +Lean.whnfCore getConstNoEx isAuxDef? whnf inferType isExprDefEqAux getLocalDecl getExprMVarAssignment e + +@[inline] def tryL (x : MetaM LBool) (k : MetaM Bool) : MetaM Bool := +do status ← x; + match status with + | LBool.true => pure true + | LBool.false => pure false + | LBool.undef => k @[specialize] private partial def isDefEqWHNF - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) (t s : Expr) (k : Expr → Expr → MetaM Bool) : MetaM Bool := -do t' ← whnfCoreAux whnf isDefEq synthesizePending t; - s' ← whnfCoreAux whnf isDefEq synthesizePending s; +do t' ← whnfCoreAux t; + s' ← whnfCoreAux s; if t == t' && s == s' then - k t s + k t' s' else - isDefEqQuick whnf isDefEq synthesizePending t' s' $ k t' s' + tryL (isDefEqQuick t' s') $ k t' s' @[specialize] private def unstuckMVar - (whnf : Expr → MetaM Expr) - (synthesizePending : Expr → MetaM Bool) (e : Expr) (successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool := do s? ← getStuckMVar getConst whnf e; match s? with | some s => - condM (synthesizePending s) + condM (synthPending s) (do e ← instantiateMVars e; successK e) failK | none => failK -@[specialize] private def isDefEqOnFailure - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (t s : Expr) : MetaM Bool := -unstuckMVar whnf synthesizePending t (fun t => isDefEq t s) $ -unstuckMVar whnf synthesizePending s (fun s => isDefEq t s) $ +private def isDefEqOnFailure (t s : Expr) : MetaM Bool := +unstuckMVar t (fun t => isExprDefEqAux t s) $ +unstuckMVar s (fun s => isExprDefEqAux t s) $ pure false -@[specialize] partial def isExprDefEqAux - (whnf : Expr → MetaM Expr) - (synthesizePending : Expr → MetaM Bool) - : Expr → Expr → MetaM Bool +partial def isExprDefEqAuxImpl : Expr → Expr → MetaM Bool | t, s => do trace! `Meta.isDefEq.step (t ++ " =?= " ++ s); - isDefEqQuick whnf isExprDefEqAux synthesizePending t s $ - isDefEqProofIrrel whnf isExprDefEqAux t s $ - isDefEqWHNF whnf isExprDefEqAux synthesizePending t s $ fun t s => - isDefEqOffset whnf isExprDefEqAux t s $ - isDefEqDelta whnf isExprDefEqAux synthesizePending t s $ - condM (isDefEqEta whnf isExprDefEqAux t s <||> isDefEqEta whnf isExprDefEqAux s t) (pure true) $ + tryL (isDefEqQuick t s) $ + tryL (isDefEqProofIrrel t s) $ + isDefEqWHNF t s $ fun t s => + tryL (isDefEqOffset t s) $ + tryL (isDefEqDelta t s) $ + condM (isDefEqEta t s <||> isDefEqEta s t) (pure true) $ match t, s with | Expr.const _ us _, Expr.const _ vs _ => isListLevelDefEqAux us vs | Expr.app _ _ _, Expr.app _ _ _ => let tFn := t.getAppFn; - condM (try (isExprDefEqAux tFn s.getAppFn <&&> - isDefEqArgs whnf isExprDefEqAux synthesizePending tFn t.getAppArgs s.getAppArgs)) + condM (try (isExprDefEqAux tFn s.getAppFn <&&> isDefEqArgs tFn t.getAppArgs s.getAppArgs)) (pure true) - (isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s) - | _, _ => isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s + (isDefEqOnFailure t s) + | _, _ => isDefEqOnFailure t s + +@[init] def setIsExprDefEqAuxRef : IO Unit := +isExprDefEqAuxRef.set isExprDefEqAuxImpl end Meta end Lean diff --git a/library/Init/Lean/Meta/FunInfo.lean b/library/Init/Lean/Meta/FunInfo.lean index af979a04f1..56bc2c9099 100644 --- a/library/Init/Lean/Meta/FunInfo.lean +++ b/library/Init/Lean/Meta/FunInfo.lean @@ -20,10 +20,10 @@ do s ← get; modify $ fun s => { cache := { funInfo := s.cache.funInfo.insert ⟨t, fn, maxArgs?⟩ finfo, .. s.cache }, .. s }; pure finfo -@[inline] def whenHasVar {α} (e : Expr) (deps : α) (k : α → α) : α := +@[inline] private def whenHasVar {α} (e : Expr) (deps : α) (k : α → α) : α := if e.hasFVar then k deps else deps -def collectDepsAux (fvars : Array Expr) : Expr → Array Nat → Array Nat +private def collectDepsAux (fvars : Array Expr) : Expr → Array Nat → Array Nat | e@(Expr.app f a _), deps => whenHasVar e deps (collectDepsAux a ∘ collectDepsAux f) | e@(Expr.forallE _ d b _), deps => whenHasVar e deps (collectDepsAux b ∘ collectDepsAux d) | e@(Expr.lam _ d b _), deps => whenHasVar e deps (collectDepsAux b ∘ collectDepsAux d) @@ -36,7 +36,7 @@ def collectDepsAux (fvars : Array Expr) : Expr → Array Nat → Array Nat | some i => if deps.contains i.val then deps else deps.push i.val | _, deps => deps -def collectDeps (fvars : Array Expr) (e : Expr) : Array Nat := +private def collectDeps (fvars : Array Expr) (e : Expr) : Array Nat := let deps := collectDepsAux fvars e #[]; deps.qsort (fun i j => i < j) @@ -53,38 +53,33 @@ else else info -@[specialize] def getFunInfoAuxAux - (whnf : Expr → MetaM Expr) - (fn : Expr) (maxArgs? : Option Nat) : MetaM FunInfo := +private def getFunInfoAux (fn : Expr) (maxArgs? : Option Nat) : MetaM FunInfo := checkFunInfoCache fn maxArgs? $ do - fnType ← inferTypeAux whnf fn; - forallBoundedTelescope (usingDefault whnf) fnType maxArgs? $ fun fvars type => do - pinfo ← fvars.size.foldM - (fun (i : Nat) (pinfo : Array ParamInfo) => do - let fvar := fvars.get! i; - decl ← getFVarLocalDecl fvar; - prop ← isPropAux whnf decl.type; - let backDeps := collectDeps fvars decl.type; - let pinfo := updateHasFwdDeps pinfo backDeps; - pure $ pinfo.push { - backDeps := backDeps, - prop := prop, - implicit := decl.binderInfo == BinderInfo.implicit, - instImplicit := decl.binderInfo == BinderInfo.instImplicit }) - #[]; - let resultDeps := collectDeps fvars type; - let pinfo := updateHasFwdDeps pinfo resultDeps; - pure { resultDeps := resultDeps, paramInfo := pinfo } + fnType ← inferType fn; + usingTransparency TransparencyMode.default $ + forallBoundedTelescope fnType maxArgs? $ fun fvars type => do + pinfo ← fvars.size.foldM + (fun (i : Nat) (pinfo : Array ParamInfo) => do + let fvar := fvars.get! i; + decl ← getFVarLocalDecl fvar; + prop ← isProp decl.type; + let backDeps := collectDeps fvars decl.type; + let pinfo := updateHasFwdDeps pinfo backDeps; + pure $ pinfo.push { + backDeps := backDeps, + prop := prop, + implicit := decl.binderInfo == BinderInfo.implicit, + instImplicit := decl.binderInfo == BinderInfo.instImplicit }) + #[]; + let resultDeps := collectDeps fvars type; + let pinfo := updateHasFwdDeps pinfo resultDeps; + pure { resultDeps := resultDeps, paramInfo := pinfo } -@[inline] def getFunInfoAux - (whnf : Expr → MetaM Expr) - (fn : Expr) : MetaM FunInfo := -getFunInfoAuxAux whnf fn none +def getFunInfo (fn : Expr) : MetaM FunInfo := +getFunInfoAux fn none -@[inline] def getFunInfoNArgsAux - (whnf : Expr → MetaM Expr) - (fn : Expr) (nargs : Nat) : MetaM FunInfo := -getFunInfoAuxAux whnf fn (some nargs) +def getFunInfoNArgs (fn : Expr) (nargs : Nat) : MetaM FunInfo := +getFunInfoAux fn (some nargs) end Meta end Lean diff --git a/library/Init/Lean/Meta/InferType.lean b/library/Init/Lean/Meta/InferType.lean index d78c976dbf..d1ecc809ff 100644 --- a/library/Init/Lean/Meta/InferType.lean +++ b/library/Init/Lean/Meta/InferType.lean @@ -10,10 +10,7 @@ import Init.Lean.Meta.Basic namespace Lean namespace Meta -@[specialize] private def inferAppType - (whnf : Expr → MetaM Expr) - (inferType : Expr → MetaM Expr) - (f : Expr) (args : Array Expr) : MetaM Expr := +private def inferAppType (f : Expr) (args : Array Expr) : MetaM Expr := do fType ← inferType f; (j, fType) ← args.size.foldM (fun i (acc : Nat × Expr) => @@ -39,10 +36,7 @@ do env ← getEnv; | none => throwEx $ Exception.unknownConst c -@[specialize] private def inferProjType - (whnf : Expr → MetaM Expr) - (inferType : Expr → MetaM Expr) - (structName : Name) (idx : Nat) (e : Expr) : MetaM Expr := +private def inferProjType (structName : Name) (idx : Nat) (e : Expr) : MetaM Expr := do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProjection structName idx e; structType ← inferType e; structType ← whnf structType; @@ -55,7 +49,7 @@ do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProje else match env.find ctor with | none => failed () | some (ctorInfo) => do - ctorType ← inferAppType whnf inferType (mkConst ctor structLvls) structParams; + ctorType ← inferAppType (mkConst ctor structLvls) structParams; ctorType ← idx.foldM (fun i ctorType => do ctorType ← whnf ctorType; @@ -73,10 +67,7 @@ do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProje | _ => failed () | _ => failed () -@[specialize] def getLevelAux - (whnf : Expr → MetaM Expr) - (inferType : Expr → MetaM Expr) - (type : Expr) : MetaM Level := +def getLevel (type : Expr) : MetaM Level := do typeType ← inferType type; typeType ← whnf typeType; match typeType with @@ -90,26 +81,20 @@ do typeType ← inferType type; pure lvl) | _ => throwEx $ Exception.typeExpected type -@[specialize] private def inferForallType - (whnf : Expr → MetaM Expr) - (inferType : Expr → MetaM Expr) - (e : Expr) : MetaM Expr := -forallTelescope whnf e $ fun xs e => do - lvl ← getLevelAux whnf inferType e; +private def inferForallType (e : Expr) : MetaM Expr := +forallTelescope e $ fun xs e => do + lvl ← getLevel e; lvl ← xs.foldrM (fun x lvl => do xType ← inferType x; - xTypeLvl ← getLevelAux whnf inferType xType; + xTypeLvl ← getLevel xType; pure $ mkLevelIMax xTypeLvl lvl) lvl; pure $ mkSort lvl.normalize /- Infer type of lambda and let expressions -/ -@[specialize] private def inferLambdaType - (whnf : Expr → MetaM Expr) - (inferType : Expr → MetaM Expr) - (e : Expr) : MetaM Expr := -lambdaTelescope whnf e $ fun xs e => do +private def inferLambdaType (e : Expr) : MetaM Expr := +lambdaTelescope e $ fun xs e => do type ← inferType e; mkForall xs type @@ -140,27 +125,26 @@ do s ← get; modify $ fun s => { cache := { inferType := s.cache.inferType.insert e type, .. s.cache }, .. s }; pure type -@[specialize] partial def inferTypeAuxAux - (whnf : Expr → MetaM Expr) - : Expr → MetaM Expr +private partial def inferTypeAux : Expr → MetaM Expr | Expr.const c lvls _ => inferConstType c lvls -| e@(Expr.proj n i s _) => checkInferTypeCache e (inferProjType whnf inferTypeAuxAux n i s) -| e@(Expr.app f _ _) => checkInferTypeCache e (inferAppType whnf inferTypeAuxAux f.getAppFn e.getAppArgs) +| e@(Expr.proj n i s _) => checkInferTypeCache e (inferProjType n i s) +| e@(Expr.app f _ _) => checkInferTypeCache e (inferAppType f.getAppFn e.getAppArgs) | Expr.mvar mvarId _ => inferMVarType mvarId | Expr.fvar fvarId _ => inferFVarType fvarId | Expr.bvar bidx _ => throw $ Exception.unexpectedBVar bidx -| Expr.mdata _ e _ => inferTypeAuxAux e +| Expr.mdata _ e _ => inferTypeAux e | Expr.lit v _ => pure v.type | Expr.sort lvl _ => pure $ mkSort (mkLevelSucc lvl) -| e@(Expr.forallE _ _ _ _) => checkInferTypeCache e (inferForallType whnf inferTypeAuxAux e) -| e@(Expr.lam _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAuxAux e) -| e@(Expr.letE _ _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAuxAux e) +| e@(Expr.forallE _ _ _ _) => checkInferTypeCache e (inferForallType e) +| e@(Expr.lam _ _ _ _) => checkInferTypeCache e (inferLambdaType e) +| e@(Expr.letE _ _ _ _ _) => checkInferTypeCache e (inferLambdaType e) | Expr.localE _ _ _ _ => unreachable! -@[inline] def inferTypeAux - (whnf : Expr → MetaM Expr) - (e : Expr) : MetaM Expr := -inferTypeAuxAux (usingDefault whnf) e +def inferTypeImpl (e : Expr) : MetaM Expr := +usingTransparency TransparencyMode.default (inferTypeAux e) + +@[init] def setInferTypeRef : IO Unit := +inferTypeRef.set inferTypeImpl /-- Return `LBool.true` if given level is always equivalent to universe level zero. @@ -177,7 +161,7 @@ private def isAlwaysZero : Level → Bool `isArrowProp type n` is an "approximate" predicate which returns `LBool.true` if `type` is of the form `A_1 -> ... -> A_n -> Prop`. Remark: `type` can be a dependent arrow. -/ -@[specialize] private partial def isArrowProp : Expr → Nat → MetaM LBool +private partial def isArrowProp : Expr → Nat → MetaM LBool | Expr.sort u _, 0 => do u ← instantiateLevelMVars u; pure $ (isAlwaysZero u).toLBool | Expr.forallE _ _ _ _, 0 => pure LBool.false | Expr.forallE _ _ b _, n+1 => isArrowProp b n @@ -188,7 +172,7 @@ private def isAlwaysZero : Level → Bool /-- `isPropQuickApp f n` is an "approximate" predicate which returns `LBool.true` if `f` applied to `n` arguments is a proposition. -/ -@[specialize] private partial def isPropQuickApp : Expr → Nat → MetaM LBool +private partial def isPropQuickApp : Expr → Nat → MetaM LBool | Expr.const c lvls _, arity => do constType ← inferConstType c lvls; isArrowProp constType arity | Expr.fvar fvarId _, arity => do fvarType ← inferFVarType fvarId; isArrowProp fvarType arity | Expr.mvar mvarId _, arity => do mvarType ← inferMVarType mvarId; isArrowProp mvarType arity @@ -202,7 +186,7 @@ private def isAlwaysZero : Level → Bool /-- `isPropQuick e` is an "approximate" predicate which returns `LBool.true` if `e` is a proposition. -/ -@[specialize] private partial def isPropQuick : Expr → MetaM LBool +private partial def isPropQuick : Expr → MetaM LBool | Expr.bvar _ _ => pure LBool.undef | Expr.lit _ _ => pure LBool.false | Expr.sort _ _ => pure LBool.false @@ -223,15 +207,15 @@ private def isAlwaysZero : Level → Bool to decide whether is a proposition or not. We return `false` in this case. We considered using `LBool` and retuning `LBool.undef`, but we have no applications for it. -/ -@[specialize] def isPropAux (whnf : Expr → MetaM Expr) (e : Expr) : MetaM Bool := +def isProp (e : Expr) : MetaM Bool := do r ← isPropQuick e; match r with | LBool.true => pure true | LBool.false => pure false | LBool.undef => do -- dbgTrace ("PropQuick failed " ++ toString e); - type ← inferTypeAux whnf e; - type ← usingDefault whnf type; + type ← inferType e; + type ← whnfUsingDefault type; match type with | Expr.sort u _ => do u ← instantiateLevelMVars u; pure $ isAlwaysZero u | _ => pure false diff --git a/library/Init/Lean/Meta/LevelDefEq.lean b/library/Init/Lean/Meta/LevelDefEq.lean index 5e55e5112c..70a3277b4d 100644 --- a/library/Init/Lean/Meta/LevelDefEq.lean +++ b/library/Init/Lean/Meta/LevelDefEq.lean @@ -169,5 +169,8 @@ do s ← get; def isLevelDefEq (u v : Level) : MetaM Bool := try $ isLevelDefEqAux u v +def isExprDefEq (e₁ e₂ : Expr) : MetaM Bool := +try $ isExprDefEqAux e₁ e₂ + end Meta end Lean diff --git a/library/Init/Lean/Meta/Offset.lean b/library/Init/Lean/Meta/Offset.lean index 668a034d2b..fb2f89d46a 100644 --- a/library/Init/Lean/Meta/Offset.lean +++ b/library/Init/Lean/Meta/Offset.lean @@ -89,12 +89,8 @@ private partial def isOffset : Expr → Option (Expr × Nat) | _ => none | _ => none -@[specialize] def isDefEqOffset - (whnf : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (s t : Expr) - (k : MetaM Bool) -- continuation when `isDefEqOffset` could not decide - : MetaM Bool := +def isDefEqOffset (s t : Expr) : MetaM LBool := +let isDefEq (s t) : MetaM LBool := toLBoolM $ isExprDefEqAux s t; match isOffset s with | some (s, k₁) => match isOffset t with | some (t, k₂) => -- s+k₁ =?= t+k₂ @@ -103,16 +99,16 @@ match isOffset s with else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t | none => match evalNat t with | some v₂ => -- s+k₁ =?= v₂ - if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure false - | none => k + if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure LBool.false + | none => pure LBool.undef | none => match evalNat s with | some v₁ => match isOffset t with | some (t, k₂) => -- v₁ =?= t+k₂ - if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure false + if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure LBool.false | none => match evalNat t with - | some v₂ => pure (v₁ == v₂) -- v₁ =?= v₂ - | none => k - | none => k + | some v₂ => pure (v₁ == v₂).toLBool -- v₁ =?= v₂ + | none => pure LBool.false + | none => pure LBool.false end Meta end Lean diff --git a/library/Init/Lean/Meta/WHNF.lean b/library/Init/Lean/Meta/WHNF.lean index 8ffe89fd9b..461caceb06 100644 --- a/library/Init/Lean/Meta/WHNF.lean +++ b/library/Init/Lean/Meta/WHNF.lean @@ -7,6 +7,7 @@ prelude import Init.Lean.AuxRecursor import Init.Lean.WHNF import Init.Lean.Meta.Basic +import Init.Lean.Meta.LevelDefEq namespace Lean namespace Meta @@ -15,23 +16,18 @@ def isAuxDef? (constName : Name) : MetaM Bool := do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName) @[specialize] def unfoldDefinitionAux {α} - (whnf : Expr → MetaM Expr) - (inferType : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - (e : Expr) - (failK : MetaM α) (successK : Expr → MetaM α) : MetaM α := -Lean.unfoldDefinitionAux getConstNoEx isAuxDef? whnf inferType isDefEq synthesizePending getLocalDecl + (e : Expr) (failK : MetaM α) (successK : Expr → MetaM α) : MetaM α := +Lean.unfoldDefinitionAux getConstNoEx isAuxDef? whnf inferType isExprDefEq synthPending getLocalDecl getExprMVarAssignment e (fun _ => failK) successK -@[specialize] partial def whnfAux - (inferType : Expr → MetaM Expr) - (isDefEq : Expr → Expr → MetaM Bool) - (synthesizePending : Expr → MetaM Bool) - : Expr → MetaM Expr +partial def whnfImpl : Expr → MetaM Expr | e => whnfEasyCases getLocalDecl getExprMVarAssignment e $ fun e => do - e ← whnfCore getConstNoEx isAuxDef? whnfAux inferType isDefEq getLocalDecl getExprMVarAssignment e; - unfoldDefinitionAux whnfAux inferType isDefEq synthesizePending e (pure e) whnfAux + e ← whnfCore getConstNoEx isAuxDef? whnfImpl inferType isExprDefEqAux getLocalDecl getExprMVarAssignment e; + Lean.unfoldDefinitionAux getConstNoEx isAuxDef? whnf inferType isExprDefEq synthPending getLocalDecl + getExprMVarAssignment e (fun _ => pure e) whnfImpl + +@[init] def setWHNFRef : IO Unit := +whnfRef.set whnfImpl end Meta end Lean