From 82bce7ebecfbfc01ec5541e2e07297f3d5dbe2ec Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 12 Feb 2022 12:01:08 -0800 Subject: [PATCH] fix: declare local instaces occurring in patterns --- src/Lean/Elab/Match.lean | 11 ++++++----- src/Lean/Meta/Basic.lean | 27 +++++++++++++++++---------- tests/lean/run/instPatVar.lean | 9 +++++++++ 3 files changed, 32 insertions(+), 15 deletions(-) create mode 100644 tests/lean/run/instPatVar.lean diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index a08514a807..636edddb7d 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -562,11 +562,12 @@ private def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : ExceptT P trace[Elab.match] "patternVars: {patternVars}" withPatternVars patternVars fun patternVarDecls => do withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do - let rhs ← elabTermEnsuringType alt.rhs matchType - let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr - let rhs ← if xs.isEmpty then pure <| mkSimpleThunk rhs else mkLambdaFVars xs rhs - trace[Elab.match] "rhs: {rhs}" - return (altLHS, rhs) + withLocalInstances altLHS.fvarDecls do + let rhs ← elabTermEnsuringType alt.rhs matchType + let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr + let rhs ← if xs.isEmpty then pure <| mkSimpleThunk rhs else mkLambdaFVars xs rhs + trace[Elab.match] "rhs: {rhs}" + return (altLHS, rhs) /-- Collect problematic index for the "discriminant refinement feature". This method is invoked diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index b133ca6e8c..f87ca4dcff 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -946,21 +946,28 @@ private def withLetDeclImp (n : Name) (type : Expr) (val : Expr) (k : Expr → M def withLetDecl (name : Name) (type : Expr) (val : Expr) (k : Expr → n α) : n α := map1MetaM (fun k => withLetDeclImp name type val k) k +def withLocalInstancesImp (decls : List LocalDecl) (k : MetaM α) : MetaM α := do + let localInsts := (← read).localInstances + let size := localInsts.size + let localInstsNew ← decls.foldlM (init := localInsts) fun localInstsNew decl => do + match (← isClass? decl.type) with + | none => return localInstsNew + | some className => return localInstsNew.push { className, fvar := decl.toExpr } + if localInstsNew.size == size then + k + else + resettingSynthInstanceCache <| withReader (fun ctx => { ctx with localInstances := localInstsNew }) k + +/-- Register any local instance in `decls` -/ +def withLocalInstances (decls : List LocalDecl) : n α → n α := + mapMetaM <| withLocalInstancesImp decls + private def withExistingLocalDeclsImp (decls : List LocalDecl) (k : MetaM α) : MetaM α := do let ctx ← read let numLocalInstances := ctx.localInstances.size let lctx := decls.foldl (fun (lctx : LocalContext) decl => lctx.addDecl decl) ctx.lctx withReader (fun ctx => { ctx with lctx := lctx }) do - let newLocalInsts ← decls.foldlM - (fun (newlocalInsts : Array LocalInstance) (decl : LocalDecl) => (do { - match (← isClass? decl.type) with - | none => pure newlocalInsts - | some c => pure <| newlocalInsts.push { className := c, fvar := decl.toExpr } } : MetaM _)) - ctx.localInstances; - if newLocalInsts.size == numLocalInstances then - k - else - resettingSynthInstanceCache <| withReader (fun ctx => { ctx with localInstances := newLocalInsts }) k + withLocalInstancesImp decls k def withExistingLocalDecls (decls : List LocalDecl) : n α → n α := mapMetaM <| withExistingLocalDeclsImp decls diff --git a/tests/lean/run/instPatVar.lean b/tests/lean/run/instPatVar.lean new file mode 100644 index 0000000000..35592a8e2e --- /dev/null +++ b/tests/lean/run/instPatVar.lean @@ -0,0 +1,9 @@ +class Pretty (α : Type u) where + pretty : α → Std.Format + +export Pretty (pretty) + +def concat (xs : List ((α : Type u) × Pretty α × α)) : Std.Format := + match xs with + | [] => "" + | ⟨_, _, v⟩ :: xs => pretty v ++ concat xs