feat: quotation scopes in match_syntax
This commit is contained in:
parent
c63b770a7c
commit
6fc03d0f29
3 changed files with 111 additions and 31 deletions
|
|
@ -1138,6 +1138,17 @@ instance {α : Type u} {m : Type u → Type v} [Monad m] : Inhabited (α → m
|
|||
instance {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α] : Inhabited (m α) where
|
||||
default := pure arbitrary
|
||||
|
||||
-- A fusion of Haskell's `sequence` and `map`
|
||||
def Array.sequenceMap {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (f : α → m β) : m (Array β) :=
|
||||
let rec loop (i : Nat) (j : Nat) (bs : Array β) : m (Array β) :=
|
||||
dite (Less j as.size)
|
||||
(fun hlt =>
|
||||
match i with
|
||||
| 0 => pure bs
|
||||
| Nat.succ i' => Bind.bind (f (as.get ⟨j, hlt⟩)) fun b => loop i' (hAdd j 1) (bs.push b))
|
||||
(fun _ => bs)
|
||||
loop as.size 0 Array.empty
|
||||
|
||||
/-- A Function for lifting a computation from an inner Monad to an outer Monad.
|
||||
Like [MonadTrans](https://hackage.haskell.org/package/transformers-0.5.5.0/docs/Control-Monad-Trans-Class.html),
|
||||
but `n` does not have to be a monad transformer.
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ private abbrev Alt := List Syntax × Syntax
|
|||
|
||||
/-- Information on a pattern's head that influences the compilation of a single
|
||||
match step. -/
|
||||
structure HeadInfo where
|
||||
structure BasicHeadInfo where
|
||||
-- Node kind to match, if any
|
||||
kind : Option SyntaxNodeKind := none
|
||||
-- Nested patterns for each argument, if any. In a single match step, we only
|
||||
|
|
@ -208,21 +208,34 @@ structure HeadInfo where
|
|||
-- bind pattern variables.
|
||||
rhsFn : Syntax → TermElabM Syntax := pure
|
||||
|
||||
instance : Inhabited HeadInfo := ⟨{}⟩
|
||||
inductive HeadInfo where
|
||||
| basic (bhi : BasicHeadInfo)
|
||||
| antiquotScope (stx : Syntax)
|
||||
|
||||
open HeadInfo
|
||||
|
||||
instance : Inhabited HeadInfo := ⟨basic {}⟩
|
||||
|
||||
/-- `h1.generalizes h2` iff h1 is equal to or more general than h2, i.e. it matches all nodes
|
||||
h2 matches. This induces a partial ordering. -/
|
||||
def HeadInfo.generalizes : HeadInfo → HeadInfo → Bool
|
||||
| { kind := none, .. }, _ => true
|
||||
| { kind := some k1, argPats := none, .. },
|
||||
{ kind := some k2, .. } => k1 == k2
|
||||
| { kind := some k1, argPats := some ps1, .. },
|
||||
{ kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size
|
||||
| _, _ => false
|
||||
| basic { kind := none, .. }, _ => true
|
||||
| basic { kind := some k1, argPats := none, .. },
|
||||
basic { kind := some k2, .. } => k1 == k2
|
||||
| basic { kind := some k1, argPats := some ps1, .. },
|
||||
basic { kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size
|
||||
-- roughmost approximation for now
|
||||
| antiquotScope stx1, antiquotScope stx2 => stx1 == stx2
|
||||
| _, _ => false
|
||||
|
||||
def mkTuple : Array Syntax → TermElabM Syntax
|
||||
| #[] => `(())
|
||||
| #[e] => e
|
||||
| es => `(($(es[0]), $(es.eraseIdx 0)*))
|
||||
|
||||
private def getHeadInfo (alt : Alt) : HeadInfo :=
|
||||
let pat := alt.fst.head!;
|
||||
let unconditional (rhsFn) := { rhsFn := rhsFn : HeadInfo };
|
||||
let unconditional (rhsFn) := basic { rhsFn := rhsFn };
|
||||
-- variable pattern
|
||||
if pat.isIdent then unconditional $ fun rhs => `(let $pat := discr; $rhs)
|
||||
-- wildcard pattern
|
||||
|
|
@ -250,18 +263,21 @@ private def getHeadInfo (alt : Alt) : HeadInfo :=
|
|||
let anti := getAntiquotTerm quoted
|
||||
-- Splices should only appear inside a nullKind node, see next case
|
||||
if isAntiquotSplice quoted then unconditional $ fun _ => throwErrorAt quoted "unexpected antiquotation splice"
|
||||
else if anti.isIdent then { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) }
|
||||
else if isAntiquotScope quoted then unconditional $ fun _ => throwErrorAt quoted "unexpected antiquotation scope"
|
||||
else if anti.isIdent then basic { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) }
|
||||
else unconditional fun _ => throwErrorAt! anti "match_syntax: antiquotation must be variable {anti}"
|
||||
else if isAntiquotSplicePat quoted && quoted.getArgs.size == 1 then
|
||||
-- quotation is a single antiquotation splice => bind args array
|
||||
let anti := getAntiquotTerm quoted[0]
|
||||
unconditional fun rhs => `(let $anti := Syntax.getArgs discr; $rhs)
|
||||
-- TODO: support for more complex antiquotation splices
|
||||
else if quoted.getArgs.size == 1 && isAntiquotScope quoted[0] then
|
||||
antiquotScope quoted[0]
|
||||
else
|
||||
-- not an antiquotation or escaped antiquotation: match head shape
|
||||
let quoted := unescapeAntiquot quoted
|
||||
let argPats := quoted.getArgs.map (pat.setArg 1);
|
||||
{ kind := quoted.getKind, argPats := argPats }
|
||||
basic { kind := quoted.getKind, argPats := argPats }
|
||||
else
|
||||
unconditional $ fun _ => throwErrorAt! pat "match_syntax: unexpected pattern kind {pat}"
|
||||
|
||||
|
|
@ -270,15 +286,18 @@ private def getHeadInfo (alt : Alt) : HeadInfo :=
|
|||
-- Ex: `($a + (- $b)) => `($a), `(+), `(- $b)
|
||||
-- Note: The atom pattern `(+) will be discarded in a later step
|
||||
private def explodeHeadPat (numArgs : Nat) : HeadInfo × Alt → TermElabM Alt
|
||||
| (info, (pat::pats, rhs)) => do
|
||||
| (basic info, (pat::pats, rhs)) => do
|
||||
let newPats := match info.argPats with
|
||||
| some argPats => argPats.toList
|
||||
| none => List.replicate numArgs $ Unhygienic.run `(_)
|
||||
let rhs ← info.rhsFn rhs
|
||||
pure (newPats ++ pats, rhs)
|
||||
| (antiquotScope _, (pat::pats, rhs)) => (pats, rhs)
|
||||
| _ => unreachable!
|
||||
|
||||
private partial def compileStxMatch : List Syntax → List Alt → TermElabM Syntax
|
||||
private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : TermElabM Syntax := do
|
||||
trace[Elab.match_syntax]! "match_syntax {discrs} with {alts}"
|
||||
match discrs, alts with
|
||||
| [], ([], rhs)::_ => pure rhs -- nothing left to match
|
||||
| _, [] => throwError "non-exhaustive 'match_syntax'"
|
||||
| discr::discrs, alts => do
|
||||
|
|
@ -287,24 +306,65 @@ private partial def compileStxMatch : List Syntax → List Alt → TermElabM Syn
|
|||
-- If there are multiple minimal elements, the choice does not matter.
|
||||
let (info, alt) := alts.tail!.foldl (fun (min : HeadInfo × Alt) (alt : HeadInfo × Alt) => if min.1.generalizes alt.1 then alt else min) alts.head!;
|
||||
-- introduce pattern matches on the discriminant's children if there are any nested patterns
|
||||
let newDiscrs ← match info.argPats with
|
||||
| some pats => (List.range pats.size).mapM fun i => `(Syntax.getArg discr $(quote i))
|
||||
| none => pure []
|
||||
let newDiscrs ← match info with
|
||||
| basic { argPats := some pats, .. } => (List.range pats.size).mapM fun i => `(Syntax.getArg discr $(quote i))
|
||||
| _ => pure []
|
||||
-- collect matching alternatives and explode them
|
||||
let yesAlts := alts.filter fun (alt : HeadInfo × Alt) => alt.1.generalizes info
|
||||
let yesAlts ← yesAlts.mapM $ explodeHeadPat newDiscrs.length
|
||||
-- NOTE: use fresh macro scopes for recursive call so that different `discr`s introduced by the quotations below do not collide
|
||||
let yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts
|
||||
let some kind ← pure info.kind
|
||||
-- unconditional match step
|
||||
| `(let discr := $discr; $yes)
|
||||
let mkNo := do
|
||||
let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map (·.2)
|
||||
withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts
|
||||
match info with
|
||||
-- unconditional match step
|
||||
| basic { kind := none, .. } => `(let discr := $discr; $yes)
|
||||
-- conditional match step
|
||||
let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map (·.2)
|
||||
let no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts
|
||||
let cond ← match info.argPats with
|
||||
| some pats => `(and (Syntax.isOfKind discr $(quote kind)) (BEq.beq (Array.size (Syntax.getArgs discr)) $(quote pats.size)))
|
||||
| none => `(Syntax.isOfKind discr $(quote kind))
|
||||
`(let discr := $discr; ite (Eq $cond true) $yes $no)
|
||||
| basic { kind := some kind, argPats := pats, .. } =>
|
||||
let cond ← match pats with
|
||||
| some pats => `(and (Syntax.isOfKind discr $(quote kind)) (BEq.beq (Array.size (Syntax.getArgs discr)) $(quote pats.size)))
|
||||
| none => `(Syntax.isOfKind discr $(quote kind))
|
||||
let no ← mkNo
|
||||
`(let discr := $discr; ite (Eq $cond true) $yes $no)
|
||||
-- terrifying match step
|
||||
| antiquotScope scope =>
|
||||
let k := antiquotScopeKind? scope
|
||||
let contents := getAntiquotScopeContents scope
|
||||
let ids ← getAntiquotationIds scope
|
||||
let no ← mkNo
|
||||
match k with
|
||||
| `optional =>
|
||||
let mut yesMatch := yes
|
||||
for id in ids do
|
||||
yesMatch ← `(let $id := some $id; $yesMatch)
|
||||
let mut yesNoMatch := yes
|
||||
for id in ids do
|
||||
yesNoMatch ← `(let $id := none; $yesNoMatch)
|
||||
`(let discr := $discr;
|
||||
if discr.isNone then $yesNoMatch
|
||||
else match_syntax discr with
|
||||
| `($(mkNullNode contents)) => $yesMatch
|
||||
| _ => $no)
|
||||
| _ =>
|
||||
let mut discrs ← `(Syntax.getArgs $discr)
|
||||
if k == `sepBy then
|
||||
discrs ← `(Array.getSepElems $discrs)
|
||||
let ids := ids.toArray
|
||||
let tuple ← mkTuple ids
|
||||
let mut yes := yes
|
||||
let resId ← match ids with
|
||||
| #[id] => id
|
||||
| _ =>
|
||||
for i in [:ids.size] do
|
||||
let idx := Syntax.mkLit fieldIdxKind (toString (i + 1));
|
||||
yes ← `(let $(ids[i]) := tuples.map (·.$idx:fieldIdx); $yes)
|
||||
`(tuples)
|
||||
`(match ($(discrs).sequenceMap fun discr => match_syntax discr with
|
||||
| `($(contents[0])) => some $tuple
|
||||
| _ => none) with
|
||||
| some $resId => $yes
|
||||
| none => $no)
|
||||
| _, _ => unreachable!
|
||||
|
||||
-- Get all pattern vars (as `Syntax.ident`s) in `stx`
|
||||
|
|
@ -352,9 +412,15 @@ def match_syntax.expand (stx : Syntax) : TermElabM Syntax := do
|
|||
let rhs := alt.getArg 2
|
||||
pure ([pat], rhs)
|
||||
-- letBindRhss (compileStxMatch stx [discr]) alts.toList []
|
||||
compileStxMatch [discr] alts.toList
|
||||
let stx ← compileStxMatch [discr] alts.toList
|
||||
trace[Elab.match_syntax.result]! "{stx}"
|
||||
stx
|
||||
|
||||
@[builtinTermElab «match_syntax»] def elabMatchSyntax : TermElab :=
|
||||
adaptExpander match_syntax.expand
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Elab.match_syntax
|
||||
registerTraceClass `Elab.match_syntax.result
|
||||
|
||||
end Lean.Elab.Term.Quotation
|
||||
|
|
|
|||
|
|
@ -38,11 +38,14 @@ end Syntax
|
|||
#eval run $ do let a ← `(a.{0}); match_syntax a with `($id:ident) => pure id | _ => pure a
|
||||
#eval run $ do let a ← `(match a with | a => 1 | _ => 2); match_syntax a with `(match $e with $eqns:matchAlt*) => pure eqns | _ => pure #[]
|
||||
|
||||
#eval run do let a ← some <$> `(a); `({ a := a $[: $(id a)]?})
|
||||
#eval run do let a ← pure none; `({ a := a $[: $a]?})
|
||||
def f (stx : Syntax) : Unhygienic Syntax := match_syntax stx with
|
||||
| `({ a := a $[: $a]?}) => `({ a := a $[: $(id a)]?})
|
||||
| _ => unreachable!
|
||||
#eval run do f (← `({ a := a : a }))
|
||||
#eval run do f (← `({ a := a }))
|
||||
|
||||
#eval run do
|
||||
let pats := #[← `(a), ← `(a + 1)]
|
||||
let rhss := #[← `(b), ← `(b + 1)]
|
||||
`(match a with $[$pats => $rhss]|*)
|
||||
match_syntax ← `(match a with a => b | a + 1 => b + 1) with
|
||||
| `(match a with $[$pats => $rhss]|*) => `(match a with $[$pats => $rhss]|*)
|
||||
| _ => unreachable!
|
||||
end Lean
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue