From fcf4df2f5cbf04c23c352a5590aea2ed493f886a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 12 Aug 2020 13:58:27 -0700 Subject: [PATCH] fix: do not use named holes for representing `_` in patterns --- src/Lean/Elab/Match.lean | 53 ++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 4bd0236430..c1dc72b2ba 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -105,6 +105,24 @@ instance PatternVar.hasToString : HasToString PatternVar := @[init] private def registerAuxiliaryNodeKind : IO Unit := Parser.registerBuiltinNodeKind `MVarWithIdKind +/-- + Create an auxiliary Syntax node wrapping a fresh metavariable id. + We use this kind of Syntax for representing `_` occurring in patterns. + The metavariables are created before we elaborate the patterns into `Expr`s. -/ +private def mkMVarSyntax : TermElabM Syntax := do +mvarId ← mkFreshId; +pure $ Syntax.node `MVarWithIdKind #[Syntax.node mvarId #[]] + +/-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/ +private def getMVarSyntaxMVarId (stx : Syntax) : MVarId := +(stx.getArg 0).getKind + +/-- + The elaboration function for `Syntax` created using `mkMVarSyntax`. + It just converts the metavariable id wrapped by the Syntax into an `Expr`. -/ +@[builtinTermElab MVarWithIdKind] def elabMVarWithIdKind : TermElab := +fun stx expectedType? => pure $ mkMVar (getMVarSyntaxMVarId stx) + /- Patterns define new local variables. This module collect them and preprocess `_` occurring in patterns. @@ -223,8 +241,8 @@ private partial def collect : Syntax → M Syntax }; pure $ Syntax.node k $ args.set! 2 $ mkNullNode fields else if k == `Lean.Parser.Term.hole then do - r ← `(?x); - modify fun s => { s with vars := s.vars.push $ PatternVar.anonymousVar $ (r.getArg 1).getId }; + r ← liftM mkMVarSyntax; + modify fun s => { s with vars := s.vars.push $ PatternVar.anonymousVar $ getMVarSyntaxMVarId r }; pure r else if k == `Lean.Parser.Term.paren then let arg := args.get! 1; @@ -299,19 +317,25 @@ private def collectPatternVars (alt : MatchAltView) : TermElabM (Array PatternVa (alt, s) ← (CollectPatternVars.main alt).run {}; pure (s.vars, alt) -private partial def withPatternVarsAux {α} (ref : Syntax) (pVars : Array PatternVar) (k : TermElabM α) : Nat → TermElabM α -| i => +private partial def withPatternVarsAux {α} (ref : Syntax) (pVars : Array PatternVar) (k : Array Expr → TermElabM α) : Nat → Array Expr → TermElabM α +| i, xs => if h : i < pVars.size then match pVars.get ⟨i, h⟩ with - | PatternVar.anonymousVar _ => withPatternVarsAux (i+1) - | PatternVar.localVar userName => do + | PatternVar.anonymousVar mvarId => do + withPatternVarsAux (i+1) (xs.push (mkMVar mvarId)) + | PatternVar.localVar userName => do type ← mkFreshTypeMVar ref; - withLocalDecl ref userName BinderInfo.default type fun _ => withPatternVarsAux (i+1) - else - k + withLocalDecl ref userName BinderInfo.default type fun x => withPatternVarsAux (i+1) (xs.push x) + else do + /- We must create the metavariables for `PatternVar.anonymousVar` AFTER we create the new local decls using `withLocalDecl`. + Reason: their scope must include the new local decls since some of them will be assigned by typing constraints. -/ + pVars.forM fun pvar => match pvar with + | PatternVar.anonymousVar mvarId => do _ ← mkFreshExprMVarWithId ref mvarId; pure () + | _ => pure (); + k xs -private def withPatternVars {α} (ref : Syntax) (pVars : Array PatternVar) (k : TermElabM α) : TermElabM α := -withPatternVarsAux ref pVars k 0 +private def withPatternVars {α} (ref : Syntax) (pVars : Array PatternVar) (k : Array Expr → TermElabM α) : TermElabM α := +withPatternVarsAux ref pVars k 0 #[] private partial def elabPatternsAux (ref : Syntax) (patternStxs : Array Syntax) : Nat → Expr → Array Expr → TermElabM (Array Expr) | i, matchType, patterns => @@ -319,7 +343,9 @@ private partial def elabPatternsAux (ref : Syntax) (patternStxs : Array Syntax) matchType ← whnf ref matchType; match matchType with | Expr.forallE _ d b _ => do - pattern ← elabTerm (patternStxs.get ⟨i, h⟩) d; + let patternStx := patternStxs.get ⟨i, h⟩; + pattern ← elabTerm patternStx d; + pattern ← ensureHasType patternStx d pattern; elabPatternsAux (i+1) (b.instantiate1 pattern) (patterns.push pattern) | _ => throwError ref "unexpected match type" else @@ -333,7 +359,8 @@ pure patterns def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := do (patternVars, alt) ← collectPatternVars alt; trace `Elab.match alt.ref fun _ => "patternVars: " ++ toString patternVars; -withPatternVars alt.ref patternVars do +withPatternVars alt.ref patternVars fun xs => do + trace `Elab.match alt.ref fun _ => "xs: " ++ xs; ps ← elabPatterns alt.ref alt.patterns matchType; -- TODO pure (⟨[], []⟩, arbitrary _)