refactor: use an auxiliary environment extension to implement the mutual recursion between whnf, isDefEq and inferType

@Kha @dselsam I was experiencing an insane code explosion with the
previous approach. There were too many definitions marked with
`@[specialize]`. `Meta.c` was reaching 0.5 million lines of code.
We would need a more sophisticated code specializer cache to handle
this kind of code. The new approach is much simpler. I don't see any
major disadvantages.
This commit is contained in:
Leonardo de Moura 2019-11-20 16:03:45 -08:00
parent da8f9806a8
commit f2bb86f45c
9 changed files with 296 additions and 427 deletions

View file

@ -5,70 +5,8 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Lean.Meta.Basic
import Init.Lean.Meta.LevelDefEq
import Init.Lean.Meta.WHNF
import Init.Lean.Meta.InferType
import Init.Lean.Meta.FunInfo
import Init.Lean.Meta.LevelDefEq
import Init.Lean.Meta.ExprDefEq
namespace Lean
namespace Meta
/- =========================================== -/
/- BIG HACK until we add `mutual` keyword back -/
/- =========================================== -/
inductive MetaOp
| whnfOp | inferTypeOp | isDefEqOp | synthPendingOp
open MetaOp
private def exprToBool : Expr → Bool
| Expr.sort _ _ => false
| _ => true
private def boolToExpr : Bool → Expr
| false => mkSort levelZero
| true => mkBVar 0
private partial def auxFixpoint : MetaOp → Expr → Expr → MetaM Expr
| op, e₁, e₂ =>
let whnf := fun e => auxFixpoint whnfOp e e;
let inferType := fun e => auxFixpoint inferTypeOp e e;
let isDefEq := fun e₁ e₂ => exprToBool <$> auxFixpoint isDefEqOp e₁ e₂;
let synthPending := fun e => exprToBool <$> auxFixpoint synthPendingOp e e;
match op with
| whnfOp => whnfAux inferType isDefEq synthPending e₁
| inferTypeOp => inferTypeAux whnf e₁
-- | isDefEqOp => boolToExpr <$> isExprDefEqAux whnf synthPending e₁ e₂
| isDefEqOp => boolToExpr <$> pure false
| synthPendingOp => boolToExpr <$> pure false -- TODO
def whnf (e : Expr) : MetaM Expr :=
auxFixpoint whnfOp e e
def inferType (e : Expr) : MetaM Expr :=
auxFixpoint inferTypeOp e e
def isDefEq (e₁ e₂ : Expr) : MetaM Bool :=
try $ exprToBool <$> auxFixpoint isDefEqOp e₁ e₂
/- =========================================== -/
/- END OF BIG HACK -/
/- =========================================== -/
def isProp (e : Expr) : MetaM Bool :=
isPropAux whnf e
def getFunInfo (fn : Expr) : MetaM FunInfo :=
getFunInfoAux whnf fn
def getFunInfoNArgs (fn : Expr) (nargs : Nat) : MetaM FunInfo :=
getFunInfoNArgsAux whnf fn nargs
/-- Throws exception if `e` is not type correct. -/
def check (e : Expr) : MetaM Unit :=
checkAux whnf isDefEq e
def isTypeCorrect (e : Expr) : MetaM Bool :=
isTypeCorrectAux whnf isDefEq e
end Meta
end Lean

View file

