feat: monadic match_expr

This commit is contained in:
Leonardo de Moura 2024-03-01 22:07:07 -08:00 committed by Leonardo de Moura
parent f777e0cc85
commit 7a27b04d50
4 changed files with 180 additions and 21 deletions

View file

@ -131,12 +131,31 @@ abbrev Var := Syntax -- TODO: should be `Ident`
/-- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/
structure Alt (σ : Type) where
ref : Syntax
vars : Array Var
ref : Syntax
vars : Array Var
patterns : Syntax
rhs : σ
rhs : σ
deriving Inhabited
/-- A `doMatchExpr` alternative. -/
structure AltExpr (σ : Type) where
ref : Syntax
var? : Option Var
funName : Syntax
pvars : Array Syntax
rhs : σ
deriving Inhabited
def AltExpr.vars (alt : AltExpr σ) : Array Var := Id.run do
let mut vars := #[]
if let some var := alt.var? then
vars := vars.push var
for pvar in alt.pvars do
match pvar with
| `(_) => pure ()
| _ => vars := vars.push pvar
return vars
/--
Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`).
We convert `Code` into a `Syntax` term representing the:
@ -198,6 +217,7 @@ inductive Code where
/-- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
| ite (ref : Syntax) (h? : Option Var) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
| match (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code))
| matchExpr (ref : Syntax) (meta : Bool) (discr : Syntax) (alts : Array (AltExpr Code)) (elseBranch : Code)
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
deriving Inhabited
@ -212,6 +232,7 @@ def Code.getRef? : Code → Option Syntax
| .return ref _ => ref
| .ite ref .. => ref
| .match ref .. => ref
| .matchExpr ref .. => ref
| .jmp ref .. => ref
abbrev VarSet := RBMap Name Syntax Name.cmp
@ -243,19 +264,28 @@ partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData :=
| .match _ _ ds _ alts =>
m!"match {ds} with"
++ alts.foldl (init := m!"") fun acc alt => acc ++ m!"\n| {alt.patterns} => {loop alt.rhs}"
| .matchExpr _ meta d alts elseCode =>
let r := m!"match_expr {if meta then "" else "(meta := false)"} {d} with"
let r := r ++ alts.foldl (init := m!"") fun acc alt =>
let acc := acc ++ m!"\n| {if let some var := alt.var? then m!"{var}@" else ""}"
let acc := acc ++ m!"{alt.funName}"
let acc := acc ++ alt.pvars.foldl (init := m!"") fun acc pvar => acc ++ m!" {pvar}"
acc ++ m!" => {loop alt.rhs}"
r ++ m!"| _ => {loop elseCode}"
loop codeBlock.code
/-- Return true if the give code contains an exit point that satisfies `p` -/
partial def hasExitPointPred (c : Code) (p : Code → Bool) : Bool :=
let rec loop : Code → Bool
| .decl _ _ k => loop k
| .reassign _ _ k => loop k
| .joinpoint _ _ b k => loop b || loop k
| .seq _ k => loop k
| .ite _ _ _ _ t e => loop t || loop e
| .match _ _ _ _ alts => alts.any (loop ·.rhs)
| .jmp .. => false
| c => p c
| .decl _ _ k => loop k
| .reassign _ _ k => loop k
| .joinpoint _ _ b k => loop b || loop k
| .seq _ k => loop k
| .ite _ _ _ _ t e => loop t || loop e
| .match _ _ _ _ alts => alts.any (loop ·.rhs)
| .matchExpr _ _ _ alts e => alts.any (loop ·.rhs) || loop e
| .jmp .. => false
| c => p c
loop c
def hasExitPoint (c : Code) : Bool :=
@ -300,13 +330,18 @@ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array V
| .joinpoint n ps b k => return .joinpoint n ps (← loop b) (← loop k)
| .seq e k => return .seq e (← loop k)
| .ite ref x? h c t e => return .ite ref x? h c (← loop t) (← loop e)
| .match ref g ds t alts => return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) })
| .action e => mkAuxDeclFor e fun y =>
let ref := e
-- We jump to `jp` with xs **and** y
let jmpArgs := xs.push y
return Code.jmp ref jp jmpArgs
| c => return c
| .match ref g ds t alts =>
return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) })
| .matchExpr ref meta d alts e => do
let alts ← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) }
let e ← loop e
return .matchExpr ref meta d alts e
| c => return c
loop code
structure JPDecl where
@ -372,14 +407,13 @@ def mkJmp (ref : Syntax) (rs : VarSet) (val : Syntax) (mkJPBody : Syntax → Mac
return Code.jmp ref jp args
/-- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/
partial def pullExitPointsAux (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code :=
partial def pullExitPointsAux (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do
match c with
| .decl xs stx k => return .decl xs stx (← pullExitPointsAux (eraseVars rs xs) k)
| .reassign xs stx k => return .reassign xs stx (← pullExitPointsAux (insertVars rs xs) k)
| .joinpoint j ps b k => return .joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k)
| .seq e k => return .seq e (← pullExitPointsAux rs k)
| .ite ref x? o c t e => return .ite ref x? o c (← pullExitPointsAux (eraseOptVar rs x?) t) (← pullExitPointsAux (eraseOptVar rs x?) e)
| .match ref g ds t alts => return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) })
| .jmp .. => return c
| .break ref => mkSimpleJmp ref rs (.break ref)
| .continue ref => mkSimpleJmp ref rs (.continue ref)
@ -389,6 +423,13 @@ partial def pullExitPointsAux (rs : VarSet) (c : Code) : StateRefT (Array JPDecl
mkAuxDeclFor e fun y =>
let ref := e
mkJmp ref rs y (fun yFresh => return .action (← ``(Pure.pure $yFresh)))
| .match ref g ds t alts =>
let alts ← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) }
return .match ref g ds t alts
| .matchExpr ref meta d alts e =>
let alts ← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) }
let e ← pullExitPointsAux rs e
return .matchExpr ref meta d alts e
/--
Auxiliary operation for adding new variables to the collection of updated variables in a CodeBlock.
@ -457,6 +498,14 @@ partial def extendUpdatedVarsAux (c : Code) (ws : VarSet) : TermElabM Code :=
pullExitPoints c
else
return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) })
| .matchExpr ref meta d alts e =>
if alts.any fun alt => alt.vars.any fun x => ws.contains x.getId then
-- If a pattern variable is shadowing a variable in ws, we `pullExitPoints`
pullExitPoints c
else
let alts ← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) }
let e ← update e
return .matchExpr ref meta d alts e
| .ite ref none o c t e => return .ite ref none o c (← update t) (← update e)
| .ite ref (some h) o cond t e =>
if ws.contains h.getId then
@ -570,6 +619,16 @@ def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Sy
return { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
return { code := .match ref genParam discrs optMotive alts, uvars := ws }
def mkMatchExpr (ref : Syntax) (meta : Bool) (discr : Syntax) (alts : Array (AltExpr CodeBlock)) (elseBranch : CodeBlock) : TermElabM CodeBlock := do
-- nary version of homogenize
let ws := alts.foldl (union · ·.rhs.uvars) {}
let ws := union ws elseBranch.uvars
let alts ← alts.mapM fun alt => do
let rhs ← extendUpdatedVars alt.rhs ws
return { alt with rhs := rhs.code : AltExpr Code }
let elseBranch ← extendUpdatedVars elseBranch ws
return { code := .matchExpr ref meta discr alts elseBranch.code, uvars := ws }
/-- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`.
This method assumes `terminal` is a terminal -/
def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Var) (k : CodeBlock) : TermElabM CodeBlock := do
@ -1077,10 +1136,25 @@ where
let mut termAlts := #[]
for alt in alts do
let rhs ← toTerm alt.rhs
let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", mkNullNode #[alt.patterns], mkAtomFrom alt.ref "=>", rhs]
let termAlt := mkNode ``Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", mkNullNode #[alt.patterns], mkAtomFrom alt.ref "=>", rhs]
termAlts := termAlts.push termAlt
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
let termMatchAlts := mkNode ``Parser.Term.matchAlts #[mkNullNode termAlts]
return mkNode ``Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
| .matchExpr ref meta d alts elseBranch => withFreshMacroScope do
let d' ← `(discr)
let mut termAlts := #[]
for alt in alts do
let rhs ← toTerm alt.rhs
let optVar := if let some var := alt.var? then mkNullNode #[var, mkAtomFrom var "@"] else mkNullNode #[]
let termAlt := mkNode ``Parser.Term.matchExprAlt #[mkAtomFrom alt.ref "|", optVar, alt.funName, mkNullNode alt.pvars, mkAtomFrom alt.ref "=>", rhs]
termAlts := termAlts.push termAlt
let elseBranch := mkNode ``Parser.Term.matchExprElseAlt #[mkAtomFrom ref "|", mkHole ref, mkAtomFrom ref "=>", (← toTerm elseBranch)]
let termMatchExprAlts := mkNode ``Parser.Term.matchExprAlts #[mkNullNode termAlts, elseBranch]
let body := mkNode ``Parser.Term.matchExpr #[mkAtomFrom ref "match_expr", d', mkAtomFrom ref "with", termMatchExprAlts]
if meta then
`(Bind.bind (instantiateMVarsIfMVarApp $d) fun discr => $body)
else
`(let discr := $d; $body)
def run (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax :=
toTerm code { m, returnType, kind, uvars }
@ -1533,6 +1607,23 @@ mutual
let matchCode ← mkMatch ref genParam discrs optMotive alts
concatWith matchCode doElems
/-- Generate `CodeBlock` for `doMatchExpr; doElems` -/
partial def doMatchExprToCode (doMatchExpr : Syntax) (doElems: List Syntax) : M CodeBlock := do
let ref := doMatchExpr
let meta := doMatchExpr[1].isNone
let discr := doMatchExpr[2]
let alts := doMatchExpr[4][0].getArgs -- Array of `doMatchExprAlt`
let alts ← alts.mapM fun alt => do
let var? := if alt[1].isNone then none else some alt[1][0]
let funName := alt[2]
let pvars := alt[3].getArgs
let rhs := alt[5]
let rhs ← doSeqToCode (getDoSeqElems rhs)
pure { ref, var?, funName, pvars, rhs }
let elseBranch ← doSeqToCode (getDoSeqElems doMatchExpr[4][1][3])
let matchCode ← mkMatchExpr ref meta discr alts elseBranch
concatWith matchCode doElems
/--
Generate `CodeBlock` for `doTry; doElems`
```
@ -1640,6 +1731,8 @@ mutual
doForToCode doElem doElems
else if k == ``Parser.Term.doMatch then
doMatchToCode doElem doElems
else if k == ``Parser.Term.doMatchExpr then
doMatchExprToCode doElem doElems
else if k == ``Parser.Term.doTry then
doTryToCode doElem doElems
else if k == ``Parser.Term.doBreak then

View file

@ -63,9 +63,10 @@ def toAlt? (stx : Syntax) : Option Alt :=
let optVar := stx.getArg 1
if optVar.isNone then none else some ⟨optVar.getArg 0⟩
let funName := ⟨stx.getArg 2⟩
let pvars := stx.getArg 3 |>.getArgs.toList.map fun
| `($arg:ident) => some arg
| _ => none
let pvars := stx.getArg 3 |>.getArgs.toList.reverse.map fun arg =>
match arg with
| `(_) => none
| _ => some ⟨arg⟩
let rhs := stx.getArg 5
some { var?, funName, pvars, rhs }

View file

@ -1737,6 +1737,15 @@ def isDefEqNoConstantApprox (t s : Expr) : MetaM Bool :=
def etaExpand (e : Expr) : MetaM Expr :=
withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs (mkAppN e xs)
/--
If `e` is of the form `?m ...` instantiate metavars
-/
def instantiateMVarsIfMVarApp (e : Expr) : MetaM Expr := do
if e.getAppFn.isMVar then
instantiateMVars e
else
return e
end Meta
builtin_initialize

View file

@ -0,0 +1,56 @@
import Lean
open Lean Meta
def test1 (e : Expr) : Option Expr :=
match_expr e with
| c@Eq α a b => some (mkApp3 c α b a)
| Eq.refl _ a => some a
| _ => none
/--
info: 3 = 1
---
info: 3
---
info: 4 = 2
-/
#guard_msgs in
run_meta do
let eq ← mkEq (toExpr 1) (toExpr 3)
let eq := mkAnnotation `foo eq
let some eq := test1 eq | throwError "unexpected"
logInfo eq
let rfl ← mkEqRefl (toExpr 3)
let some n := test1 rfl | throwError "unexpected"
logInfo n
let eq := mkAnnotation `boo <| mkApp (mkAnnotation `bla (mkApp (mkAnnotation `foo eq.appFn!.appFn!) (toExpr 2))) (toExpr 4)
let some eq := test1 eq | throwError "unexpected"
logInfo eq
def test2 (e : Expr) : MetaM Expr := do
match_expr e with
| HAdd.hAdd _ _ _ _ a b => mkSub a b
| HMul.hMul _ _ _ _ a b => mkAdd b a
| _ => return e
/--
info: 2 - 5
---
info: 5 + 2
---
info: 5 - 2
-/
#guard_msgs in
run_meta do
let e ← mkAdd (toExpr 2) (toExpr 5)
let e ← test2 e
logInfo e
let e ← mkMul (toExpr 2) (toExpr 5)
let e ← test2 e
logInfo e
let m ← mkFreshExprMVar none
let m ← test2 m
assert! m.isMVar
discard <| isDefEq m e
let m ← test2 m
logInfo m