diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 333b548653..bf0dc06df5 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -131,12 +131,31 @@ abbrev Var := Syntax -- TODO: should be `Ident` /-- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/ structure Alt (σ : Type) where - ref : Syntax - vars : Array Var + ref : Syntax + vars : Array Var patterns : Syntax - rhs : σ + rhs : σ deriving Inhabited +/-- A `doMatchExpr` alternative. -/ +structure AltExpr (σ : Type) where + ref : Syntax + var? : Option Var + funName : Syntax + pvars : Array Syntax + rhs : σ + deriving Inhabited + +def AltExpr.vars (alt : AltExpr σ) : Array Var := Id.run do + let mut vars := #[] + if let some var := alt.var? then + vars := vars.push var + for pvar in alt.pvars do + match pvar with + | `(_) => pure () + | _ => vars := vars.push pvar + return vars + /-- Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`). We convert `Code` into a `Syntax` term representing the: @@ -198,6 +217,7 @@ inductive Code where /-- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ | ite (ref : Syntax) (h? : Option Var) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) | match (ref : Syntax) (gen : Syntax) (discrs : Syntax) (optMotive : Syntax) (alts : Array (Alt Code)) + | matchExpr (ref : Syntax) (meta : Bool) (discr : Syntax) (alts : Array (AltExpr Code)) (elseBranch : Code) | jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) deriving Inhabited @@ -212,6 +232,7 @@ def Code.getRef? : Code → Option Syntax | .return ref _ => ref | .ite ref .. => ref | .match ref .. => ref + | .matchExpr ref .. => ref | .jmp ref .. => ref abbrev VarSet := RBMap Name Syntax Name.cmp @@ -243,19 +264,28 @@ partial def CodeBlocl.toMessageData (codeBlock : CodeBlock) : MessageData := | .match _ _ ds _ alts => m!"match {ds} with" ++ alts.foldl (init := m!"") fun acc alt => acc ++ m!"\n| {alt.patterns} => {loop alt.rhs}" + | .matchExpr _ meta d alts elseCode => + let r := m!"match_expr {if meta then "" else "(meta := false)"} {d} with" + let r := r ++ alts.foldl (init := m!"") fun acc alt => + let acc := acc ++ m!"\n| {if let some var := alt.var? then m!"{var}@" else ""}" + let acc := acc ++ m!"{alt.funName}" + let acc := acc ++ alt.pvars.foldl (init := m!"") fun acc pvar => acc ++ m!" {pvar}" + acc ++ m!" => {loop alt.rhs}" + r ++ m!"| _ => {loop elseCode}" loop codeBlock.code /-- Return true if the give code contains an exit point that satisfies `p` -/ partial def hasExitPointPred (c : Code) (p : Code → Bool) : Bool := let rec loop : Code → Bool - | .decl _ _ k => loop k - | .reassign _ _ k => loop k - | .joinpoint _ _ b k => loop b || loop k - | .seq _ k => loop k - | .ite _ _ _ _ t e => loop t || loop e - | .match _ _ _ _ alts => alts.any (loop ·.rhs) - | .jmp .. => false - | c => p c + | .decl _ _ k => loop k + | .reassign _ _ k => loop k + | .joinpoint _ _ b k => loop b || loop k + | .seq _ k => loop k + | .ite _ _ _ _ t e => loop t || loop e + | .match _ _ _ _ alts => alts.any (loop ·.rhs) + | .matchExpr _ _ _ alts e => alts.any (loop ·.rhs) || loop e + | .jmp .. => false + | c => p c loop c def hasExitPoint (c : Code) : Bool := @@ -300,13 +330,18 @@ partial def convertTerminalActionIntoJmp (code : Code) (jp : Name) (xs : Array V | .joinpoint n ps b k => return .joinpoint n ps (← loop b) (← loop k) | .seq e k => return .seq e (← loop k) | .ite ref x? h c t e => return .ite ref x? h c (← loop t) (← loop e) - | .match ref g ds t alts => return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) }) | .action e => mkAuxDeclFor e fun y => let ref := e -- We jump to `jp` with xs **and** y let jmpArgs := xs.push y return Code.jmp ref jp jmpArgs - | c => return c + | .match ref g ds t alts => + return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) }) + | .matchExpr ref meta d alts e => do + let alts ← alts.mapM fun alt => do pure { alt with rhs := (← loop alt.rhs) } + let e ← loop e + return .matchExpr ref meta d alts e + | c => return c loop code structure JPDecl where @@ -372,14 +407,13 @@ def mkJmp (ref : Syntax) (rs : VarSet) (val : Syntax) (mkJPBody : Syntax → Mac return Code.jmp ref jp args /-- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/ -partial def pullExitPointsAux (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := +partial def pullExitPointsAux (rs : VarSet) (c : Code) : StateRefT (Array JPDecl) TermElabM Code := do match c with | .decl xs stx k => return .decl xs stx (← pullExitPointsAux (eraseVars rs xs) k) | .reassign xs stx k => return .reassign xs stx (← pullExitPointsAux (insertVars rs xs) k) | .joinpoint j ps b k => return .joinpoint j ps (← pullExitPointsAux rs b) (← pullExitPointsAux rs k) | .seq e k => return .seq e (← pullExitPointsAux rs k) | .ite ref x? o c t e => return .ite ref x? o c (← pullExitPointsAux (eraseOptVar rs x?) t) (← pullExitPointsAux (eraseOptVar rs x?) e) - | .match ref g ds t alts => return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) }) | .jmp .. => return c | .break ref => mkSimpleJmp ref rs (.break ref) | .continue ref => mkSimpleJmp ref rs (.continue ref) @@ -389,6 +423,13 @@ partial def pullExitPointsAux (rs : VarSet) (c : Code) : StateRefT (Array JPDecl mkAuxDeclFor e fun y => let ref := e mkJmp ref rs y (fun yFresh => return .action (← ``(Pure.pure $yFresh))) + | .match ref g ds t alts => + let alts ← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) } + return .match ref g ds t alts + | .matchExpr ref meta d alts e => + let alts ← alts.mapM fun alt => do pure { alt with rhs := (← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs) } + let e ← pullExitPointsAux rs e + return .matchExpr ref meta d alts e /-- Auxiliary operation for adding new variables to the collection of updated variables in a CodeBlock. @@ -457,6 +498,14 @@ partial def extendUpdatedVarsAux (c : Code) (ws : VarSet) : TermElabM Code := pullExitPoints c else return .match ref g ds t (← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) }) + | .matchExpr ref meta d alts e => + if alts.any fun alt => alt.vars.any fun x => ws.contains x.getId then + -- If a pattern variable is shadowing a variable in ws, we `pullExitPoints` + pullExitPoints c + else + let alts ← alts.mapM fun alt => do pure { alt with rhs := (← update alt.rhs) } + let e ← update e + return .matchExpr ref meta d alts e | .ite ref none o c t e => return .ite ref none o c (← update t) (← update e) | .ite ref (some h) o cond t e => if ws.contains h.getId then @@ -570,6 +619,16 @@ def mkMatch (ref : Syntax) (genParam : Syntax) (discrs : Syntax) (optMotive : Sy return { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code } return { code := .match ref genParam discrs optMotive alts, uvars := ws } +def mkMatchExpr (ref : Syntax) (meta : Bool) (discr : Syntax) (alts : Array (AltExpr CodeBlock)) (elseBranch : CodeBlock) : TermElabM CodeBlock := do + -- nary version of homogenize + let ws := alts.foldl (union · ·.rhs.uvars) {} + let ws := union ws elseBranch.uvars + let alts ← alts.mapM fun alt => do + let rhs ← extendUpdatedVars alt.rhs ws + return { alt with rhs := rhs.code : AltExpr Code } + let elseBranch ← extendUpdatedVars elseBranch ws + return { code := .matchExpr ref meta discr alts elseBranch.code, uvars := ws } + /-- Return a code block that executes `terminal` and then `k` with the value produced by `terminal`. This method assumes `terminal` is a terminal -/ def concat (terminal : CodeBlock) (kRef : Syntax) (y? : Option Var) (k : CodeBlock) : TermElabM CodeBlock := do @@ -1077,10 +1136,25 @@ where let mut termAlts := #[] for alt in alts do let rhs ← toTerm alt.rhs - let termAlt := mkNode `Lean.Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", mkNullNode #[alt.patterns], mkAtomFrom alt.ref "=>", rhs] + let termAlt := mkNode ``Parser.Term.matchAlt #[mkAtomFrom alt.ref "|", mkNullNode #[alt.patterns], mkAtomFrom alt.ref "=>", rhs] termAlts := termAlts.push termAlt - let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts] - return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] + let termMatchAlts := mkNode ``Parser.Term.matchAlts #[mkNullNode termAlts] + return mkNode ``Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts] + | .matchExpr ref meta d alts elseBranch => withFreshMacroScope do + let d' ← `(discr) + let mut termAlts := #[] + for alt in alts do + let rhs ← toTerm alt.rhs + let optVar := if let some var := alt.var? then mkNullNode #[var, mkAtomFrom var "@"] else mkNullNode #[] + let termAlt := mkNode ``Parser.Term.matchExprAlt #[mkAtomFrom alt.ref "|", optVar, alt.funName, mkNullNode alt.pvars, mkAtomFrom alt.ref "=>", rhs] + termAlts := termAlts.push termAlt + let elseBranch := mkNode ``Parser.Term.matchExprElseAlt #[mkAtomFrom ref "|", mkHole ref, mkAtomFrom ref "=>", (← toTerm elseBranch)] + let termMatchExprAlts := mkNode ``Parser.Term.matchExprAlts #[mkNullNode termAlts, elseBranch] + let body := mkNode ``Parser.Term.matchExpr #[mkAtomFrom ref "match_expr", d', mkAtomFrom ref "with", termMatchExprAlts] + if meta then + `(Bind.bind (instantiateMVarsIfMVarApp $d) fun discr => $body) + else + `(let discr := $d; $body) def run (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax := toTerm code { m, returnType, kind, uvars } @@ -1533,6 +1607,23 @@ mutual let matchCode ← mkMatch ref genParam discrs optMotive alts concatWith matchCode doElems + /-- Generate `CodeBlock` for `doMatchExpr; doElems` -/ + partial def doMatchExprToCode (doMatchExpr : Syntax) (doElems: List Syntax) : M CodeBlock := do + let ref := doMatchExpr + let meta := doMatchExpr[1].isNone + let discr := doMatchExpr[2] + let alts := doMatchExpr[4][0].getArgs -- Array of `doMatchExprAlt` + let alts ← alts.mapM fun alt => do + let var? := if alt[1].isNone then none else some alt[1][0] + let funName := alt[2] + let pvars := alt[3].getArgs + let rhs := alt[5] + let rhs ← doSeqToCode (getDoSeqElems rhs) + pure { ref, var?, funName, pvars, rhs } + let elseBranch ← doSeqToCode (getDoSeqElems doMatchExpr[4][1][3]) + let matchCode ← mkMatchExpr ref meta discr alts elseBranch + concatWith matchCode doElems + /-- Generate `CodeBlock` for `doTry; doElems` ``` @@ -1640,6 +1731,8 @@ mutual doForToCode doElem doElems else if k == ``Parser.Term.doMatch then doMatchToCode doElem doElems + else if k == ``Parser.Term.doMatchExpr then + doMatchExprToCode doElem doElems else if k == ``Parser.Term.doTry then doTryToCode doElem doElems else if k == ``Parser.Term.doBreak then diff --git a/src/Lean/Elab/MatchExpr.lean b/src/Lean/Elab/MatchExpr.lean index a0b1cca905..c6e05113a7 100644 --- a/src/Lean/Elab/MatchExpr.lean +++ b/src/Lean/Elab/MatchExpr.lean @@ -63,9 +63,10 @@ def toAlt? (stx : Syntax) : Option Alt := let optVar := stx.getArg 1 if optVar.isNone then none else some ⟨optVar.getArg 0⟩ let funName := ⟨stx.getArg 2⟩ - let pvars := stx.getArg 3 |>.getArgs.toList.map fun - | `($arg:ident) => some arg - | _ => none + let pvars := stx.getArg 3 |>.getArgs.toList.reverse.map fun arg => + match arg with + | `(_) => none + | _ => some ⟨arg⟩ let rhs := stx.getArg 5 some { var?, funName, pvars, rhs } diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index a0fea460d5..836f212d00 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -1737,6 +1737,15 @@ def isDefEqNoConstantApprox (t s : Expr) : MetaM Bool := def etaExpand (e : Expr) : MetaM Expr := withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs (mkAppN e xs) +/-- +If `e` is of the form `?m ...` instantiate metavars +-/ +def instantiateMVarsIfMVarApp (e : Expr) : MetaM Expr := do + if e.getAppFn.isMVar then + instantiateMVars e + else + return e + end Meta builtin_initialize diff --git a/tests/lean/run/match_expr.lean b/tests/lean/run/match_expr.lean new file mode 100644 index 0000000000..9c599b720e --- /dev/null +++ b/tests/lean/run/match_expr.lean @@ -0,0 +1,56 @@ +import Lean +open Lean Meta + +def test1 (e : Expr) : Option Expr := + match_expr e with + | c@Eq α a b => some (mkApp3 c α b a) + | Eq.refl _ a => some a + | _ => none + +/-- +info: 3 = 1 +--- +info: 3 +--- +info: 4 = 2 +-/ +#guard_msgs in +run_meta do + let eq ← mkEq (toExpr 1) (toExpr 3) + let eq := mkAnnotation `foo eq + let some eq := test1 eq | throwError "unexpected" + logInfo eq + let rfl ← mkEqRefl (toExpr 3) + let some n := test1 rfl | throwError "unexpected" + logInfo n + let eq := mkAnnotation `boo <| mkApp (mkAnnotation `bla (mkApp (mkAnnotation `foo eq.appFn!.appFn!) (toExpr 2))) (toExpr 4) + let some eq := test1 eq | throwError "unexpected" + logInfo eq + +def test2 (e : Expr) : MetaM Expr := do + match_expr e with + | HAdd.hAdd _ _ _ _ a b => mkSub a b + | HMul.hMul _ _ _ _ a b => mkAdd b a + | _ => return e + +/-- +info: 2 - 5 +--- +info: 5 + 2 +--- +info: 5 - 2 +-/ +#guard_msgs in +run_meta do + let e ← mkAdd (toExpr 2) (toExpr 5) + let e ← test2 e + logInfo e + let e ← mkMul (toExpr 2) (toExpr 5) + let e ← test2 e + logInfo e + let m ← mkFreshExprMVar none + let m ← test2 m + assert! m.isMVar + discard <| isDefEq m e + let m ← test2 m + logInfo m