From 6fc03d0f298ff12e75a2ebf6a386f37d4de74bda Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Tue, 8 Dec 2020 11:53:49 +0100 Subject: [PATCH] feat: quotation scopes in `match_syntax` --- src/Init/Prelude.lean | 11 ++++ src/Lean/Elab/Quotation.lean | 118 +++++++++++++++++++++++++++-------- tests/lean/StxQuot.lean | 13 ++-- 3 files changed, 111 insertions(+), 31 deletions(-) diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 1f80379d62..70cf2dc202 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -1138,6 +1138,17 @@ instance {α : Type u} {m : Type u → Type v} [Monad m] : Inhabited (α → m instance {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α] : Inhabited (m α) where default := pure arbitrary +-- A fusion of Haskell's `sequence` and `map` +def Array.sequenceMap {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (f : α → m β) : m (Array β) := + let rec loop (i : Nat) (j : Nat) (bs : Array β) : m (Array β) := + dite (Less j as.size) + (fun hlt => + match i with + | 0 => pure bs + | Nat.succ i' => Bind.bind (f (as.get ⟨j, hlt⟩)) fun b => loop i' (hAdd j 1) (bs.push b)) + (fun _ => bs) + loop as.size 0 Array.empty + /-- A Function for lifting a computation from an inner Monad to an outer Monad. Like [MonadTrans](https://hackage.haskell.org/package/transformers-0.5.5.0/docs/Control-Monad-Trans-Class.html), but `n` does not have to be a monad transformer. diff --git a/src/Lean/Elab/Quotation.lean b/src/Lean/Elab/Quotation.lean index 7030349626..8cc0a45796 100644 --- a/src/Lean/Elab/Quotation.lean +++ b/src/Lean/Elab/Quotation.lean @@ -197,7 +197,7 @@ private abbrev Alt := List Syntax × Syntax /-- Information on a pattern's head that influences the compilation of a single match step. -/ -structure HeadInfo where +structure BasicHeadInfo where -- Node kind to match, if any kind : Option SyntaxNodeKind := none -- Nested patterns for each argument, if any. In a single match step, we only @@ -208,21 +208,34 @@ structure HeadInfo where -- bind pattern variables. rhsFn : Syntax → TermElabM Syntax := pure -instance : Inhabited HeadInfo := ⟨{}⟩ +inductive HeadInfo where + | basic (bhi : BasicHeadInfo) + | antiquotScope (stx : Syntax) + +open HeadInfo + +instance : Inhabited HeadInfo := ⟨basic {}⟩ /-- `h1.generalizes h2` iff h1 is equal to or more general than h2, i.e. it matches all nodes h2 matches. This induces a partial ordering. -/ def HeadInfo.generalizes : HeadInfo → HeadInfo → Bool - | { kind := none, .. }, _ => true - | { kind := some k1, argPats := none, .. }, - { kind := some k2, .. } => k1 == k2 - | { kind := some k1, argPats := some ps1, .. }, - { kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size - | _, _ => false + | basic { kind := none, .. }, _ => true + | basic { kind := some k1, argPats := none, .. }, + basic { kind := some k2, .. } => k1 == k2 + | basic { kind := some k1, argPats := some ps1, .. }, + basic { kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size + -- roughmost approximation for now + | antiquotScope stx1, antiquotScope stx2 => stx1 == stx2 + | _, _ => false + +def mkTuple : Array Syntax → TermElabM Syntax + | #[] => `(()) + | #[e] => e + | es => `(($(es[0]), $(es.eraseIdx 0)*)) private def getHeadInfo (alt : Alt) : HeadInfo := let pat := alt.fst.head!; - let unconditional (rhsFn) := { rhsFn := rhsFn : HeadInfo }; + let unconditional (rhsFn) := basic { rhsFn := rhsFn }; -- variable pattern if pat.isIdent then unconditional $ fun rhs => `(let $pat := discr; $rhs) -- wildcard pattern @@ -250,18 +263,21 @@ private def getHeadInfo (alt : Alt) : HeadInfo := let anti := getAntiquotTerm quoted -- Splices should only appear inside a nullKind node, see next case if isAntiquotSplice quoted then unconditional $ fun _ => throwErrorAt quoted "unexpected antiquotation splice" - else if anti.isIdent then { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) } + else if isAntiquotScope quoted then unconditional $ fun _ => throwErrorAt quoted "unexpected antiquotation scope" + else if anti.isIdent then basic { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) } else unconditional fun _ => throwErrorAt! anti "match_syntax: antiquotation must be variable {anti}" else if isAntiquotSplicePat quoted && quoted.getArgs.size == 1 then -- quotation is a single antiquotation splice => bind args array let anti := getAntiquotTerm quoted[0] unconditional fun rhs => `(let $anti := Syntax.getArgs discr; $rhs) -- TODO: support for more complex antiquotation splices + else if quoted.getArgs.size == 1 && isAntiquotScope quoted[0] then + antiquotScope quoted[0] else -- not an antiquotation or escaped antiquotation: match head shape let quoted := unescapeAntiquot quoted let argPats := quoted.getArgs.map (pat.setArg 1); - { kind := quoted.getKind, argPats := argPats } + basic { kind := quoted.getKind, argPats := argPats } else unconditional $ fun _ => throwErrorAt! pat "match_syntax: unexpected pattern kind {pat}" @@ -270,15 +286,18 @@ private def getHeadInfo (alt : Alt) : HeadInfo := -- Ex: `($a + (- $b)) => `($a), `(+), `(- $b) -- Note: The atom pattern `(+) will be discarded in a later step private def explodeHeadPat (numArgs : Nat) : HeadInfo × Alt → TermElabM Alt - | (info, (pat::pats, rhs)) => do + | (basic info, (pat::pats, rhs)) => do let newPats := match info.argPats with | some argPats => argPats.toList | none => List.replicate numArgs $ Unhygienic.run `(_) let rhs ← info.rhsFn rhs pure (newPats ++ pats, rhs) + | (antiquotScope _, (pat::pats, rhs)) => (pats, rhs) | _ => unreachable! -private partial def compileStxMatch : List Syntax → List Alt → TermElabM Syntax +private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : TermElabM Syntax := do + trace[Elab.match_syntax]! "match_syntax {discrs} with {alts}" + match discrs, alts with | [], ([], rhs)::_ => pure rhs -- nothing left to match | _, [] => throwError "non-exhaustive 'match_syntax'" | discr::discrs, alts => do @@ -287,24 +306,65 @@ private partial def compileStxMatch : List Syntax → List Alt → TermElabM Syn -- If there are multiple minimal elements, the choice does not matter. let (info, alt) := alts.tail!.foldl (fun (min : HeadInfo × Alt) (alt : HeadInfo × Alt) => if min.1.generalizes alt.1 then alt else min) alts.head!; -- introduce pattern matches on the discriminant's children if there are any nested patterns - let newDiscrs ← match info.argPats with - | some pats => (List.range pats.size).mapM fun i => `(Syntax.getArg discr $(quote i)) - | none => pure [] + let newDiscrs ← match info with + | basic { argPats := some pats, .. } => (List.range pats.size).mapM fun i => `(Syntax.getArg discr $(quote i)) + | _ => pure [] -- collect matching alternatives and explode them let yesAlts := alts.filter fun (alt : HeadInfo × Alt) => alt.1.generalizes info let yesAlts ← yesAlts.mapM $ explodeHeadPat newDiscrs.length -- NOTE: use fresh macro scopes for recursive call so that different `discr`s introduced by the quotations below do not collide let yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts - let some kind ← pure info.kind - -- unconditional match step - | `(let discr := $discr; $yes) + let mkNo := do + let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map (·.2) + withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts + match info with + -- unconditional match step + | basic { kind := none, .. } => `(let discr := $discr; $yes) -- conditional match step - let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map (·.2) - let no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts - let cond ← match info.argPats with - | some pats => `(and (Syntax.isOfKind discr $(quote kind)) (BEq.beq (Array.size (Syntax.getArgs discr)) $(quote pats.size))) - | none => `(Syntax.isOfKind discr $(quote kind)) - `(let discr := $discr; ite (Eq $cond true) $yes $no) + | basic { kind := some kind, argPats := pats, .. } => + let cond ← match pats with + | some pats => `(and (Syntax.isOfKind discr $(quote kind)) (BEq.beq (Array.size (Syntax.getArgs discr)) $(quote pats.size))) + | none => `(Syntax.isOfKind discr $(quote kind)) + let no ← mkNo + `(let discr := $discr; ite (Eq $cond true) $yes $no) + -- terrifying match step + | antiquotScope scope => + let k := antiquotScopeKind? scope + let contents := getAntiquotScopeContents scope + let ids ← getAntiquotationIds scope + let no ← mkNo + match k with + | `optional => + let mut yesMatch := yes + for id in ids do + yesMatch ← `(let $id := some $id; $yesMatch) + let mut yesNoMatch := yes + for id in ids do + yesNoMatch ← `(let $id := none; $yesNoMatch) + `(let discr := $discr; + if discr.isNone then $yesNoMatch + else match_syntax discr with + | `($(mkNullNode contents)) => $yesMatch + | _ => $no) + | _ => + let mut discrs ← `(Syntax.getArgs $discr) + if k == `sepBy then + discrs ← `(Array.getSepElems $discrs) + let ids := ids.toArray + let tuple ← mkTuple ids + let mut yes := yes + let resId ← match ids with + | #[id] => id + | _ => + for i in [:ids.size] do + let idx := Syntax.mkLit fieldIdxKind (toString (i + 1)); + yes ← `(let $(ids[i]) := tuples.map (·.$idx:fieldIdx); $yes) + `(tuples) + `(match ($(discrs).sequenceMap fun discr => match_syntax discr with + | `($(contents[0])) => some $tuple + | _ => none) with + | some $resId => $yes + | none => $no) | _, _ => unreachable! -- Get all pattern vars (as `Syntax.ident`s) in `stx` @@ -352,9 +412,15 @@ def match_syntax.expand (stx : Syntax) : TermElabM Syntax := do let rhs := alt.getArg 2 pure ([pat], rhs) -- letBindRhss (compileStxMatch stx [discr]) alts.toList [] - compileStxMatch [discr] alts.toList + let stx ← compileStxMatch [discr] alts.toList + trace[Elab.match_syntax.result]! "{stx}" + stx @[builtinTermElab «match_syntax»] def elabMatchSyntax : TermElab := adaptExpander match_syntax.expand +builtin_initialize + registerTraceClass `Elab.match_syntax + registerTraceClass `Elab.match_syntax.result + end Lean.Elab.Term.Quotation diff --git a/tests/lean/StxQuot.lean b/tests/lean/StxQuot.lean index 45ec3aa6f0..999dbc185f 100644 --- a/tests/lean/StxQuot.lean +++ b/tests/lean/StxQuot.lean @@ -38,11 +38,14 @@ end Syntax #eval run $ do let a ← `(a.{0}); match_syntax a with `($id:ident) => pure id | _ => pure a #eval run $ do let a ← `(match a with | a => 1 | _ => 2); match_syntax a with `(match $e with $eqns:matchAlt*) => pure eqns | _ => pure #[] -#eval run do let a ← some <$> `(a); `({ a := a $[: $(id a)]?}) -#eval run do let a ← pure none; `({ a := a $[: $a]?}) +def f (stx : Syntax) : Unhygienic Syntax := match_syntax stx with + | `({ a := a $[: $a]?}) => `({ a := a $[: $(id a)]?}) + | _ => unreachable! +#eval run do f (← `({ a := a : a })) +#eval run do f (← `({ a := a })) #eval run do - let pats := #[← `(a), ← `(a + 1)] - let rhss := #[← `(b), ← `(b + 1)] - `(match a with $[$pats => $rhss]|*) + match_syntax ← `(match a with a => b | a + 1 => b + 1) with + | `(match a with $[$pats => $rhss]|*) => `(match a with $[$pats => $rhss]|*) + | _ => unreachable! end Lean