perf: let*-bind syntax match RHSs before duplicating them
This commit is contained in:
parent
e797ce3fb7
commit
93518d4e42
2 changed files with 33 additions and 33 deletions
|
|
@ -280,15 +280,11 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
|
|||
let no ← no
|
||||
match k with
|
||||
| `optional =>
|
||||
let mut yesMatch := yes
|
||||
for id in ids do
|
||||
yesMatch ← `(let $id := some $id; $yesMatch)
|
||||
let mut yesNoMatch := yes
|
||||
for id in ids do
|
||||
yesNoMatch ← `(let $id := none; $yesNoMatch)
|
||||
`(if discr.isNone then $yesNoMatch
|
||||
let nones := mkArray ids.size (← `(none))
|
||||
`(let* yes _ $ids* := $yes;
|
||||
if discr.isNone then yes () $[ $nones]*
|
||||
else match discr with
|
||||
| `($(mkNullNode contents)) => $yesMatch
|
||||
| `($(mkNullNode contents)) => yes () $[ (some $ids)]*
|
||||
| _ => $no)
|
||||
| _ =>
|
||||
let mut discrs ← `(Syntax.getArgs discr)
|
||||
|
|
@ -368,6 +364,23 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
|
|||
| r => r }
|
||||
| _ => throwErrorAt! pat "match_syntax: unexpected pattern kind {pat}"
|
||||
|
||||
-- Bind right-hand side to new `let*` decl in order to prevent code duplication
|
||||
private def deduplicate (floatedLetDecls : Array Syntax) : Alt → TermElabM (Array Syntax × Alt)
|
||||
-- NOTE: new macro scope so that introduced bindings do not collide
|
||||
| (pats, rhs) => do
|
||||
if let `($f:ident $[ $args:ident]*) := rhs then
|
||||
-- looks simple enough/created by this function, skip
|
||||
return (floatedLetDecls, (pats, rhs))
|
||||
withFreshMacroScope do
|
||||
match ← getPatternsVars pats.toArray with
|
||||
| #[] =>
|
||||
-- no antiquotations => introduce Unit parameter to preserve evaluation order
|
||||
let rhs' ← `(rhs Unit.unit)
|
||||
(floatedLetDecls.push (← `(letDecl|rhs _ := $rhs)), (pats, rhs'))
|
||||
| vars =>
|
||||
let rhs' ← `(rhs $vars*)
|
||||
(floatedLetDecls.push (← `(letDecl|rhs $vars:ident* := $rhs)), (pats, rhs'))
|
||||
|
||||
private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : TermElabM Syntax := do
|
||||
trace[Elab.match_syntax]! "match {discrs} with {alts}"
|
||||
match discrs, alts with
|
||||
|
|
@ -380,7 +393,10 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T
|
|||
let mut yesAlts := #[]
|
||||
let mut undecidedAlts := #[]
|
||||
let mut nonExhaustiveAlts := #[]
|
||||
for alt in alts do match alt with
|
||||
let mut floatedLetDecls := #[]
|
||||
for alt in alts do
|
||||
let mut alt := alt
|
||||
match alt with
|
||||
| (covered f exh, alt) =>
|
||||
-- we can only factor out a common check if there are no undecided patterns in between;
|
||||
-- otherwise we would change the order of alternatives
|
||||
|
|
@ -389,14 +405,16 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T
|
|||
if !exh then
|
||||
nonExhaustiveAlts := nonExhaustiveAlts.push alt
|
||||
else
|
||||
(floatedLetDecls, alt) ← deduplicate floatedLetDecls alt
|
||||
undecidedAlts := undecidedAlts.push alt
|
||||
nonExhaustiveAlts := nonExhaustiveAlts.push alt
|
||||
| (undecided, alt) =>
|
||||
(floatedLetDecls, alt) ← deduplicate floatedLetDecls alt
|
||||
undecidedAlts := undecidedAlts.push alt
|
||||
nonExhaustiveAlts := nonExhaustiveAlts.push alt
|
||||
| (uncovered, alt) =>
|
||||
nonExhaustiveAlts := nonExhaustiveAlts.push alt
|
||||
let m ← info.doMatch
|
||||
let mut stx ← info.doMatch
|
||||
(yes := fun newDiscrs => do
|
||||
let mut yesAlts := yesAlts
|
||||
if !undecidedAlts.isEmpty then
|
||||
|
|
@ -408,33 +426,14 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T
|
|||
yesAlts := yesAlts.push (pats, rhs)
|
||||
withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts.toList)
|
||||
(no := withFreshMacroScope $ compileStxMatch (discr::discrs) nonExhaustiveAlts.toList)
|
||||
`(let discr := $discr; $m)
|
||||
for d in floatedLetDecls do
|
||||
stx ← `(let* $d:letDecl; $stx)
|
||||
`(let discr := $discr; $stx)
|
||||
| _, _ => unreachable!
|
||||
|
||||
-- Transform alternatives by binding all right-hand sides to outside the match in order to prevent
|
||||
-- code duplication during match compilation
|
||||
private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax
|
||||
| [], altsRev' => cont altsRev'.reverse
|
||||
| (pats, rhs)::alts, altsRev' => do
|
||||
match ← getPatternsVars pats.toArray with
|
||||
-- no antiquotations => introduce Unit parameter to preserve evaluation order
|
||||
| #[] =>
|
||||
-- NOTE: references binding below
|
||||
let rhs' ← `(rhs ())
|
||||
-- NOTE: new macro scope so that introduced bindings do not collide
|
||||
let stx ← withFreshMacroScope $ letBindRhss cont alts ((pats, rhs')::altsRev')
|
||||
`(let rhs := fun _ => $rhs; $stx)
|
||||
| vars =>
|
||||
-- rhs ← `(fun $vars* => $rhs)
|
||||
let rhs := Syntax.node `Lean.Parser.Term.fun #[mkAtom "fun", Syntax.node `null vars, mkAtom "=>", rhs]
|
||||
let rhs' ← `(rhs)
|
||||
let stx ← withFreshMacroScope $ letBindRhss cont alts ((pats, rhs')::altsRev')
|
||||
`(let rhs := $rhs; $stx)
|
||||
|
||||
def match_syntax.expand (stx : Syntax) : TermElabM Syntax := do
|
||||
match stx with
|
||||
| `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do
|
||||
-- letBindRhss ...
|
||||
if !patss.any (·.any (fun
|
||||
| `($id@$pat) => pat.isQuot
|
||||
| pat => pat.isQuot)) then
|
||||
|
|
|
|||
|
|
@ -25,9 +25,10 @@ partial def getPatternVars (stx : Syntax) : TermElabM (Array Syntax) :=
|
|||
if stx.isQuot then
|
||||
getAntiquotationIds stx
|
||||
else match stx with
|
||||
| `(_) => #[]
|
||||
| `($id:ident) => #[id]
|
||||
| `($id:ident@$e) => do (← getPatternVars e).push id
|
||||
| _ => throwErrorAt stx "unsupported pattern in syntax match"
|
||||
| _ => throwErrorAt! stx "unsupported pattern in syntax match{indentD stx}"
|
||||
|
||||
partial def getPatternsVars (pats : Array Syntax) : TermElabM (Array Syntax) :=
|
||||
pats.foldlM (fun vars pat => do return vars ++ (← getPatternVars pat)) #[]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue