diff --git a/src/Lean/Elab/Quotation.lean b/src/Lean/Elab/Quotation.lean index e741bd2758..a02747c653 100644 --- a/src/Lean/Elab/Quotation.lean +++ b/src/Lean/Elab/Quotation.lean @@ -280,15 +280,11 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo := let no ← no match k with | `optional => - let mut yesMatch := yes - for id in ids do - yesMatch ← `(let $id := some $id; $yesMatch) - let mut yesNoMatch := yes - for id in ids do - yesNoMatch ← `(let $id := none; $yesNoMatch) - `(if discr.isNone then $yesNoMatch + let nones := mkArray ids.size (← `(none)) + `(let* yes _ $ids* := $yes; + if discr.isNone then yes () $[ $nones]* else match discr with - | `($(mkNullNode contents)) => $yesMatch + | `($(mkNullNode contents)) => yes () $[ (some $ids)]* | _ => $no) | _ => let mut discrs ← `(Syntax.getArgs discr) @@ -368,6 +364,23 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo := | r => r } | _ => throwErrorAt! pat "match_syntax: unexpected pattern kind {pat}" +-- Bind right-hand side to new `let*` decl in order to prevent code duplication +private def deduplicate (floatedLetDecls : Array Syntax) : Alt → TermElabM (Array Syntax × Alt) + -- NOTE: new macro scope so that introduced bindings do not collide + | (pats, rhs) => do + if let `($f:ident $[ $args:ident]*) := rhs then + -- looks simple enough/created by this function, skip + return (floatedLetDecls, (pats, rhs)) + withFreshMacroScope do + match ← getPatternsVars pats.toArray with + | #[] => + -- no antiquotations => introduce Unit parameter to preserve evaluation order + let rhs' ← `(rhs Unit.unit) + (floatedLetDecls.push (← `(letDecl|rhs _ := $rhs)), (pats, rhs')) + | vars => + let rhs' ← `(rhs $vars*) + (floatedLetDecls.push (← `(letDecl|rhs $vars:ident* := $rhs)), (pats, rhs')) + private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : TermElabM Syntax := do trace[Elab.match_syntax]! "match {discrs} with {alts}" match discrs, alts with @@ -380,7 +393,10 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T let mut yesAlts := #[] let mut undecidedAlts := #[] let mut nonExhaustiveAlts := #[] - for alt in alts do match alt with + let mut floatedLetDecls := #[] + for alt in alts do + let mut alt := alt + match alt with | (covered f exh, alt) => -- we can only factor out a common check if there are no undecided patterns in between; -- otherwise we would change the order of alternatives @@ -389,14 +405,16 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T if !exh then nonExhaustiveAlts := nonExhaustiveAlts.push alt else + (floatedLetDecls, alt) ← deduplicate floatedLetDecls alt undecidedAlts := undecidedAlts.push alt nonExhaustiveAlts := nonExhaustiveAlts.push alt | (undecided, alt) => + (floatedLetDecls, alt) ← deduplicate floatedLetDecls alt undecidedAlts := undecidedAlts.push alt nonExhaustiveAlts := nonExhaustiveAlts.push alt | (uncovered, alt) => nonExhaustiveAlts := nonExhaustiveAlts.push alt - let m ← info.doMatch + let mut stx ← info.doMatch (yes := fun newDiscrs => do let mut yesAlts := yesAlts if !undecidedAlts.isEmpty then @@ -408,33 +426,14 @@ private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : T yesAlts := yesAlts.push (pats, rhs) withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts.toList) (no := withFreshMacroScope $ compileStxMatch (discr::discrs) nonExhaustiveAlts.toList) - `(let discr := $discr; $m) + for d in floatedLetDecls do + stx ← `(let* $d:letDecl; $stx) + `(let discr := $discr; $stx) | _, _ => unreachable! --- Transform alternatives by binding all right-hand sides to outside the match in order to prevent --- code duplication during match compilation -private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax - | [], altsRev' => cont altsRev'.reverse - | (pats, rhs)::alts, altsRev' => do - match ← getPatternsVars pats.toArray with - -- no antiquotations => introduce Unit parameter to preserve evaluation order - | #[] => - -- NOTE: references binding below - let rhs' ← `(rhs ()) - -- NOTE: new macro scope so that introduced bindings do not collide - let stx ← withFreshMacroScope $ letBindRhss cont alts ((pats, rhs')::altsRev') - `(let rhs := fun _ => $rhs; $stx) - | vars => - -- rhs ← `(fun $vars* => $rhs) - let rhs := Syntax.node `Lean.Parser.Term.fun #[mkAtom "fun", Syntax.node `null vars, mkAtom "=>", rhs] - let rhs' ← `(rhs) - let stx ← withFreshMacroScope $ letBindRhss cont alts ((pats, rhs')::altsRev') - `(let rhs := $rhs; $stx) - def match_syntax.expand (stx : Syntax) : TermElabM Syntax := do match stx with | `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do - -- letBindRhss ... if !patss.any (·.any (fun | `($id@$pat) => pat.isQuot | pat => pat.isQuot)) then diff --git a/src/Lean/Elab/Quotation/Util.lean b/src/Lean/Elab/Quotation/Util.lean index d0122df901..7acaebf94d 100644 --- a/src/Lean/Elab/Quotation/Util.lean +++ b/src/Lean/Elab/Quotation/Util.lean @@ -25,9 +25,10 @@ partial def getPatternVars (stx : Syntax) : TermElabM (Array Syntax) := if stx.isQuot then getAntiquotationIds stx else match stx with + | `(_) => #[] | `($id:ident) => #[id] | `($id:ident@$e) => do (← getPatternVars e).push id - | _ => throwErrorAt stx "unsupported pattern in syntax match" + | _ => throwErrorAt! stx "unsupported pattern in syntax match{indentD stx}" partial def getPatternsVars (pats : Array Syntax) : TermElabM (Array Syntax) := pats.foldlM (fun vars pat => do return vars ++ (← getPatternVars pat)) #[]