perf: let*-bind syntax match RHSs before duplicating them

This commit is contained in:
Sebastian Ullrich 2020-12-22 15:46:55 +01:00
parent e797ce3fb7
commit 93518d4e42
2 changed files with 33 additions and 33 deletions

View file

@ -280,15 +280,11 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
let no ← no
match k with
| `optional =>
let mut yesMatch := yes
for id in ids do
yesMatch ← `(let $id := some $id; $yesMatch)
let mut yesNoMatch := yes
for id in ids do
yesNoMatch ← `(let $id := none; $yesNoMatch)
`(if discr.isNone then $yesNoMatch
let nones := mkArray ids.size (← `(none))
`(let* yes _ $ids* := $yes;
if discr.isNone then yes () $[ $nones]*
else match discr with
| `($(mkNullNode contents)) => $yesMatch
| `($(mkNullNode contents)) => yes () $[ (some $ids)]*
| _ => $no)
| _ =>
let mut discrs ← `(Syntax.getArgs discr)
@ -368,6 +364,23 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
| r => r }
| _ => throwErrorAt! pat "match_syntax: unexpected pattern kind {pat}"
-- Bind right-hand side to new `let*` decl in order to prevent code duplication
private def deduplicate (floatedLetDecls : Array Syntax) : Alt → TermElabM (Array Syntax × Alt)
-- NOTE: new macro scope so that introduced bindings do not collide
| (pats, rhs) => do
if let `($f:ident $[ $args:ident]*) := rhs then
-- looks simple enough/created by this function, skip
return (floatedLetDecls, (pats, rhs))
withFreshMacroScope do
match ← getPatternsVars pats.toArray with
| #[] =>
-- no antiquotations => introduce Unit parameter to preserve evaluation order
let rhs' ← `(rhs Unit.unit)
(floatedLetDecls.push (← `(letDecl|rhs _ := $rhs)), (pats, rhs'))
| vars =>
let rhs' ← `(rhs $vars*)
(floatedLetDecls.push (← `(letDecl|rhs $vars:ident* := $rhs)), (pats, rhs'))
private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : TermElabM Syntax := do
trace[Elab.match_syntax]! "match {discrs} with {alts}"
match discrs, alts with
@ -380,7 +393,10 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T
let mut yesAlts := #[]
let mut undecidedAlts := #[]
let mut nonExhaustiveAlts := #[]
for alt in alts do match alt with
let mut floatedLetDecls := #[]
for alt in alts do
let mut alt := alt
match alt with
| (covered f exh, alt) =>
-- we can only factor out a common check if there are no undecided patterns in between;
-- otherwise we would change the order of alternatives
@ -389,14 +405,16 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T
if !exh then
nonExhaustiveAlts := nonExhaustiveAlts.push alt
else
(floatedLetDecls, alt) ← deduplicate floatedLetDecls alt
undecidedAlts := undecidedAlts.push alt
nonExhaustiveAlts := nonExhaustiveAlts.push alt
| (undecided, alt) =>
(floatedLetDecls, alt) ← deduplicate floatedLetDecls alt
undecidedAlts := undecidedAlts.push alt
nonExhaustiveAlts := nonExhaustiveAlts.push alt
| (uncovered, alt) =>
nonExhaustiveAlts := nonExhaustiveAlts.push alt
let m ← info.doMatch
let mut stx ← info.doMatch
(yes := fun newDiscrs => do
let mut yesAlts := yesAlts
if !undecidedAlts.isEmpty then
@ -408,33 +426,14 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T
yesAlts := yesAlts.push (pats, rhs)
withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts.toList)
(no := withFreshMacroScope $ compileStxMatch (discr::discrs) nonExhaustiveAlts.toList)
`(let discr := $discr; $m)
for d in floatedLetDecls do
stx ← `(let* $d:letDecl; $stx)
`(let discr := $discr; $stx)
| _, _ => unreachable!
-- Transform alternatives by binding all right-hand sides to outside the match in order to prevent
-- code duplication during match compilation
private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax
| [], altsRev' => cont altsRev'.reverse
| (pats, rhs)::alts, altsRev' => do
match ← getPatternsVars pats.toArray with
-- no antiquotations => introduce Unit parameter to preserve evaluation order
| #[] =>
-- NOTE: references binding below
let rhs' ← `(rhs ())
-- NOTE: new macro scope so that introduced bindings do not collide
let stx ← withFreshMacroScope $ letBindRhss cont alts ((pats, rhs')::altsRev')
`(let rhs := fun _ => $rhs; $stx)
| vars =>
-- rhs ← `(fun $vars* => $rhs)
let rhs := Syntax.node `Lean.Parser.Term.fun #[mkAtom "fun", Syntax.node `null vars, mkAtom "=>", rhs]
let rhs' ← `(rhs)
let stx ← withFreshMacroScope $ letBindRhss cont alts ((pats, rhs')::altsRev')
`(let rhs := $rhs; $stx)
def match_syntax.expand (stx : Syntax) : TermElabM Syntax := do
match stx with
| `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do
-- letBindRhss ...
if !patss.any (·.any (fun
| `($id@$pat) => pat.isQuot
| pat => pat.isQuot)) then

View file

@ -25,9 +25,10 @@ partial def getPatternVars (stx : Syntax) : TermElabM (Array Syntax) :=
if stx.isQuot then
getAntiquotationIds stx
else match stx with
| `(_) => #[]
| `($id:ident) => #[id]
| `($id:ident@$e) => do (← getPatternVars e).push id
| _ => throwErrorAt stx "unsupported pattern in syntax match"
| _ => throwErrorAt! stx "unsupported pattern in syntax match{indentD stx}"
partial def getPatternsVars (pats : Array Syntax) : TermElabM (Array Syntax) :=
pats.foldlM (fun vars pat => do return vars ++ (← getPatternVars pat)) #[]