fix: forallMetaTelescope issue

This commit incorporates the fix at PR #612, and clean up
`Meta/Basic.lean` using Lean 4 features.
This commit is contained in:
Leonardo de Moura 2021-08-06 09:43:26 -07:00
parent 803b73e32d
commit a230fe2d06
4 changed files with 72 additions and 109 deletions

View file

@ -26,6 +26,10 @@ def toMonad [Monad m] [Alternative m] : Option α → m α
| some _ => false
| none => true
@[inline] def isEqSome [BEq α] : Option αα → Bool
| some a, b => a == b
| none, _ => false
@[inline] protected def bind : Option α → (α → Option β) → Option β
| none, b => none
| some a, b => b a

View file

@ -292,8 +292,7 @@ def mkFreshExprMVarAt
(lctx : LocalContext) (localInsts : LocalInstances) (type : Expr)
(kind : MetavarKind := MetavarKind.natural) (userName : Name := Name.anonymous) (numScopeArgs : Nat := 0)
: MetaM Expr := do
let mvarId ← mkFreshId
mkFreshExprMVarAtCore mvarId lctx localInsts type kind userName numScopeArgs
mkFreshExprMVarAtCore (← mkFreshId) lctx localInsts type kind userName numScopeArgs
def mkFreshLevelMVar : MetaM Level := do
let mvarId ← mkFreshId
@ -301,9 +300,7 @@ def mkFreshLevelMVar : MetaM Level := do
return mkLevelMVar mvarId
private def mkFreshExprMVarCore (type : Expr) (kind : MetavarKind) (userName : Name) : MetaM Expr := do
let lctx ← getLCtx
let localInsts ← getLocalInstances
mkFreshExprMVarAt lctx localInsts type kind userName
mkFreshExprMVarAt (← getLCtx) (← getLocalInstances) type kind userName
private def mkFreshExprMVarImpl (type? : Option Expr) (kind : MetavarKind) (userName : Name) : MetaM Expr :=
match type? with
@ -325,9 +322,7 @@ def mkFreshTypeMVar (kind := MetavarKind.natural) (userName := Name.anonymous) :
private def mkFreshExprMVarWithIdCore (mvarId : MVarId) (type : Expr)
(kind : MetavarKind := MetavarKind.natural) (userName : Name := Name.anonymous) (numScopeArgs : Nat := 0)
: MetaM Expr := do
let lctx ← getLCtx
let localInsts ← getLocalInstances
mkFreshExprMVarAtCore mvarId lctx localInsts type kind userName numScopeArgs
mkFreshExprMVarAtCore mvarId (← getLCtx) (← getLocalInstances) type kind userName numScopeArgs
def mkFreshExprMVarWithId (mvarId : MVarId) (type? : Option Expr := none) (kind : MetavarKind := MetavarKind.natural) (userName := Name.anonymous) : MetaM Expr :=
match type? with
@ -358,8 +353,7 @@ def shouldReduceReducibleOnly : MetaM Bool :=
return (← getTransparency) == TransparencyMode.reducible
def getMVarDecl (mvarId : MVarId) : MetaM MetavarDecl := do
let mctx ← getMCtx
match mctx.findDecl? mvarId with
match (← getMCtx).findDecl? mvarId with
| some d => pure d
| none => throwError "unknown metavariable '?{mvarId}'"
@ -372,9 +366,7 @@ def setMVarType (mvarId : MVarId) (type : Expr) : MetaM Unit := do
modifyMCtx fun mctx => mctx.setMVarType mvarId type
def isReadOnlyExprMVar (mvarId : MVarId) : MetaM Bool := do
let mvarDecl ← getMVarDecl mvarId
let mctx ← getMCtx
return mvarDecl.depth != mctx.depth
return (← getMVarDecl mvarId).depth != (← getMCtx).depth
def isReadOnlyOrSyntheticOpaqueExprMVar (mvarId : MVarId) : MetaM Bool := do
let mvarDecl ← getMVarDecl mvarId
@ -438,15 +430,12 @@ def instantiateLevelMVars (u : Level) : MetaM Level :=
def instantiateMVars (e : Expr) : MetaM Expr :=
(MetavarContext.instantiateExprMVars e).run
def instantiateLocalDeclMVars (localDecl : LocalDecl) : MetaM LocalDecl := do
def instantiateLocalDeclMVars (localDecl : LocalDecl) : MetaM LocalDecl :=
match localDecl with
| LocalDecl.cdecl idx id n type bi =>
let type ← instantiateMVars type
return LocalDecl.cdecl idx id n type bi
return LocalDecl.cdecl idx id n (← instantiateMVars type) bi
| LocalDecl.ldecl idx id n type val nonDep =>
let type ← instantiateMVars type
let val ← instantiateMVars val
return LocalDecl.ldecl idx id n type val nonDep
return LocalDecl.ldecl idx id n (← instantiateMVars type) (← instantiateMVars val) nonDep
@[inline] def liftMkBindingM (x : MetavarContext.MkBindingM α) : MetaM α := do
match x (← getLCtx) { mctx := (← getMCtx), ngen := (← getNGen) } with
@ -468,9 +457,8 @@ def mkLambdaFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedL
def mkLetFVars (xs : Array Expr) (e : Expr) (usedLetOnly := true) : MetaM Expr :=
mkLambdaFVars xs e (usedLetOnly := usedLetOnly)
def mkArrow (d b : Expr) : MetaM Expr := do
let n ← mkFreshUserName `x
return Lean.mkForall n BinderInfo.default d b
def mkArrow (d b : Expr) : MetaM Expr :=
return Lean.mkForall (← mkFreshUserName `x) BinderInfo.default d b
def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) : MetaM Expr :=
if xs.isEmpty then pure e else liftMkBindingM <| MetavarContext.elimMVarDeps xs e preserveOrder
@ -510,8 +498,7 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) :
/-- Save cache, execute `x`, restore cache -/
@[inline] private def savingCacheImpl (x : MetaM α) : MetaM α := do
let s ← get
let savedCache := s.cache
let savedCache := (← get).cache
try x finally modify fun s => { s with cache := savedCache }
@[inline] def savingCache : n α → n α :=
@ -538,16 +525,14 @@ private def getDefInfoTemp (info : ConstantInfo) : MetaM (Option ConstantInfo) :
It is very similar to `getConst?`, but it returns none when `TransparencyMode.instances` and
`constName` is an instance. This difference should be irrelevant for `isClassQuickConst?`. -/
private def getConstTemp? (constName : Name) : MetaM (Option ConstantInfo) := do
let env ← getEnv
match env.find? constName with
match (← getEnv).find? constName with
| some (info@(ConstantInfo.thmInfo _)) => getTheoremInfo info
| some (info@(ConstantInfo.defnInfo _)) => getDefInfoTemp info
| some info => pure (some info)
| none => throwUnknownConstant constName
private def isClassQuickConst? (constName : Name) : MetaM (LOption Name) := do
let env ← getEnv
if isClass env constName then
if isClass (← getEnv) constName then
pure (LOption.some constName)
else
match (← getConstTemp? constName) with
@ -576,8 +561,7 @@ private partial def isClassQuick? : Expr → MetaM (LOption Name)
| _ => pure LOption.none
def saveAndResetSynthInstanceCache : MetaM SynthInstanceCache := do
let s ← get
let savedSythInstance := s.cache.synthInstance
let savedSythInstance := (← get).cache.synthInstance
modifyCache fun c => { c with synthInstance := {} }
pure savedSythInstance
@ -769,31 +753,10 @@ private def forallBoundedTelescopeImp (type : Expr) (maxFVars? : Option Nat) (k
def forallBoundedTelescope (type : Expr) (maxFVars? : Option Nat) (k : Array Expr → Expr → n α) : n α :=
map2MetaM (fun k => forallBoundedTelescopeImp type maxFVars? k) k
/-- Similar to `forallTelescopeAuxAux` but for lambda and let expressions. -/
private partial def lambdaTelescopeAux
(k : Array Expr → Expr → MetaM α)
: Bool → LocalContext → Array Expr → Nat → Expr → MetaM α
| consumeLet, lctx, fvars, j, Expr.lam n d b c => do
let d := d.instantiateRevRange j fvars.size fvars
let fvarId ← mkFreshId
let lctx := lctx.mkLocalDecl fvarId n d c.binderInfo
let fvar := mkFVar fvarId
lambdaTelescopeAux k consumeLet lctx (fvars.push fvar) j b
| true, lctx, fvars, j, Expr.letE n t v b _ => do
let t := t.instantiateRevRange j fvars.size fvars
let v := v.instantiateRevRange j fvars.size fvars
let fvarId ← mkFreshId
let lctx := lctx.mkLetDecl fvarId n t v
let fvar := mkFVar fvarId
lambdaTelescopeAux k true lctx (fvars.push fvar) j b
| _, lctx, fvars, j, e =>
let e := e.instantiateRevRange j fvars.size fvars;
withReader (fun ctx => { ctx with lctx := lctx }) do
withNewLocalInstancesImp fvars j do
k fvars e
private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (k : Array Expr → Expr → MetaM α) : MetaM α := do
let rec process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (j : Nat) (e : Expr) : MetaM α := do
process consumeLet (← getLCtx) #[] 0 e
where
process (consumeLet : Bool) (lctx : LocalContext) (fvars : Array Expr) (j : Nat) (e : Expr) : MetaM α := do
match consumeLet, e with
| _, Expr.lam n d b c =>
let d := d.instantiateRevRange j fvars.size fvars
@ -813,7 +776,6 @@ private partial def lambdaTelescopeImp (e : Expr) (consumeLet : Bool) (k : Array
withReader (fun ctx => { ctx with lctx := lctx }) do
withNewLocalInstancesImp fvars j do
k fvars e
process consumeLet (← getLCtx) #[] 0 e
/-- Similar to `forallTelescope` but for lambda and let expressions. -/
def lambdaLetTelescope (type : Expr) (k : Array Expr → Expr → n α) : n α :=
@ -825,8 +787,7 @@ def lambdaTelescope (type : Expr) (k : Array Expr → Expr → n α) : n α :=
/-- Return the parameter names for the givel global declaration. -/
def getParamNames (declName : Name) : MetaM (Array Name) := do
let cinfo ← getConstInfo declName
forallTelescopeReducing cinfo.type fun xs _ => do
forallTelescopeReducing (← getConstInfo declName).type fun xs _ => do
xs.mapM fun x => do
let localDecl ← getLocalDecl x.fvarId!
pure localDecl.userName
@ -834,35 +795,31 @@ def getParamNames (declName : Name) : MetaM (Array Name) := do
-- `kind` specifies the metavariable kind for metavariables not corresponding to instance implicit `[ ... ]` arguments.
private partial def forallMetaTelescopeReducingAux
(e : Expr) (reducing : Bool) (maxMVars? : Option Nat) (kind : MetavarKind) : MetaM (Array Expr × Array BinderInfo × Expr) :=
let rec process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do
match type with
| Expr.forallE n d b c =>
let cont : Unit → MetaM (Array Expr × Array BinderInfo × Expr) := fun _ => do
process #[] #[] 0 e
where
process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do
if maxMVars?.isEqSome mvars.size then
let type := type.instantiateRevRange j mvars.size mvars;
return (mvars, bis, type)
else
match type with
| Expr.forallE n d b c =>
let d := d.instantiateRevRange j mvars.size mvars
let k := if c.binderInfo.isInstImplicit then MetavarKind.synthetic else kind
let mvar ← mkFreshExprMVar d k n
let mvars := mvars.push mvar
let bis := bis.push c.binderInfo
process mvars bis j b
match maxMVars? with
| none => cont ()
| some maxMVars =>
if mvars.size < maxMVars then
cont ()
| _ =>
let type := type.instantiateRevRange j mvars.size mvars;
if reducing then do
let newType ← whnf type;
if newType.isForall then
process mvars bis mvars.size newType
else
return (mvars, bis, type)
else
let type := type.instantiateRevRange j mvars.size mvars;
pure (mvars, bis, type)
| _ =>
let type := type.instantiateRevRange j mvars.size mvars;
if reducing then do
let newType ← whnf type;
if newType.isForall then
process mvars bis mvars.size newType
else
pure (mvars, bis, type)
else
pure (mvars, bis, type)
process #[] #[] 0 e
return (mvars, bis, type)
/-- Similar to `forallTelescope`, but creates metavariables instead of free variables. -/
def forallMetaTelescope (e : Expr) (kind := MetavarKind.natural) : MetaM (Array Expr × Array BinderInfo × Expr) :=
@ -878,11 +835,15 @@ def forallMetaBoundedTelescope (e : Expr) (maxMVars : Nat) (kind : MetavarKind :
/-- Similar to `forallMetaTelescopeReducingAux` but for lambda expressions. -/
partial def lambdaMetaTelescope (e : Expr) (maxMVars? : Option Nat := none) : MetaM (Array Expr × Array BinderInfo × Expr) :=
let rec process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do
process #[] #[] 0 e
where
process (mvars : Array Expr) (bis : Array BinderInfo) (j : Nat) (type : Expr) : MetaM (Array Expr × Array BinderInfo × Expr) := do
let finalize : Unit → MetaM (Array Expr × Array BinderInfo × Expr) := fun _ => do
let type := type.instantiateRevRange j mvars.size mvars
pure (mvars, bis, type)
let cont : Unit → MetaM (Array Expr × Array BinderInfo × Expr) := fun _ => do
if maxMVars?.isEqSome mvars.size then
finalize ()
else
match type with
| Expr.lam n d b c =>
let d := d.instantiateRevRange j mvars.size mvars
@ -891,14 +852,6 @@ partial def lambdaMetaTelescope (e : Expr) (maxMVars? : Option Nat := none) : Me
let bis := bis.push c.binderInfo
process mvars bis j b
| _ => finalize ()
match maxMVars? with
| none => cont ()
| some maxMVars =>
if mvars.size < maxMVars then
cont ()
else
finalize ()
process #[] #[] 0 e
private def withNewFVar (fvar fvarType : Expr) (k : Expr → MetaM α) : MetaM α := do
match (← isClass? fvarType) with
@ -924,21 +877,16 @@ partial def withLocalDecls
(declInfos : Array (Name × BinderInfo × (Array Expr → n Expr)))
(k : (xs : Array Expr) → n α)
: n α :=
let rec loop
[Inhabited α]
(acc : Array Expr) : n α := do
loop #[]
where
loop [Inhabited α] (acc : Array Expr) : n α := do
if acc.size < declInfos.size then
let (name, bi, typeCtor) := declInfos[acc.size]
withLocalDecl name bi (←typeCtor acc) fun x => loop (acc.push x)
else k acc
else
k acc
loop #[]
def withLocalDeclsD
[Inhabited α]
(declInfos : Array (Name × (Array Expr → n Expr)))
(k : (xs : Array Expr) → n α)
: n α :=
def withLocalDeclsD [Inhabited α] (declInfos : Array (Name × (Array Expr → n Expr))) (k : (xs : Array Expr) → n α) : n α :=
withLocalDecls
(declInfos.map (fun (name, typeCtor) => (name, BinderInfo.default, typeCtor))) k
@ -1069,8 +1017,7 @@ def setInlineAttribute (declName : Name) (kind := Compiler.InlineAttributeKind.i
private partial def instantiateForallAux (ps : Array Expr) (i : Nat) (e : Expr) : MetaM Expr := do
if h : i < ps.size then
let p := ps.get ⟨i, h⟩
let e ← whnf e
match e with
match (← whnf e) with
| Expr.forallE _ _ b _ => instantiateForallAux ps (i+1) (b.instantiate1 p)
| _ => throwError "invalid instantiateForall, too many parameters"
else
@ -1083,8 +1030,7 @@ def instantiateForall (e : Expr) (ps : Array Expr) : MetaM Expr :=
private partial def instantiateLambdaAux (ps : Array Expr) (i : Nat) (e : Expr) : MetaM Expr := do
if h : i < ps.size then
let p := ps.get ⟨i, h⟩
let e ← whnf e
match e with
match (← whnf e) with
| Expr.lam _ _ b _ => instantiateLambdaAux ps (i+1) (b.instantiate1 p)
| _ => throwError "invalid instantiateLambda, too many parameters"
else
@ -1100,12 +1046,8 @@ def dependsOn (e : Expr) (fvarId : FVarId) : MetaM Bool :=
return (← getMCtx).exprDependsOn e fvarId
def ppExpr (e : Expr) : MetaM Format := do
let env ← getEnv
let mctx ← getMCtx
let lctx ← getLCtx
let opts ← getOptions
let ctxCore ← readThe Core.Context
Lean.ppExpr { env := env, mctx := mctx, lctx := lctx, opts := opts, currNamespace := ctxCore.currNamespace, openDecls := ctxCore.openDecls } e
Lean.ppExpr { env := (← getEnv), mctx := (← getMCtx), lctx := (← getLCtx), opts := (← getOptions), currNamespace := ctxCore.currNamespace, openDecls := ctxCore.openDecls } e
@[inline] protected def orelse (x y : MetaM α) : MetaM α := do
let env ← getEnv

View file

@ -0,0 +1,16 @@
import Lean
open Lean Lean.Meta
def Set (α : Type) : Type :=
α → Prop
def Set.empty {α : Type} : Set α :=
fun a => False
def Set.insert (s : Set α) (a : α) : Set α :=
fun x => x = a s a
#eval show MetaM Unit from do
let insertType ← inferType (mkConst `Set.insert)
let ⟨mvars, bInfos, resultType⟩ ← forallMetaBoundedTelescope insertType 3
println! "{resultType}"

View file

@ -0,0 +1 @@
Set ?_uniq.1