@ -142,6 +142,62 @@ do s ← get; pure s.mctx
@[inline] def getEnv : MetaM Environment :=
do s ← get; pure s.env
def mkWHNFRef : IO (IO.Ref (Expr → MetaM Expr)) :=
IO.mkRef $ fun _ => throw $ Exception.other "whnf implementation was not set"
@[init mkWHNFRef] def whnfRef : IO.Ref (Expr → MetaM Expr) := arbitrary _
def mkInferTypeRef : IO (IO.Ref (Expr → MetaM Expr)) :=
IO.mkRef $ fun _ => throw $ Exception.other "inferType implementation was not set"
@[init mkInferTypeRef] def inferTypeRef : IO.Ref (Expr → MetaM Expr) := arbitrary _
def mkIsExprDefEqAuxRef : IO (IO.Ref (Expr → Expr → MetaM Bool)) :=
IO.mkRef $ fun _ _ => throw $ Exception.other "isDefEq implementation was not set"
@[init mkIsExprDefEqAuxRef] def isExprDefEqAuxRef : IO.Ref (Expr → Expr → MetaM Bool) := arbitrary _
def mkSynthPendingRef : IO (IO.Ref (Expr → MetaM Bool)) :=
IO.mkRef $ fun _ => pure false
@[init mkSynthPendingRef] def synthPendingRef : IO.Ref (Expr → MetaM Bool) := arbitrary _
structure MetaExtState :=
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(isDefEqAux : Expr → Expr → MetaM Bool)
(synthPending : Expr → MetaM Bool)
instance MetaExtState.inhabited : Inhabited MetaExtState :=
⟨{ whnf := arbitrary _, inferType := arbitrary _, isDefEqAux := arbitrary _, synthPending := arbitrary _ }⟩
def mkMetaExtension : IO (EnvExtension MetaExtState) :=
registerEnvExtension $ do
whnf ← whnfRef.get;
inferType ← inferTypeRef.get;
isDefEqAux ← isExprDefEqAuxRef.get;
synthPending ← synthPendingRef.get;
pure { whnf := whnf, inferType := inferType, isDefEqAux := isDefEqAux, synthPending := synthPending }
@[init mkMetaExtension]
constant metaExt : EnvExtension MetaExtState := arbitrary _
def whnf (e : Expr) : MetaM Expr :=
do env ← getEnv;
(metaExt.getState env).whnf e
def inferType (e : Expr) : MetaM Expr :=
do env ← getEnv;
(metaExt.getState env).inferType e
def isExprDefEqAux (t s : Expr) : MetaM Bool :=
do env ← getEnv;
(metaExt.getState env).isDefEqAux t s
def synthPending (e : Expr) : MetaM Bool :=
do env ← getEnv;
(metaExt.getState env).synthPending e
@[inline] def throwEx {α} (f : ExceptionContext → Exception) : MetaM α :=
do ctx ← read;
s ← get;
@ -401,7 +457,6 @@ resettingTypeClassCache $
when `fvars.size == max`
-/
@[specialize] private partial def forallTelescopeReducingAuxAux {α}
(whnf : Expr → MetaM Expr)
(isClassExpensive : Expr → MetaM (Option Name))
(reducing? : Bool) (maxFVars? : Option Nat)
(k : Array Expr → Expr → MetaM α)
@ -438,22 +493,19 @@ resettingTypeClassCache $
/- We need this auxiliary definition because it depends on `isClassExpensive`,
and `isClassExpensive` depends on it. -/
@[specialize] private def forallTelescopeReducingAux {α}
(whnf : Expr → MetaM Expr)
(isClassExpensive : Expr → MetaM (Option Name))
(type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) : MetaM α :=
do newType ← whnf type;
if newType.isForall then
savingCache $ do
lctx ← getLCtx;
forallTelescopeReducingAuxAux whnf isClassExpensive true maxFVars? k lctx #[] 0 newType
forallTelescopeReducingAuxAux isClassExpensive true maxFVars? k lctx #[] 0 newType
else
k #[] type
@[specialize] partial def isClassExpensive
(whnf : Expr → MetaM Expr)
: Expr → MetaM (Option Name)
partial def isClassExpensive : Expr → MetaM (Option Name)
| type => usingTransparency TransparencyMode.reducible $ -- when testing whether a type is a type class, we only unfold reducible constants.
forallTelescopeReducingAux whnf isClassExpensive type none $ fun xs type => do
forallTelescopeReducingAux isClassExpensive type none $ fun xs type => do
match type.getAppFn with
| Expr.const c _ _ => do
env ← getEnv;
@ -464,42 +516,33 @@ do newType ← whnf type;
Given `type` of the form `forall xs, A`, execute `k xs A`.
This combinator will declare local declarations, create free variables for them,
execute `k` with updated local context, and make sure the cache is restored after executing `k`. -/
@[inline] def forallTelescope {α}
(whnf : Expr → MetaM Expr)
(type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
@[inline] def forallTelescope {α} (type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
savingCache $ do
lctx ← getLCtx;
forallTelescopeReducingAuxAux whnf (isClassExpensive whnf) false none k lctx #[] 0 type
forallTelescopeReducingAuxAux isClassExpensive false none k lctx #[] 0 type
/--
Similar to `forallTelescope`, but given `type` of the form `forall xs, A`,
it reduces `A` and continues bulding the telescope if it is a `forall`. -/
@[specialize] def forallTelescopeReducing {α}
(whnf : Expr → MetaM Expr)
(type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
forallTelescopeReducingAux whnf (isClassExpensive whnf) type none k
@[inline] def forallTelescopeReducing {α} (type : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
forallTelescopeReducingAux isClassExpensive type none k
/--
Similar to `forallTelescopeReducing`, stops constructing the telescope when
it reaches size `maxFVars`. -/
@[specialize] def forallBoundedTelescope {α}
(whnf : Expr → MetaM Expr)
(type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) : MetaM α :=
forallTelescopeReducingAux whnf (isClassExpensive whnf) type maxFVars? k
@[inline] def forallBoundedTelescope {α} (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → MetaM α) : MetaM α :=
forallTelescopeReducingAux isClassExpensive type maxFVars? k
@[specialize] def isClass
(whnf : Expr → MetaM Expr)
(type : Expr) : MetaM (Option Name) :=
def isClass (type : Expr) : MetaM (Option Name) :=
do c? ← isClassQuick type;
match c? with
| LOption.none => pure none
| LOption.some c => pure (some c)
| LOption.undef => isClassExpensive whnf type
| LOption.undef => isClassExpensive type
/-- Similar to `forallTelescopeAuxAux` but for lambda and let expressions. -/
@[specialize] private partial def lambdaTelescopeAux {α}
(whnf : Expr → MetaM Expr)
(k : Array Expr → Expr → MetaM α)
(k : Array Expr → Expr → MetaM α)
: LocalContext → Array Expr → Nat → Expr → MetaM α
| lctx, fvars, j, Expr.lam n d b c => do
let d := d.instantiateRevRange j fvars.size fvars;
@ -517,16 +560,15 @@ do c? ← isClassQuick type;
| lctx, fvars, j, e =>
let e := e.instantiateRevRange j fvars.size fvars;
adaptReader (fun (ctx : Context) => { lctx := lctx, .. ctx }) $
withNewLocalInstances (isClassExpensive whnf) fvars j $ do
withNewLocalInstances isClassExpensive fvars j $ do
k fvars e
/-- Similar to `forallTelescope` but for lambda and let expressions. -/
@[specialize] def lambdaTelescope {α}
(whnf : Expr → MetaM Expr)
(e : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α :=
savingCache $ do
lctx ← getLCtx;
lambdaTelescopeAux whnf k lctx #[] 0 e
lambdaTelescopeAux k lctx #[] 0 e
@[inline] def liftStateMCtx {α} (x : StateM MetavarContext α) : MetaM α :=
fun _ s =>
@ -544,7 +586,7 @@ do mvarId ← mkFreshId;
modify $ fun s => { mctx := s.mctx.addLevelMVarDecl mvarId, .. s };
pure mvarId
@[inline] def usingDefault (whnf : Expr → MetaM Expr) : Expr → MetaM Expr :=
def whnfUsingDefault : Expr → MetaM Expr :=
fun e => usingTransparency TransparencyMode.default $ whnf e
end Meta

View file

@ -14,94 +14,75 @@ whether terms produced by tactics and `isDefEq` are type correct.
namespace Lean
namespace Meta
@[specialize] private def ensureType
(whnf : Expr → MetaM Expr)
(e : Expr) : MetaM Unit :=
do getLevelAux whnf (inferTypeAux whnf) e;
pure ()
private def ensureType (e : Expr) : MetaM Unit :=
do getLevel e; pure ()
@[specialize] private def checkLambdaLet
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(check : Expr → MetaM Unit)
(e : Expr) : MetaM Unit :=
lambdaTelescope whnf e $ fun xs b => do
lambdaTelescope e $ fun xs b => do
xs.forM $ fun x => do {
xDecl ← getFVarLocalDecl x;
match xDecl with
| LocalDecl.cdecl _ _ _ t _ => do
ensureType whnf t;
ensureType t;
check t
| LocalDecl.ldecl _ _ _ t v => do
ensureType whnf t;
ensureType t;
check t;
vType ← inferTypeAux whnf v;
unlessM (isDefEq t vType) $ throwEx $ Exception.letTypeMismatch x.fvarId!;
vType ← inferType v;
unlessM (isExprDefEqAux t vType) $ throwEx $ Exception.letTypeMismatch x.fvarId!;
check v
};
check b
@[specialize] private def checkForall
(whnf : Expr → MetaM Expr)
(check : Expr → MetaM Unit)
(e : Expr) : MetaM Unit :=
forallTelescope whnf e $ fun xs b => do
forallTelescope e $ fun xs b => do
xs.forM $ fun x => do {
xDecl ← getFVarLocalDecl x;
ensureType whnf xDecl.type;
ensureType xDecl.type;
check xDecl.type
};
ensureType whnf b;
ensureType b;
check b
@[specialize] private def checkConstant
(c : Name) (lvls : List Level) : MetaM Unit :=
private def checkConstant (c : Name) (lvls : List Level) : MetaM Unit :=
do env ← getEnv;
match env.find c with
| none => throwEx $ Exception.unknownConst c
| some cinfo => unless (lvls.length != cinfo.lparams.length) $ throwEx $ Exception.incorrectNumOfLevels c lvls
@[specialize] private def checkApp
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(check : Expr → MetaM Unit)
(f a : Expr) : MetaM Unit :=
do check f;
check a;
fType ← inferTypeAux whnf f;
fType ← inferType f;
fType ← whnf fType;
match fType with
| Expr.forallE _ d _ _ => do
aType ← inferTypeAux whnf a;
unlessM (isDefEq d aType) $ throwEx $ Exception.appTypeMismatch f a
aType ← inferType a;
unlessM (isExprDefEqAux d aType) $ throwEx $ Exception.appTypeMismatch f a
| _ => unless fType.isForall $ throwEx $ Exception.functionExpected f a
@[specialize] private partial def checkAuxAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
: Expr → MetaM Unit
| e@(Expr.forallE _ _ _ _) => checkForall whnf checkAuxAux e
| e@(Expr.lam _ _ _ _) => checkLambdaLet whnf isDefEq checkAuxAux e
| e@(Expr.letE _ _ _ _ _) => checkLambdaLet whnf isDefEq checkAuxAux e
private partial def checkAux : Expr → MetaM Unit
| e@(Expr.forallE _ _ _ _) => checkForall checkAux e
| e@(Expr.lam _ _ _ _) => checkLambdaLet checkAux e
| e@(Expr.letE _ _ _ _ _) => checkLambdaLet checkAux e
| Expr.const c lvls _ => checkConstant c lvls
| Expr.app f a _ => checkApp whnf isDefEq checkAuxAux f a
| Expr.mdata _ e _ => checkAuxAux e
| Expr.proj _ _ e _ => checkAuxAux e
| Expr.app f a _ => checkApp checkAux f a
| Expr.mdata _ e _ => checkAux e
| Expr.proj _ _ e _ => checkAux e
| _ => pure ()
@[specialize] def checkAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(e : Expr) : MetaM Unit :=
usingTransparency TransparencyMode.all $
checkAuxAux whnf isDefEq e
def check (e : Expr) : MetaM Unit :=
usingTransparency TransparencyMode.all $ checkAux e
@[specialize] def isTypeCorrectAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(e : Expr) : MetaM Bool :=
def isTypeCorrect (e : Expr) : MetaM Bool :=
catch
(do checkAux whnf isDefEq e; pure true)
(do checkAux e; pure true)
(fun _ => pure false)
end Meta

View file

@ -23,17 +23,14 @@ namespace Meta
(fun x : A => f ?m) =?= f
```
The left-hand side of the constraint above it not eta-reduced because `?m` is a metavariable. -/
@[specialize] private def isDefEqEta
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(a b : Expr) : MetaM Bool :=
private def isDefEqEta (a b : Expr) : MetaM Bool :=
if a.isLambda && !b.isLambda then do
bType ← inferTypeAux whnf b;
bType ← usingDefault whnf bType;
bType ← inferType b;
bType ← whnfUsingDefault bType;
match bType with
| Expr.forallE n d b c =>
let b' := Lean.mkLambda n c.binderInfo d (mkApp b (mkBVar 0));
try $ isDefEq a b'
try $ isExprDefEqAux a b'
| _ => pure false
else
pure false
@ -81,8 +78,7 @@ match e.etaExpanded? with
Pre: `paramInfo.size <= args₁.size = args₂.size`
-/
@[specialize] private partial def isDefEqArgsFirstPass
(isDefEq : Expr → Expr → MetaM Bool)
private partial def isDefEqArgsFirstPass
(paramInfo : Array ParamInfo) (args₁ args₂ : Array Expr) : Nat → Array Nat → MetaM (Option (Array Nat))
| i, postponed =>
if h : i < paramInfo.size then
@ -91,39 +87,33 @@ match e.etaExpanded? with
let a₂ := args₂.get! i;
if info.implicit || info.instImplicit then
condM (isEtaUnassignedMVar a₁ <||> isEtaUnassignedMVar a₂)
(condM (isDefEq a₁ a₂)
(condM (isExprDefEqAux a₁ a₂)
(isDefEqArgsFirstPass (i+1) postponed)
(pure none))
(isDefEqArgsFirstPass (i+1) (postponed.push i))
else
condM (isDefEq a₁ a₂)
condM (isExprDefEqAux a₁ a₂)
(isDefEqArgsFirstPass (i+1) postponed)
(pure none)
else
pure (some postponed)
@[specialize] private partial def isDefEqArgsAux
(isDefEq : Expr → Expr → MetaM Bool)
(args₁ args₂ : Array Expr) (h : args₁.size = args₂.size) : Nat → MetaM Bool
private partial def isDefEqArgsAux (args₁ args₂ : Array Expr) (h : args₁.size = args₂.size) : Nat → MetaM Bool
| i =>
if h₁ : i < args₁.size then
let a₁ := args₁.get ⟨i, h₁⟩;
let a₂ := args₂.get ⟨i, h ▸ h₁⟩;
condM (isDefEq a₁ a₂)
condM (isExprDefEqAux a₁ a₂)
(isDefEqArgsAux (i+1))
(pure false)
else
pure true
@[specialize] private def isDefEqArgs
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(f : Expr) (args₁ args₂ : Array Expr) : MetaM Bool :=
private def isDefEqArgs (f : Expr) (args₁ args₂ : Array Expr) : MetaM Bool :=
if h : args₁.size = args₂.size then do
finfo ← getFunInfoNArgsAux whnf f args₁.size;
isDefEqArgsAux isDefEq args₁ args₂ h finfo.paramInfo.size;
(some postponed) ← isDefEqArgsFirstPass isDefEq finfo.paramInfo args₁ args₂ 0 #[] | pure false;
finfo ← getFunInfoNArgs f args₁.size;
isDefEqArgsAux args₁ args₂ h finfo.paramInfo.size;
(some postponed) ← isDefEqArgsFirstPass finfo.paramInfo args₁ args₂ 0 #[] | pure false;
/- Second pass: unify implicit arguments.
In the second pass, we make sure we are unfolding at
least non reducible definitions (default setting). -/
@ -132,11 +122,11 @@ if h : args₁.size = args₂.size then do
let a₂ := args₂.get! i;
let info := finfo.paramInfo.get! i;
when info.instImplicit $ do {
synthesizePending a₁;
synthesizePending a₂;
synthPending a₁;
synthPending a₂;
pure ()
};
usingAtLeastTransparency TransparencyMode.default $ isDefEq a₁ a₂
usingAtLeastTransparency TransparencyMode.default $ isExprDefEqAux a₁ a₂
else
pure false
@ -151,18 +141,15 @@ else
We can't use `withNewLocalInstances` because the `isDeq fvarType d₂`
may use local instances. -/
@[specialize] partial def isDefEqBindingDomain
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(fvars : Array Expr) (ds₂ : Array Expr) : Nat → MetaM Bool → MetaM Bool
@[specialize] partial def isDefEqBindingDomain (fvars : Array Expr) (ds₂ : Array Expr) : Nat → MetaM Bool → MetaM Bool
| i, k =>
if h : i < fvars.size then do
let fvar := fvars.get ⟨i, h⟩;
fvarDecl ← getFVarLocalDecl fvar;
let fvarType := fvarDecl.type;
let d₂ := ds₂.get! i;
condM (isDefEq fvarType d₂)
(do c? ← isClass whnf fvarType;
condM (isExprDefEqAux fvarType d₂)
(do c? ← isClass fvarType;
match c? with
| some className => withNewLocalInstance className fvar $ isDefEqBindingDomain (i+1) k
| none => isDefEqBindingDomain (i+1) k)
@ -174,10 +161,7 @@ else
It accumulates the new free variables in `fvars`, and declare them at `lctx`.
We use the domain types of `e₁` to create the new free variables.
We store the domain types of `e₂` at `ds₂`. -/
@[specialize] private partial def isDefEqBindingAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
: LocalContext → Array Expr → Expr → Expr → Array Expr → MetaM Bool
private partial def isDefEqBindingAux : LocalContext → Array Expr → Expr → Expr → Array Expr → MetaM Bool
| lctx, fvars, e₁, e₂, ds₂ =>
let process (n : Name) (d₁ d₂ b₁ b₂ : Expr) : MetaM Bool := do {
let d₁ := d₁.instantiateRev fvars;
@ -192,15 +176,12 @@ else
| Expr.lam n d₁ b₁ _, Expr.lam _ d₂ b₂ _ => process n d₁ d₂ b₁ b₂
| _, _ =>
adaptReader (fun (ctx : Context) => { lctx := lctx, .. ctx }) $
isDefEqBindingDomain whnf isDefEq fvars ds₂ 0 $
isDefEq (e₁.instantiateRev fvars) (e₂.instantiateRev fvars)
isDefEqBindingDomain fvars ds₂ 0 $
isExprDefEqAux (e₁.instantiateRev fvars) (e₂.instantiateRev fvars)
@[inline] private def isDefEqBinding
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(a b : Expr) : MetaM Bool :=
@[inline] private def isDefEqBinding (a b : Expr) : MetaM Bool :=
do lctx ← getLCtx;
isDefEqBindingAux whnf isDefEq lctx #[] a b #[]
isDefEqBindingAux lctx #[] a b #[]
/-
Each metavariable is declared in a particular local context.
@ -389,7 +370,7 @@ instance : MonadCache Expr Expr CheckAssignmentM :=
@[inline] private def visit (f : Expr → CheckAssignmentM Expr) (e : Expr) : CheckAssignmentM Expr :=
if !e.hasExprMVar && !e.hasFVar then pure e else checkCache e f
@[inline] def checkFVar (check : Expr → CheckAssignmentM Expr) (fvar : Expr) : CheckAssignmentM Expr :=
@[specialize] def checkFVar (check : Expr → CheckAssignmentM Expr) (fvar : Expr) : CheckAssignmentM Expr :=
do ctx ← read;
if ctx.mvarDecl.lctx.containsFVar fvar then pure fvar
else do
@ -409,7 +390,7 @@ do s ← get;
modify $ fun s => { ngen := s.ngen.next, mctx := s.mctx.addExprMVarDecl mvarId Name.anonymous lctx type, .. s };
pure (mkMVar mvarId)
@[inline] def checkMVar (check : Expr → CheckAssignmentM Expr) (mvar : Expr) : CheckAssignmentM Expr :=
@[specialize] def checkMVar (check : Expr → CheckAssignmentM Expr) (mvar : Expr) : CheckAssignmentM Expr :=
do let mvarId := mvar.mvarId!;
ctx ← read;
mctx ← getMCtx;
@ -550,22 +531,18 @@ fun ctx s => if !v.hasExprMVar && !v.hasFVar then EStateM.Result.ok (some v) s e
We try to unify arguments before we try to unify the functions.
The motivation is the following: the universe constraints in
the arguments propagate to the function. -/
@[specialize] private partial def isDefEqFOApprox
(isDefEq : Expr → Expr → MetaM Bool)
(f₁ f₂ : Expr) (args₁ args₂ : Array Expr) : Nat → Nat → MetaM Bool
private partial def isDefEqFOApprox (f₁ f₂ : Expr) (args₁ args₂ : Array Expr) : Nat → Nat → MetaM Bool
| i₁, i₂ =>
if h : i₂ < args₂.size then
let arg₁ := args₁.get! i₁;
let arg₂ := args₂.get ⟨i₂, h⟩;
condM (isDefEq arg₁ arg₂)
condM (isExprDefEqAux arg₁ arg₂)
(isDefEqFOApprox (i₁+1) (i₂+1))
(pure false)
else
isDefEq f₁ f₂
isExprDefEqAux f₁ f₂
@[specialize] private def processAssignmentFOApproxAux
(isDefEq : Expr → Expr → MetaM Bool)
(mvar : Expr) (args : Array Expr) (v : Expr) : MetaM Bool :=
private def processAssignmentFOApproxAux (mvar : Expr) (args : Array Expr) (v : Expr) : MetaM Bool :=
let vArgs := v.getAppArgs;
if vArgs.isEmpty then
/- ?m a_1 ... a_k =?= t, where t is not an application -/
@ -583,7 +560,7 @@ else if args.size > vArgs.size then
-/
let f₁ := mkAppRange mvar 0 (args.size - vArgs.size) args;
let i₁ := args.size - vArgs.size;
isDefEqFOApprox isDefEq f₁ v.getAppFn args vArgs i₁ 0
isDefEqFOApprox f₁ v.getAppFn args vArgs i₁ 0
else if args.size < vArgs.size then
/-
?m a_1 ... a_k =?= f b_1 ... b_i b_{i+1} ... b_{i+k}
@ -597,7 +574,7 @@ else if args.size < vArgs.size then
-/
let vFn := mkAppRange v.getAppFn 0 (vArgs.size - args.size) vArgs;
let i₂ := vArgs.size - args.size;
isDefEqFOApprox isDefEq mvar vFn args vArgs 0 i₂
isDefEqFOApprox mvar vFn args vArgs 0 i₂
else
/-
?m a_1 ... a_k =?= f b_1 ... b_k
@ -609,7 +586,7 @@ else
...
a_k =?= b_k
-/
isDefEqFOApprox isDefEq mvar v.getAppFn args vArgs 0 0
isDefEqFOApprox mvar v.getAppFn args vArgs 0 0
/-
Auxiliary method for applying first-order unification. It is an approximation.
@ -627,15 +604,11 @@ else
def ITactic := Tactic Unit
-/
@[specialize] private partial def processAssignmentFOApprox
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(mvar : Expr) (args : Array Expr) : Expr → MetaM Bool
private partial def processAssignmentFOApprox (mvar : Expr) (args : Array Expr) : Expr → MetaM Bool
| v =>
condM (try $ processAssignmentFOApproxAux isDefEq mvar args v)
condM (try $ processAssignmentFOApproxAux mvar args v)
(pure true)
(unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending v (pure false) processAssignmentFOApprox)
(unfoldDefinitionAux v (pure false) processAssignmentFOApprox)
private partial def simpAssignmentArgAux : Expr → MetaM Expr
| Expr.mdata _ e _ => simpAssignmentArgAux e
@ -653,14 +626,7 @@ private def simpAssignmentArg (arg : Expr) : MetaM Expr :=
do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg;
simpAssignmentArgAux arg
@[specialize] private partial def processAssignmentAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(mvar : Expr)
(mvarDecl : MetavarDecl)
(v : Expr)
: Nat → Array Expr → MetaM Bool
private partial def processAssignmentAux (mvar : Expr) (mvarDecl : MetavarDecl) (v : Expr) : Nat → Array Expr → MetaM Bool
| i, args =>
if h : i < args.size then do
cfg ← getConfig;
@ -669,7 +635,7 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg;
let args := args.set ⟨i, h⟩ arg;
let useFOApprox : Unit → MetaM Bool := fun _ =>
if cfg.foApprox then
processAssignmentFOApprox whnf isDefEq synthesizePending mvar args v
processAssignmentFOApprox mvar args v
else
pure false;
match arg with
@ -686,10 +652,10 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg;
cfg ← getConfig;
v ← instantiateMVars v; -- enforce A4
if cfg.foApprox && args.isEmpty && v.getAppFn == mvar then
processAssignmentFOApprox whnf isDefEq synthesizePending mvar args v
processAssignmentFOApprox mvar args v
else do
let useFOApprox : Unit → MetaM Bool := fun _ =>
if cfg.foApprox then processAssignmentFOApprox whnf isDefEq synthesizePending mvar args v
if cfg.foApprox then processAssignmentFOApprox mvar args v
else pure false;
let mvarId := mvar.mvarId!;
v? ← checkAssignment mvarId args v;
@ -699,9 +665,9 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg;
v ← mkLambda args v;
let finalize : Unit → MetaM Bool := fun _ => do {
-- must check whether types are definitionally equal or not, before assigning and returning true
mvarType ← inferTypeAux whnf mvar;
vType ← inferTypeAux whnf v;
condM (usingTransparency TransparencyMode.default $ isDefEq mvarType vType)
mvarType ← inferType mvar;
vType ← inferType v;
condM (usingTransparency TransparencyMode.default $ isExprDefEqAux mvarType vType)
(do assignExprMVar mvarId v; pure true)
(do trace! `Meta.isDefEq.assignment.typeMismatch (mvar ++ " : " ++ mvarType ++ " := " ++ v ++ " : " ++ vType);
pure false)
@ -709,7 +675,7 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg;
if args.any (fun arg => mvarDecl.lctx.containsFVar arg) then
/- We need to type check `v` because abstraction using `mkLambda` may have produced
a type incorrect term. See discussion at A2 -/
condM (isTypeCorrectAux whnf isDefEq v)
condM (isTypeCorrect v)
(finalize ())
(useFOApprox ())
else
@ -717,22 +683,10 @@ do arg ← if arg.getAppFn.hasExprMVar then instantiateMVars arg else pure arg;
/-- Tries to solve `?m a₁ ... aₙ =?= v` by assigning `?m`.
It assumes `?m` is unassigned. -/
@[specialize] private def processAssignment
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(mvarApp : Expr) (v : Expr) : MetaM Bool :=
private def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool :=
do let mvar := mvarApp.getAppFn;
mvarDecl ← getMVarDecl mvar.mvarId!;
processAssignmentAux whnf isDefEq synthesizePending mvar mvarDecl v 0 mvarApp.getAppArgs
@[specialize] private def unfold {α}
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(e : Expr)
(failK : MetaM α) (successK : Expr → MetaM α) : MetaM α :=
unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK
processAssignmentAux mvar mvarDecl v 0 mvarApp.getAppArgs
private def isDeltaCandidate (t : Expr) : MetaM (Option ConstantInfo) :=
match t.getAppFn with
@ -744,64 +698,60 @@ match t.getAppFn, s.getAppFn with
| Expr.const c₁ _ _, Expr.const c₂ _ _ => true
| _, _ => false
@[specialize] private def isDefEqDelta
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(t s : Expr)
(k : MetaM Bool) -- continuation when `isDefEqDelta` could not decide
: MetaM Bool :=
do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do {
@[specialize] private def isDefEqDelta (t s : Expr) : MetaM LBool :=
do let isListLevelDefEqAuxL (us vs) : MetaM LBool := toLBoolM $ isListLevelDefEqAux us vs;
let isDefEqL (t s) : MetaM LBool := toLBoolM $ isExprDefEqAux t s;
let isDefEqLeft (fn : Name) (t s : Expr) : MetaM LBool := do {
trace! `Meta.isDefEq.delta.unfoldLeft fn;
isDefEq t s
isDefEqL t s
};
let isDefEqRight (fn : Name) (t s : Expr) : MetaM Bool := do {
let isDefEqRight (fn : Name) (t s : Expr) : MetaM LBool := do {
trace! `Meta.isDefEq.delta.unfoldRight fn;
isDefEq t s
isDefEqL t s
};
let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM Bool := do {
let isDefEqLeftRight (fn : Name) (t s : Expr) : MetaM LBool := do {
trace! `Meta.isDefEq.delta.unfoldLeftRight fn;
isDefEq t s
isDefEqL t s
};
let unfold (e failK successK) : MetaM Bool := unfoldDefinitionAux whnf (inferTypeAux whnf) isDefEq synthesizePending e failK successK;
let unfold (e failK successK) : MetaM LBool := unfoldDefinitionAux e failK successK;
let tryHeuristic : MetaM Bool :=
/- Try to solve `f a₁ ... aₙ =?= f b₁ ... bₙ` by solving `a₁ =?= b₁, ..., aₙ =?= bₙ` -/
let tFn := t.getAppFn;
let sFn := s.getAppFn;
traceCtx `Meta.isDefEq.delta $
try $
isDefEqArgs whnf isDefEq synthesizePending tFn t.getAppArgs s.getAppArgs
isDefEqArgs tFn t.getAppArgs s.getAppArgs
<&&>
isListLevelDefEqAux tFn.constLevels! sFn.constLevels!;
tInfo? ← isDeltaCandidate t.getAppFn;
sInfo? ← isDeltaCandidate s.getAppFn;
match tInfo?, sInfo? with
| none, none => k
| some tInfo, none => unfold t k $ fun t => isDefEqLeft tInfo.name t s
| none, some sInfo => unfold s k $ fun s => isDefEqRight sInfo.name t s
| none, none => pure LBool.undef
| some tInfo, none => unfold t (pure LBool.undef) $ fun t => isDefEqLeft tInfo.name t s
| none, some sInfo => unfold s (pure LBool.undef) $ fun s => isDefEqRight sInfo.name t s
| some tInfo, some sInfo =>
let isDefEqLeft (t s) := isDefEqLeft tInfo.name t s;
let isDefEqRight (t s) := isDefEqRight sInfo.name t s;
let isDefEqLeftRight (t s) := isDefEqLeftRight tInfo.name t s;
if tInfo.name == sInfo.name then
match t, s with
| Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAux ls₁ ls₂
| Expr.const _ ls₁ _, Expr.const _ ls₂ _ => isListLevelDefEqAuxL ls₁ ls₂
| Expr.app _ _ _, Expr.app _ _ _ =>
condM tryHeuristic
(pure true)
(pure LBool.true)
(unfold t
(unfold s (pure false) (fun s => isDefEqRight t s))
(unfold s (pure LBool.false) (fun s => isDefEqRight t s))
(fun t => unfold s (isDefEqLeft t s) (fun s => isDefEqLeftRight t s)))
| _, _ => pure false
| _, _ => pure LBool.false
else
let unfoldComparingHeads : Unit → MetaM Bool := fun _ =>
let unfoldComparingHeads : Unit → MetaM LBool := fun _ =>
/-
- If headSymbol (unfold t) == headSymbol s, then unfold t
- If headSymbol (unfold s) == headSymbol t, then unfold s
- Otherwise unfold t and s if possible. -/
unfold t
(unfold s
k -- `t` and `s` failed to be unfolded
(pure LBool.undef) -- `t` and `s` failed to be unfolded
(fun s => isDefEqRight t s))
(fun tNew =>
if sameHeadSymbol tNew s then
@ -810,7 +760,7 @@ do let isDefEqLeft (fn : Name) (t s : Expr) : MetaM Bool := do {
unfold s
(isDefEqLeft tNew s)
(fun sNew => if sameHeadSymbol t sNew then isDefEqRight t sNew else isDefEqLeftRight tNew sNew));
let kernelLikeUnfolding : Unit → MetaM Bool := fun _ =>
let kernelLikeUnfolding : Unit → MetaM LBool := fun _ =>
if !t.hasExprMVar && !s.hasExprMVar then
/- If `t` and `s` do not contain metavariables,
we simulate strategy used in the kernel. -/
@ -858,42 +808,38 @@ private def isLetFVar (fvarId : Name) : MetaM Bool :=
do decl ← getLocalDecl fvarId;
pure decl.isLet
@[specialize] private partial def isDefEqQuick
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
: Expr → Expr → MetaM Bool → MetaM Bool
| Expr.lit l₁ _, Expr.lit l₂ _, _ => pure (l₁ == l₂)
| Expr.sort u _, Expr.sort v _, _ => isLevelDefEqAux u v
| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _), _ => isDefEqBinding whnf isDefEq t s
| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _), _ => isDefEqBinding whnf isDefEq t s
| Expr.mdata _ t _, s, k => isDefEqQuick t s k
| t, Expr.mdata _ s _, k => isDefEqQuick t s k
| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _, k =>
private partial def isDefEqQuick : Expr → Expr → MetaM LBool
| Expr.lit l₁ _, Expr.lit l₂ _ => pure (l₁ == l₂).toLBool
| Expr.sort u _, Expr.sort v _ => toLBoolM $ isLevelDefEqAux u v
| t@(Expr.lam _ _ _ _), s@(Expr.lam _ _ _ _) => toLBoolM $ isDefEqBinding t s
| t@(Expr.forallE _ _ _ _), s@(Expr.forallE _ _ _ _) => toLBoolM $ isDefEqBinding t s
| Expr.mdata _ t _, s => isDefEqQuick t s
| t, Expr.mdata _ s _ => isDefEqQuick t s
| Expr.fvar fvarId₁ _, Expr.fvar fvarId₂ _ =>
condM (isLetFVar fvarId₁ <||> isLetFVar fvarId₂)
k
(pure (fvarId₁ == fvarId₂))
| t, s, k =>
cond (t == s) (pure true) $
cond (etaEq t s || etaEq s t) (pure true) $ -- t =?= (fun xs => t xs)
(pure LBool.undef)
(pure (fvarId₁ == fvarId₂).toLBool)
| t, s =>
cond (t == s) (pure LBool.true) $
cond (etaEq t s || etaEq s t) (pure LBool.true) $ -- t =?= (fun xs => t xs)
let tFn := t.getAppFn;
let sFn := s.getAppFn;
cond (!tFn.isMVar && !sFn.isMVar) k $
condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $
condM (isSynthetic tFn <&&> synthesizePending tFn) (do t ← instantiateMVars t; isDefEqQuick t s k) $
condM (isSynthetic sFn <&&> synthesizePending sFn) (do s ← instantiateMVars s; isDefEqQuick t s k) $ do
cond (!tFn.isMVar && !sFn.isMVar) (pure LBool.undef) $
condM (isAssigned tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
condM (isAssigned sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $
condM (isSynthetic tFn <&&> synthPending tFn) (do t ← instantiateMVars t; isDefEqQuick t s) $
condM (isSynthetic sFn <&&> synthPending sFn) (do s ← instantiateMVars s; isDefEqQuick t s) $ do
tAssign? ← isAssignable tFn;
sAssign? ← isAssignable sFn;
let assign (t s : Expr) : MetaM Bool := processAssignment whnf isDefEq synthesizePending t s;
let assign (t s : Expr) : MetaM LBool := toLBoolM $ processAssignment t s;
cond (tAssign? && !sAssign?) (assign t s) $
cond (!tAssign? && sAssign?) (assign s t) $
cond (!tAssign? && !sAssign?)
(if tFn.isMVar || sFn.isMVar then do
ctx ← read;
if ctx.config.isDefEqStuckEx then throwEx $ Exception.isDefEqStuck t s
else pure false
else k) $ do
else pure LBool.false
else pure LBool.undef) $ do
-- Both `t` and `s` are terms of the form `?m ...`
tMVarDecl ← getMVarDecl tFn.mvarId!;
sMVarDecl ← getMVarDecl sFn.mvarId!;
@ -912,80 +858,68 @@ do decl ← getLocalDecl fvarId;
cond (!s.isApp && t.isApp && tMVarDecl.lctx.isSubPrefixOf sMVarDecl.lctx) (assign s t) $
assign t s
@[specialize] private def isDefEqProofIrrel
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(t s : Expr)
(k : MetaM Bool) -- continuation when `isDefEqQuick` could not decide
: MetaM Bool :=
do tType ← inferTypeAux whnf t;
condM (isPropAux whnf tType)
(do sType ← inferTypeAux whnf s; isDefEq tType sType)
k
private def isDefEqProofIrrel (t s : Expr) : MetaM LBool :=
do tType ← inferType t;
condM (isProp tType)
(do sType ← inferType s; toLBoolM $ isExprDefEqAux tType sType)
(pure LBool.undef)
@[specialize] private partial def whnfCoreAux
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(e : Expr) : MetaM Expr :=
Lean.whnfCore getConstNoEx isAuxDef? whnf (inferTypeAux whnf) isDefEq getLocalDecl getExprMVarAssignment e
private def whnfCoreAux (e : Expr) : MetaM Expr :=
Lean.whnfCore getConstNoEx isAuxDef? whnf inferType isExprDefEqAux getLocalDecl getExprMVarAssignment e
@[inline] def tryL (x : MetaM LBool) (k : MetaM Bool) : MetaM Bool :=
do status ← x;
match status with
| LBool.true => pure true
| LBool.false => pure false
| LBool.undef => k
@[specialize] private partial def isDefEqWHNF
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(t s : Expr)
(k : Expr → Expr → MetaM Bool) : MetaM Bool :=
do t' ← whnfCoreAux whnf isDefEq synthesizePending t;
s' ← whnfCoreAux whnf isDefEq synthesizePending s;
do t' ← whnfCoreAux t;
s' ← whnfCoreAux s;
if t == t' && s == s' then
k t s
k t' s'
else
isDefEqQuick whnf isDefEq synthesizePending t' s' $ k t' s'
tryL (isDefEqQuick t' s') $ k t' s'
@[specialize] private def unstuckMVar
(whnf : Expr → MetaM Expr)
(synthesizePending : Expr → MetaM Bool)
(e : Expr)
(successK : Expr → MetaM Bool) (failK : MetaM Bool): MetaM Bool :=
do s? ← getStuckMVar getConst whnf e;
match s? with
| some s =>
condM (synthesizePending s)
condM (synthPending s)
(do e ← instantiateMVars e; successK e)
failK
| none => failK
@[specialize] private def isDefEqOnFailure
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(t s : Expr) : MetaM Bool :=
unstuckMVar whnf synthesizePending t (fun t => isDefEq t s) $
unstuckMVar whnf synthesizePending s (fun s => isDefEq t s) $
private def isDefEqOnFailure (t s : Expr) : MetaM Bool :=
unstuckMVar t (fun t => isExprDefEqAux t s) $
unstuckMVar s (fun s => isExprDefEqAux t s) $
pure false
@[specialize] partial def isExprDefEqAux
(whnf : Expr → MetaM Expr)
(synthesizePending : Expr → MetaM Bool)
: Expr → Expr → MetaM Bool
partial def isExprDefEqAuxImpl : Expr → Expr → MetaM Bool
| t, s => do
trace! `Meta.isDefEq.step (t ++ " =?= " ++ s);
isDefEqQuick whnf isExprDefEqAux synthesizePending t s $
isDefEqProofIrrel whnf isExprDefEqAux t s $
isDefEqWHNF whnf isExprDefEqAux synthesizePending t s $ fun t s =>
isDefEqOffset whnf isExprDefEqAux t s $
isDefEqDelta whnf isExprDefEqAux synthesizePending t s $
condM (isDefEqEta whnf isExprDefEqAux t s <||> isDefEqEta whnf isExprDefEqAux s t) (pure true) $
tryL (isDefEqQuick t s) $
tryL (isDefEqProofIrrel t s) $
isDefEqWHNF t s $ fun t s =>
tryL (isDefEqOffset t s) $
tryL (isDefEqDelta t s) $
condM (isDefEqEta t s <||> isDefEqEta s t) (pure true) $
match t, s with
| Expr.const _ us _, Expr.const _ vs _ => isListLevelDefEqAux us vs
| Expr.app _ _ _, Expr.app _ _ _ =>
let tFn := t.getAppFn;
condM (try (isExprDefEqAux tFn s.getAppFn <&&>
isDefEqArgs whnf isExprDefEqAux synthesizePending tFn t.getAppArgs s.getAppArgs))
condM (try (isExprDefEqAux tFn s.getAppFn <&&> isDefEqArgs tFn t.getAppArgs s.getAppArgs))
(pure true)
(isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s)
| _, _ => isDefEqOnFailure whnf isExprDefEqAux synthesizePending t s
(isDefEqOnFailure t s)
| _, _ => isDefEqOnFailure t s
@[init] def setIsExprDefEqAuxRef : IO Unit :=
isExprDefEqAuxRef.set isExprDefEqAuxImpl
end Meta
end Lean

View file

@ -20,10 +20,10 @@ do s ← get;
modify $ fun s => { cache := { funInfo := s.cache.funInfo.insert ⟨t, fn, maxArgs?⟩ finfo, .. s.cache }, .. s };
pure finfo
@[inline] def whenHasVar {α} (e : Expr) (deps : α) (k : αα) : α :=
@[inline] private def whenHasVar {α} (e : Expr) (deps : α) (k : αα) : α :=
if e.hasFVar then k deps else deps
def collectDepsAux (fvars : Array Expr) : Expr → Array Nat → Array Nat
private def collectDepsAux (fvars : Array Expr) : Expr → Array Nat → Array Nat
| e@(Expr.app f a _), deps => whenHasVar e deps (collectDepsAux a ∘ collectDepsAux f)
| e@(Expr.forallE _ d b _), deps => whenHasVar e deps (collectDepsAux b ∘ collectDepsAux d)
| e@(Expr.lam _ d b _), deps => whenHasVar e deps (collectDepsAux b ∘ collectDepsAux d)
@ -36,7 +36,7 @@ def collectDepsAux (fvars : Array Expr) : Expr → Array Nat → Array Nat
| some i => if deps.contains i.val then deps else deps.push i.val
| _, deps => deps
def collectDeps (fvars : Array Expr) (e : Expr) : Array Nat :=
private def collectDeps (fvars : Array Expr) (e : Expr) : Array Nat :=
let deps := collectDepsAux fvars e #[];
deps.qsort (fun i j => i < j)
@ -53,38 +53,33 @@ else
else
info
@[specialize] def getFunInfoAuxAux
(whnf : Expr → MetaM Expr)
(fn : Expr) (maxArgs? : Option Nat) : MetaM FunInfo :=
private def getFunInfoAux (fn : Expr) (maxArgs? : Option Nat) : MetaM FunInfo :=
checkFunInfoCache fn maxArgs? $ do
fnType ← inferTypeAux whnf fn;
forallBoundedTelescope (usingDefault whnf) fnType maxArgs? $ fun fvars type => do
pinfo ← fvars.size.foldM
(fun (i : Nat) (pinfo : Array ParamInfo) => do
let fvar := fvars.get! i;
decl ← getFVarLocalDecl fvar;
prop ← isPropAux whnf decl.type;
let backDeps := collectDeps fvars decl.type;
let pinfo := updateHasFwdDeps pinfo backDeps;
pure $ pinfo.push {
backDeps := backDeps,
prop := prop,
implicit := decl.binderInfo == BinderInfo.implicit,
instImplicit := decl.binderInfo == BinderInfo.instImplicit })
#[];
let resultDeps := collectDeps fvars type;
let pinfo := updateHasFwdDeps pinfo resultDeps;
pure { resultDeps := resultDeps, paramInfo := pinfo }
fnType ← inferType fn;
usingTransparency TransparencyMode.default $
forallBoundedTelescope fnType maxArgs? $ fun fvars type => do
pinfo ← fvars.size.foldM
(fun (i : Nat) (pinfo : Array ParamInfo) => do
let fvar := fvars.get! i;
decl ← getFVarLocalDecl fvar;
prop ← isProp decl.type;
let backDeps := collectDeps fvars decl.type;
let pinfo := updateHasFwdDeps pinfo backDeps;
pure $ pinfo.push {
backDeps := backDeps,
prop := prop,
implicit := decl.binderInfo == BinderInfo.implicit,
instImplicit := decl.binderInfo == BinderInfo.instImplicit })
#[];
let resultDeps := collectDeps fvars type;
let pinfo := updateHasFwdDeps pinfo resultDeps;
pure { resultDeps := resultDeps, paramInfo := pinfo }
@[inline] def getFunInfoAux
(whnf : Expr → MetaM Expr)
(fn : Expr) : MetaM FunInfo :=
getFunInfoAuxAux whnf fn none
def getFunInfo (fn : Expr) : MetaM FunInfo :=
getFunInfoAux fn none
@[inline] def getFunInfoNArgsAux
(whnf : Expr → MetaM Expr)
(fn : Expr) (nargs : Nat) : MetaM FunInfo :=
getFunInfoAuxAux whnf fn (some nargs)
def getFunInfoNArgs (fn : Expr) (nargs : Nat) : MetaM FunInfo :=
getFunInfoAux fn (some nargs)
end Meta
end Lean

View file

@ -10,10 +10,7 @@ import Init.Lean.Meta.Basic
namespace Lean
namespace Meta
@[specialize] private def inferAppType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(f : Expr) (args : Array Expr) : MetaM Expr :=
private def inferAppType (f : Expr) (args : Array Expr) : MetaM Expr :=
do fType ← inferType f;
(j, fType) ← args.size.foldM
(fun i (acc : Nat × Expr) =>
@ -39,10 +36,7 @@ do env ← getEnv;
| none =>
throwEx $ Exception.unknownConst c
@[specialize] private def inferProjType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(structName : Name) (idx : Nat) (e : Expr) : MetaM Expr :=
private def inferProjType (structName : Name) (idx : Nat) (e : Expr) : MetaM Expr :=
do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProjection structName idx e;
structType ← inferType e;
structType ← whnf structType;
@ -55,7 +49,7 @@ do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProje
else match env.find ctor with
| none => failed ()
| some (ctorInfo) => do
ctorType ← inferAppType whnf inferType (mkConst ctor structLvls) structParams;
ctorType ← inferAppType (mkConst ctor structLvls) structParams;
ctorType ← idx.foldM
(fun i ctorType => do
ctorType ← whnf ctorType;
@ -73,10 +67,7 @@ do let failed : Unit → MetaM Expr := fun _ => throwEx $ Exception.invalidProje
| _ => failed ()
| _ => failed ()
@[specialize] def getLevelAux
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(type : Expr) : MetaM Level :=
def getLevel (type : Expr) : MetaM Level :=
do typeType ← inferType type;
typeType ← whnf typeType;
match typeType with
@ -90,26 +81,20 @@ do typeType ← inferType type;
pure lvl)
| _ => throwEx $ Exception.typeExpected type
@[specialize] private def inferForallType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
forallTelescope whnf e $ fun xs e => do
lvl ← getLevelAux whnf inferType e;
private def inferForallType (e : Expr) : MetaM Expr :=
forallTelescope e $ fun xs e => do
lvl ← getLevel e;
lvl ← xs.foldrM
(fun x lvl => do
xType ← inferType x;
xTypeLvl ← getLevelAux whnf inferType xType;
xTypeLvl ← getLevel xType;
pure $ mkLevelIMax xTypeLvl lvl)
lvl;
pure $ mkSort lvl.normalize
/- Infer type of lambda and let expressions -/
@[specialize] private def inferLambdaType
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
lambdaTelescope whnf e $ fun xs e => do
private def inferLambdaType (e : Expr) : MetaM Expr :=
lambdaTelescope e $ fun xs e => do
type ← inferType e;
mkForall xs type
@ -140,27 +125,26 @@ do s ← get;
modify $ fun s => { cache := { inferType := s.cache.inferType.insert e type, .. s.cache }, .. s };
pure type
@[specialize] partial def inferTypeAuxAux
(whnf : Expr → MetaM Expr)
: Expr → MetaM Expr
private partial def inferTypeAux : Expr → MetaM Expr
| Expr.const c lvls _ => inferConstType c lvls
| e@(Expr.proj n i s _) => checkInferTypeCache e (inferProjType whnf inferTypeAuxAux n i s)
| e@(Expr.app f _ _) => checkInferTypeCache e (inferAppType whnf inferTypeAuxAux f.getAppFn e.getAppArgs)
| e@(Expr.proj n i s _) => checkInferTypeCache e (inferProjType n i s)
| e@(Expr.app f _ _) => checkInferTypeCache e (inferAppType f.getAppFn e.getAppArgs)
| Expr.mvar mvarId _ => inferMVarType mvarId
| Expr.fvar fvarId _ => inferFVarType fvarId
| Expr.bvar bidx _ => throw $ Exception.unexpectedBVar bidx
| Expr.mdata _ e _ => inferTypeAuxAux e
| Expr.mdata _ e _ => inferTypeAux e
| Expr.lit v _ => pure v.type
| Expr.sort lvl _ => pure $ mkSort (mkLevelSucc lvl)
| e@(Expr.forallE _ _ _ _) => checkInferTypeCache e (inferForallType whnf inferTypeAuxAux e)
| e@(Expr.lam _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAuxAux e)
| e@(Expr.letE _ _ _ _ _) => checkInferTypeCache e (inferLambdaType whnf inferTypeAuxAux e)
| e@(Expr.forallE _ _ _ _) => checkInferTypeCache e (inferForallType e)
| e@(Expr.lam _ _ _ _) => checkInferTypeCache e (inferLambdaType e)
| e@(Expr.letE _ _ _ _ _) => checkInferTypeCache e (inferLambdaType e)
| Expr.localE _ _ _ _ => unreachable!
@[inline] def inferTypeAux
(whnf : Expr → MetaM Expr)
(e : Expr) : MetaM Expr :=
inferTypeAuxAux (usingDefault whnf) e
def inferTypeImpl (e : Expr) : MetaM Expr :=
usingTransparency TransparencyMode.default (inferTypeAux e)
@[init] def setInferTypeRef : IO Unit :=
inferTypeRef.set inferTypeImpl
/--
Return `LBool.true` if given level is always equivalent to universe level zero.
@ -177,7 +161,7 @@ private def isAlwaysZero : Level → Bool
`isArrowProp type n` is an "approximate" predicate which returns `LBool.true`
if `type` is of the form `A_1 -> ... -> A_n -> Prop`.
Remark: `type` can be a dependent arrow. -/
@[specialize] private partial def isArrowProp : Expr → Nat → MetaM LBool
private partial def isArrowProp : Expr → Nat → MetaM LBool
| Expr.sort u _, 0 => do u ← instantiateLevelMVars u; pure $ (isAlwaysZero u).toLBool
| Expr.forallE _ _ _ _, 0 => pure LBool.false
| Expr.forallE _ _ b _, n+1 => isArrowProp b n
@ -188,7 +172,7 @@ private def isAlwaysZero : Level → Bool
/--
`isPropQuickApp f n` is an "approximate" predicate which returns `LBool.true`
if `f` applied to `n` arguments is a proposition. -/
@[specialize] private partial def isPropQuickApp : Expr → Nat → MetaM LBool
private partial def isPropQuickApp : Expr → Nat → MetaM LBool
| Expr.const c lvls _, arity => do constType ← inferConstType c lvls; isArrowProp constType arity
| Expr.fvar fvarId _, arity => do fvarType ← inferFVarType fvarId; isArrowProp fvarType arity
| Expr.mvar mvarId _, arity => do mvarType ← inferMVarType mvarId; isArrowProp mvarType arity
@ -202,7 +186,7 @@ private def isAlwaysZero : Level → Bool
/--
`isPropQuick e` is an "approximate" predicate which returns `LBool.true`
if `e` is a proposition. -/
@[specialize] private partial def isPropQuick : Expr → MetaM LBool
private partial def isPropQuick : Expr → MetaM LBool
| Expr.bvar _ _ => pure LBool.undef
| Expr.lit _ _ => pure LBool.false
| Expr.sort _ _ => pure LBool.false
@ -223,15 +207,15 @@ private def isAlwaysZero : Level → Bool
to decide whether is a proposition or not. We return `false` in this
case. We considered using `LBool` and retuning `LBool.undef`, but
we have no applications for it. -/
@[specialize] def isPropAux (whnf : Expr → MetaM Expr) (e : Expr) : MetaM Bool :=
def isProp (e : Expr) : MetaM Bool :=
do r ← isPropQuick e;
match r with
| LBool.true => pure true
| LBool.false => pure false
| LBool.undef => do
-- dbgTrace ("PropQuick failed " ++ toString e);
type ← inferTypeAux whnf e;
type ← usingDefault whnf type;
type ← inferType e;
type ← whnfUsingDefault type;
match type with
| Expr.sort u _ => do u ← instantiateLevelMVars u; pure $ isAlwaysZero u
| _ => pure false

View file

@ -169,5 +169,8 @@ do s ← get;
def isLevelDefEq (u v : Level) : MetaM Bool :=
try $ isLevelDefEqAux u v
def isExprDefEq (e₁ e₂ : Expr) : MetaM Bool :=
try $ isExprDefEqAux e₁ e₂
end Meta
end Lean

View file

@ -89,12 +89,8 @@ private partial def isOffset : Expr → Option (Expr × Nat)
| _ => none
| _ => none
@[specialize] def isDefEqOffset
(whnf : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(s t : Expr)
(k : MetaM Bool) -- continuation when `isDefEqOffset` could not decide
: MetaM Bool :=
def isDefEqOffset (s t : Expr) : MetaM LBool :=
let isDefEq (s t) : MetaM LBool := toLBoolM $ isExprDefEqAux s t;
match isOffset s with
| some (s, k₁) => match isOffset t with
| some (t, k₂) => -- s+k₁ =?= t+k₂
@ -103,16 +99,16 @@ match isOffset s with
else isDefEq (mkCAppB `Nat.add s (mkNatLit $ k₁ - k₂)) t
| none => match evalNat t with
| some v₂ => -- s+k₁ =?= v₂
if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure false
| none => k
if v₂ ≥ k₁ then isDefEq s (mkNatLit $ v₂ - k₁) else pure LBool.false
| none => pure LBool.undef
| none => match evalNat s with
| some v₁ => match isOffset t with
| some (t, k₂) => -- v₁ =?= t+k₂
if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure false
if v₁ ≥ k₂ then isDefEq s (mkNatLit $ v₁ - k₂) else pure LBool.false
| none => match evalNat t with
| some v₂ => pure (v₁ == v₂) -- v₁ =?= v₂
| none => k
| none => k
| some v₂ => pure (v₁ == v₂).toLBool -- v₁ =?= v₂
| none => pure LBool.false
| none => pure LBool.false
end Meta
end Lean

View file

@ -7,6 +7,7 @@ prelude
import Init.Lean.AuxRecursor
import Init.Lean.WHNF
import Init.Lean.Meta.Basic
import Init.Lean.Meta.LevelDefEq
namespace Lean
namespace Meta
@ -15,23 +16,18 @@ def isAuxDef? (constName : Name) : MetaM Bool :=
do env ← getEnv; pure (isAuxRecursor env constName || isNoConfusion env constName)
@[specialize] def unfoldDefinitionAux {α}
(whnf : Expr → MetaM Expr)
(inferType : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
(e : Expr)
(failK : MetaM α) (successK : Expr → MetaM α) : MetaM α :=
Lean.unfoldDefinitionAux getConstNoEx isAuxDef? whnf inferType isDefEq synthesizePending getLocalDecl
(e : Expr) (failK : MetaM α) (successK : Expr → MetaM α) : MetaM α :=
Lean.unfoldDefinitionAux getConstNoEx isAuxDef? whnf inferType isExprDefEq synthPending getLocalDecl
getExprMVarAssignment e (fun _ => failK) successK
@[specialize] partial def whnfAux
(inferType : Expr → MetaM Expr)
(isDefEq : Expr → Expr → MetaM Bool)
(synthesizePending : Expr → MetaM Bool)
: Expr → MetaM Expr
partial def whnfImpl : Expr → MetaM Expr
| e => whnfEasyCases getLocalDecl getExprMVarAssignment e $ fun e => do
e ← whnfCore getConstNoEx isAuxDef? whnfAux inferType isDefEq getLocalDecl getExprMVarAssignment e;
unfoldDefinitionAux whnfAux inferType isDefEq synthesizePending e (pure e) whnfAux
e ← whnfCore getConstNoEx isAuxDef? whnfImpl inferType isExprDefEqAux getLocalDecl getExprMVarAssignment e;
Lean.unfoldDefinitionAux getConstNoEx isAuxDef? whnf inferType isExprDefEq synthPending getLocalDecl
getExprMVarAssignment e (fun _ => pure e) whnfImpl
@[init] def setWHNFRef : IO Unit :=
whnfRef.set whnfImpl
end Meta
end Lean