From a230fe2d06cf61b7aa86f7b26bdf2ff13d9a2dc6 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 6 Aug 2021 09:43:26 -0700 Subject: [PATCH] fix: forallMetaTelescope issue This commit incorporates the fix at PR #612, and clean up `Meta/Basic.lean` using Lean 4 features. --- src/Init/Data/Option/Basic.lean | 4 + src/Lean/Meta/Basic.lean | 160 ++++++------------ tests/lean/forallMetaBounded.lean | 16 ++ .../lean/forallMetaBounded.lean.expected.out | 1 + 4 files changed, 72 insertions(+), 109 deletions(-) create mode 100644 tests/lean/forallMetaBounded.lean create mode 100644 tests/lean/forallMetaBounded.lean.expected.out diff --git a/src/Init/Data/Option/Basic.lean b/src/Init/Data/Option/Basic.lean index 9952cf4269..50d1f3bdc9 100644 --- a/src/Init/Data/Option/Basic.lean +++ b/src/Init/Data/Option/Basic.lean @@ -26,6 +26,10 @@ def toMonad [Monad m] [Alternative m] : Option α → m α | some _ => false | none => true +@[inline] def isEqSome [BEq α] : Option α → α → Bool + | some a, b => a == b + | none, _ => false + @[inline] protected def bind : Option α → (α → Option β) → Option β | none, b => none | some a, b => b a diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index 5dc60f7770..691b587732 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -292,8 +292,7 @@ def mkFreshExprMVarAt (lctx : LocalContext) (localInsts : LocalInstances) (type : Expr) (kind : MetavarKind := MetavarKind.natural) (userName : Name := Name.anonymous) (numScopeArgs : Nat := 0) : MetaM Expr := do - let mvarId ← mkFreshId - mkFreshExprMVarAtCore mvarId lctx localInsts type kind userName numScopeArgs + mkFreshExprMVarAtCore (← mkFreshId) lctx localInsts type kind userName numScopeArgs def mkFreshLevelMVar : MetaM Level := do let mvarId ← mkFreshId @@ -301,9 +300,7 @@ def mkFreshLevelMVar : MetaM Level := do return mkLevelMVar mvarId private def mkFreshExprMVarCore (type : Expr) (kind : MetavarKind) (userName : Name) : MetaM Expr := do - let lctx ← getLCtx - let localInsts ← getLocalInstances - mkFreshExprMVarAt lctx localInsts type kind userName + mkFreshExprMVarAt (← getLCtx) (← getLocalInstances) type kind userName private def mkFreshExprMVarImpl (type? : Option Expr) (kind : MetavarKind) (userName : Name) : MetaM Expr := match type? with @@ -325,9 +322,7 @@ def mkFreshTypeMVar (kind := MetavarKind.natural) (userName := Name.anonymous) : private def mkFreshExprMVarWithIdCore (mvarId : MVarId) (type : Expr) (kind : MetavarKind := MetavarKind.natural) (userName : Name := Name.anonymous) (numScopeArgs : Nat := 0) : MetaM Expr := do - let lctx ← getLCtx - let localInsts ← getLocalInstances - mkFreshExprMVarAtCore mvarId lctx localInsts type kind userName numScopeArgs + mkFreshExprMVarAtCore mvarId (← getLCtx) (← getLocalInstances) type kind userName numScopeArgs def mkFreshExprMVarWithId (mvarId : MVarId) (type? : Option Expr := none) (kind : MetavarKind := MetavarKind.natural) (userName := Name.anonymous) : MetaM Expr := match type? with @@ -358,8 +353,7 @@ def shouldReduceReducibleOnly : MetaM Bool := return (← getTransparency) == TransparencyMode.reducible def getMVarDecl (mvarId : MVarId) : MetaM MetavarDecl := do - let mctx ← getMCtx - match mctx.findDecl? mvarId with + match (← getMCtx).findDecl? mvarId with | some d => pure d | none => throwError "unknown metavariable '?{mvarId}'" @@ -372,9 +366,7 @@ def setMVarType (mvarId : MVarId) (type : Expr) : MetaM Unit := do modifyMCtx fun mctx => mctx.setMVarType mvarId type def isReadOnlyExprMVar (mvarId : MVarId) : MetaM Bool := do - let mvarDecl ← getMVarDecl mvarId - let mctx ← getMCtx - return mvarDecl.depth != mctx.depth + return (← getMVarDecl mvarId).depth != (← getMCtx).depth def isReadOnlyOrSyntheticOpaqueExprMVar (mvarId : MVarId) : MetaM Bool := do let mvarDecl ← getMVarDecl mvarId @@ -438,15 +430,12 @@ def instantiateLevelMVars (u : Level) : MetaM Level := def instantiateMVars (e : Expr) : MetaM Expr := (MetavarContext.instantiateExprMVars e).run -def instantiateLocalDeclMVars (localDecl : LocalDecl) : MetaM LocalDecl := do +def instantiateLocalDeclMVars (localDecl : LocalDecl) : MetaM LocalDecl := match localDecl with | LocalDecl.cdecl idx id n type bi => - let type ← instantiateMVars type - return LocalDecl.cdecl idx id n type bi + return LocalDecl.cdecl idx id n (← instantiateMVars type) bi | LocalDecl.ldecl idx id n type val nonDep => - let type ← instantiateMVars type - let val ← instantiateMVars val - return LocalDecl.ldecl idx id n type val nonDep + return LocalDecl.ldecl idx id n (← instantiateMVars type) (← instantiateMVars val) nonDep @[inline] def liftMkBindingM (x : MetavarContext.MkBindingM α) : MetaM α := do match x (← getLCtx) { mctx := (← getMCtx), ngen := (← getNGen) } with @@ -468,9 +457,8 @@ def mkLambdaFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedL def mkLetFVars (xs : Array Expr) (e : Expr) (usedLetOnly := true) : MetaM Expr := mkLambdaFVars xs e (usedLetOnly := usedLetOnly) -def mkArrow (d b : Expr) : MetaM Expr := do - let n ← mkFreshUserName `x - return Lean.mkForall n BinderInfo.default d b +def mkArrow (d b : Expr) : MetaM Expr := + return Lean.mkForall (← mkFreshUserName `x) BinderInfo.default d b def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) : MetaM Expr := if xs.isEmpty then pure e else liftMkBindingM <| MetavarContext.elimMVarDeps xs e preserveOrder @@ -510,8 +498,7 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) : /-- Save cache, execute `x`, restore cache -/ @[inline] private def savingCacheImpl (x : MetaM α) : MetaM α := do - let s ← get - let savedCache := s.cache + let savedCache := (← get).cache try x finally modify fun s => { s with cache := savedCache } @[inline] def savingCache : n α → n α := @@ -538,16 +525,14 @@ private def getDefInfoTemp (info : ConstantInfo) : MetaM (Option ConstantInfo) : It is very similar to `getConst?`, but it returns none when `TransparencyMode.instances` and `constName` is an instance. This difference should be irrelevant for `isClassQuickConst?`. -/ private def getConstTemp? (constName : Name) : MetaM (Option ConstantInfo) := do - let env ← getEnv - match env.find? constName with + match (← getEnv).find? constName with | some (info@(ConstantInfo.thmInfo _)) => getTheoremInfo info | some (info@(ConstantInfo.defnInfo _)) => getDefInfoTemp info | some info => pure (some info) | none => throwUnknownConstant constName private def isClassQuickConst? (constName : Name) : MetaM (LOption Name) := do - let env ← getEnv - if isClass env constName then + if isClass (← getEnv) constName then pure (LOption.some constName) else match (← getConstTemp? constName) with @@ -576,8 +561,7 @@ private partial def isClassQuick? : Expr → MetaM (LOption Name) | _ => pure LOption.none def saveAndResetSynthInstanceCache : MetaM SynthInstanceCache := do - let s ← get - let savedSythInstance := s.cache.synthInstance + let savedSythInstance := (← get).cache.synthInstance modifyCache fun c => { c with synthInstance := {} } pure savedSythInstance @@ -769,31 +753,10 @@ private def forallBoundedTelescopeImp (type : Expr) (maxFVars? : Option Nat) (k def forallBoundedTelescope (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → n α) : n α := map2MetaM (fun k => forallBoundedTelescopeImp type maxFVars? k) k -/-- Similar to `forallTelescopeAuxAux` but for lambda and let expressions. -/ -private partial def lambdaTelescopeAux - (k : Array Expr → Expr → MetaM α) - : Bool → LocalContext → Array Expr → Nat → Expr → MetaM α - | consumeLet, lctx, fvars, j, Expr.lam n d b c => do - let d := d.instantiateRevRange j fvars.size fvars - let fvarId ← mkFreshId - let lctx := lctx.mkLocalDecl fvarId n d c.binderInfo - let fvar := mkFVar fvarId - lambdaTelescopeAux k consumeLet lctx (fvars.push fvar) j b - | true, lctx, fvars, j, Expr.letE n t v b _ => do - let t := t.instantiateRevRange j fvars.size fvars - let v := v.instantiateRevRange j fvars.size fvars - let fvarId ← mkFreshId - let lctx := lctx.mkLetDecl fvarId n t v - let fvar := mkFVar fvarId - lambdaTelescopeAux k true lctx (fvars.push fvar) j b - | _, lctx, fvars, j, e => - let e := e.instantiateRevRange j fvars.size fvars; - withReader (fun ctx => { ctx with lctx := lctx }) do - withNewLocalInstancesImp fvars j do - k fvars e - private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (k : Array Expr → Expr → MetaM α) : MetaM α := do - let rec process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (j : Nat) (e : Expr) : MetaM α := do + process consumeLet (← getLCtx) #[] 0 e +where + process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (j : Nat) (e : Expr) : MetaM α := do match consumeLet, e with | _, Expr.lam n d b c => let d := d.instantiateRevRange j fvars.size fvars @@ -813,7 +776,6 @@ private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (k : Array withReader (fun ctx => { ctx with lctx := lctx }) do withNewLocalInstancesImp fvars j do k fvars e - process consumeLet (← getLCtx) #[] 0 e /-- Similar to `forallTelescope` but for lambda and let expressions. -/ def lambdaLetTelescope (type : Expr) (k : Array Expr → Expr → n α) : n α := @@ -825,8 +787,7 @@ def lambdaTelescope (type : Expr) (k : Array Expr → Expr → n α) : n α := /-- Return the parameter names for the givel global declaration. -/ def getParamNames (declName : Name) : MetaM (Array Name) := do - let cinfo ← getConstInfo declName - forallTelescopeReducing cinfo.type fun xs _ => do + forallTelescopeReducing (← getConstInfo declName).type fun xs _ => do xs.mapM fun x => do let localDecl ← getLocalDecl x.fvarId! pure localDecl.userName @@ -834,35 +795,31 @@ def getParamNames (declName : Name) : MetaM (Array Name) := do -- `kind` specifies the metavariable kind for metavariables not corresponding to instance implicit `[ ... ]` arguments. private partial def forallMetaTelescopeReducingAux (e : Expr) (reducing : Bool) (maxMVars? : Option Nat) (kind : MetavarKind) : MetaM (Array Expr × Array BinderInfo × Expr) := - let rec process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do - match type with - | Expr.forallE n d b c => - let cont : Unit → MetaM (Array Expr × Array BinderInfo × Expr) := fun _ => do + process #[] #[] 0 e +where + process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do + if maxMVars?.isEqSome mvars.size then + let type := type.instantiateRevRange j mvars.size mvars; + return (mvars, bis, type) + else + match type with + | Expr.forallE n d b c => let d := d.instantiateRevRange j mvars.size mvars let k := if c.binderInfo.isInstImplicit then MetavarKind.synthetic else kind let mvar ← mkFreshExprMVar d k n let mvars := mvars.push mvar let bis := bis.push c.binderInfo process mvars bis j b - match maxMVars? with - | none => cont () - | some maxMVars => - if mvars.size < maxMVars then - cont () + | _ => + let type := type.instantiateRevRange j mvars.size mvars; + if reducing then do + let newType ← whnf type; + if newType.isForall then + process mvars bis mvars.size newType + else + return (mvars, bis, type) else - let type := type.instantiateRevRange j mvars.size mvars; - pure (mvars, bis, type) - | _ => - let type := type.instantiateRevRange j mvars.size mvars; - if reducing then do - let newType ← whnf type; - if newType.isForall then - process mvars bis mvars.size newType - else - pure (mvars, bis, type) - else - pure (mvars, bis, type) - process #[] #[] 0 e + return (mvars, bis, type) /-- Similar to `forallTelescope`, but creates metavariables instead of free variables. -/ def forallMetaTelescope (e : Expr) (kind := MetavarKind.natural) : MetaM (Array Expr × Array BinderInfo × Expr) := @@ -878,11 +835,15 @@ def forallMetaBoundedTelescope (e : Expr) (maxMVars : Nat) (kind : MetavarKind : /-- Similar to `forallMetaTelescopeReducingAux` but for lambda expressions. -/ partial def lambdaMetaTelescope (e : Expr) (maxMVars? : Option Nat := none) : MetaM (Array Expr × Array BinderInfo × Expr) := - let rec process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do + process #[] #[] 0 e +where + process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do let finalize : Unit → MetaM (Array Expr × Array BinderInfo × Expr) := fun _ => do let type := type.instantiateRevRange j mvars.size mvars pure (mvars, bis, type) - let cont : Unit → MetaM (Array Expr × Array BinderInfo × Expr) := fun _ => do + if maxMVars?.isEqSome mvars.size then + finalize () + else match type with | Expr.lam n d b c => let d := d.instantiateRevRange j mvars.size mvars @@ -891,14 +852,6 @@ partial def lambdaMetaTelescope (e : Expr) (maxMVars? : Option Nat := none) : Me let bis := bis.push c.binderInfo process mvars bis j b | _ => finalize () - match maxMVars? with - | none => cont () - | some maxMVars => - if mvars.size < maxMVars then - cont () - else - finalize () - process #[] #[] 0 e private def withNewFVar (fvar fvarType : Expr) (k : Expr → MetaM α) : MetaM α := do match (← isClass? fvarType) with @@ -924,21 +877,16 @@ partial def withLocalDecls (declInfos : Array (Name × BinderInfo × (Array Expr → n Expr))) (k : (xs : Array Expr) → n α) : n α := - let rec loop - [Inhabited α] - (acc : Array Expr) : n α := do + loop #[] +where + loop [Inhabited α] (acc : Array Expr) : n α := do if acc.size < declInfos.size then let (name, bi, typeCtor) := declInfos[acc.size] withLocalDecl name bi (←typeCtor acc) fun x => loop (acc.push x) - else k acc + else + k acc - loop #[] - -def withLocalDeclsD - [Inhabited α] - (declInfos : Array (Name × (Array Expr → n Expr))) - (k : (xs : Array Expr) → n α) - : n α := +def withLocalDeclsD [Inhabited α] (declInfos : Array (Name × (Array Expr → n Expr))) (k : (xs : Array Expr) → n α) : n α := withLocalDecls (declInfos.map (fun (name, typeCtor) => (name, BinderInfo.default, typeCtor))) k @@ -1069,8 +1017,7 @@ def setInlineAttribute (declName : Name) (kind := Compiler.InlineAttributeKind.i private partial def instantiateForallAux (ps : Array Expr) (i : Nat) (e : Expr) : MetaM Expr := do if h : i < ps.size then let p := ps.get ⟨i, h⟩ - let e ← whnf e - match e with + match (← whnf e) with | Expr.forallE _ _ b _ => instantiateForallAux ps (i+1) (b.instantiate1 p) | _ => throwError "invalid instantiateForall, too many parameters" else @@ -1083,8 +1030,7 @@ def instantiateForall (e : Expr) (ps : Array Expr) : MetaM Expr := private partial def instantiateLambdaAux (ps : Array Expr) (i : Nat) (e : Expr) : MetaM Expr := do if h : i < ps.size then let p := ps.get ⟨i, h⟩ - let e ← whnf e - match e with + match (← whnf e) with | Expr.lam _ _ b _ => instantiateLambdaAux ps (i+1) (b.instantiate1 p) | _ => throwError "invalid instantiateLambda, too many parameters" else @@ -1100,12 +1046,8 @@ def dependsOn (e : Expr) (fvarId : FVarId) : MetaM Bool := return (← getMCtx).exprDependsOn e fvarId def ppExpr (e : Expr) : MetaM Format := do - let env ← getEnv - let mctx ← getMCtx - let lctx ← getLCtx - let opts ← getOptions let ctxCore ← readThe Core.Context - Lean.ppExpr { env := env, mctx := mctx, lctx := lctx, opts := opts, currNamespace := ctxCore.currNamespace, openDecls := ctxCore.openDecls } e + Lean.ppExpr { env := (← getEnv), mctx := (← getMCtx), lctx := (← getLCtx), opts := (← getOptions), currNamespace := ctxCore.currNamespace, openDecls := ctxCore.openDecls } e @[inline] protected def orelse (x y : MetaM α) : MetaM α := do let env ← getEnv diff --git a/tests/lean/forallMetaBounded.lean b/tests/lean/forallMetaBounded.lean new file mode 100644 index 0000000000..78cfc9cd7d --- /dev/null +++ b/tests/lean/forallMetaBounded.lean @@ -0,0 +1,16 @@ +import Lean +open Lean Lean.Meta + +def Set (α : Type) : Type := + α → Prop + +def Set.empty {α : Type} : Set α := + fun a => False + +def Set.insert (s : Set α) (a : α) : Set α := + fun x => x = a ∨ s a + +#eval show MetaM Unit from do + let insertType ← inferType (mkConst `Set.insert) + let ⟨mvars, bInfos, resultType⟩ ← forallMetaBoundedTelescope insertType 3 + println! "{resultType}" diff --git a/tests/lean/forallMetaBounded.lean.expected.out b/tests/lean/forallMetaBounded.lean.expected.out new file mode 100644 index 0000000000..5e9817c51b --- /dev/null +++ b/tests/lean/forallMetaBounded.lean.expected.out @@ -0,0 +1 @@ +Set ?_uniq.1