From 4223bdf8aa64b9ea6bb90cd37cf35adabe99c0fe Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 15 Aug 2020 07:26:22 -0700 Subject: [PATCH] feat: add `expandNatValuePattern` --- src/Lean/Meta/EqnCompiler/DepElim.lean | 79 +++++++++++++++++--------- tests/lean/match3.lean | 11 ++++ tests/lean/match3.lean.expected.out | 3 + 3 files changed, 65 insertions(+), 28 deletions(-) create mode 100644 tests/lean/match3.lean create mode 100644 tests/lean/match3.lean.expected.out diff --git a/src/Lean/Meta/EqnCompiler/DepElim.lean b/src/Lean/Meta/EqnCompiler/DepElim.lean index f85791265e..f1404919eb 100644 --- a/src/Lean/Meta/EqnCompiler/DepElim.lean +++ b/src/Lean/Meta/EqnCompiler/DepElim.lean @@ -240,48 +240,61 @@ p.alts.any fun alt => match alt.patterns with | Pattern.as _ _ :: _ => true | _ => false -/- Return true if the next pattern of each remaining alternative is an inaccessible term or a variable -/ +private def hasCtorPattern (p : Problem) : Bool := +p.alts.any fun alt => match alt.patterns with + | Pattern.ctor _ _ _ _ :: _ => true + | _ => false + +private def hasValPattern (p : Problem) : Bool := +p.alts.any fun alt => match alt.patterns with + | Pattern.val _ :: _ => true + | _ => false + +private def hasNatValPattern (p : Problem) : Bool := +p.alts.any fun alt => match alt.patterns with + | Pattern.val v :: _ => v.isNatLit + | _ => false + +private def hasVarPattern (p : Problem) : Bool := +p.alts.any fun alt => match alt.patterns with + | Pattern.var _ :: _ => true + | _ => false + +private def hasArrayLitPattern (p : Problem) : Bool := +p.alts.any fun alt => match alt.patterns with + | Pattern.arrayLit _ _ :: _ => true + | _ => false + private def isVariableTransition (p : Problem) : Bool := p.alts.all fun alt => match alt.patterns with | Pattern.inaccessible _ :: _ => true | Pattern.var _ :: _ => true | _ => false -/- Return true if the next pattern of each remaining alternative is a constructor application or variable or inaccessible term -/ private def isConstructorTransition (p : Problem) : Bool := -(p.alts.any fun alt => match alt.patterns with - | Pattern.ctor _ _ _ _ :: _ => true - | _ => false) -&& -(p.alts.all fun alt => match alt.patterns with +hasCtorPattern p +&& p.alts.all fun alt => match alt.patterns with | Pattern.ctor _ _ _ _ :: _ => true | Pattern.var _ :: _ => true | Pattern.inaccessible _ :: _ => true - | _ => false) + | _ => false -/- Return true if the next pattern of the remaining alternatives contain variables AND values. -/ private def isValueTransition (p : Problem) : Bool := -let (ok, hasVar, hasVal) := p.alts.foldl - (fun (acc : Bool × Bool × Bool) (alt : Alt) => - let (ok, hasVar, hasVal) := acc; - match alt.patterns with - | Pattern.val _ :: _ => (ok, hasVar, true) - | Pattern.var _ :: _ => (ok, true, hasVal) - | _ => (false, hasVar, hasVal)) - (true, false, false); -ok && hasVar && hasVal +hasVarPattern p && hasValPattern p +&& p.alts.all fun alt => match alt.patterns with + | Pattern.val _ :: _ => true + | Pattern.var _ :: _ => true + | _ => false -/- Return true if the next pattern of the remaining alternatives contain variables AND array literals. -/ private def isArrayLitTransition (p : Problem) : Bool := -let (ok, hasVar, hasArray) := p.alts.foldl - (fun (acc : Bool × Bool × Bool) (alt : Alt) => - let (ok, hasVar, hasArray) := acc; - match alt.patterns with - | Pattern.arrayLit _ _ :: _ => (ok, hasVar, true) - | Pattern.var _ :: _ => (ok, true, hasArray) - | _ => (false, hasVar, hasArray)) - (true, false, false); -ok && hasVar && hasArray +hasArrayLitPattern p && hasVarPattern p +&& p.alts.all fun alt => match alt.patterns with + | Pattern.arrayLit _ _ :: _ => true + | Pattern.var _ :: _ => true + | _ => false + +private def isNatValueCtorTransition (p : Problem) : Bool := +hasCtorPattern p && hasNatValPattern p private def processNonVariable (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do trace! `Meta.EqnCompiler.match ("non variable step"); @@ -576,6 +589,13 @@ match p.vars with process { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs } s) s +private def expandNatValuePattern (p : Problem) : Problem := do +let alts := p.alts.map fun alt => match alt.patterns with + | Pattern.val (Expr.lit (Literal.natVal 0) _) :: ps => { alt with patterns := Pattern.ctor `Nat.zero [] [] [] :: ps } + | Pattern.val (Expr.lit (Literal.natVal (n+1)) _) :: ps => { alt with patterns := Pattern.ctor `Nat.succ [] [] [Pattern.val (mkNatLit n)] :: ps } + | _ => alt; +{ p with alts := alts } + private partial def process : Problem → State → MetaM State | p, s => withIncRecDepth do withGoalOf p (traceM `Meta.EqnCompiler.match p.toMessageData); @@ -593,6 +613,9 @@ private partial def process : Problem → State → MetaM State processValue process p s else if isArrayLitTransition p then processArrayLit process p s + else if isNatValueCtorTransition p then do + trace! `Meta.EqnCompiler.match ("nat value to constructor step"); + process (expandNatValuePattern p) s else do msg ← p.toMessageData; -- TODO: remaining cases diff --git a/tests/lean/match3.lean b/tests/lean/match3.lean new file mode 100644 index 0000000000..ed3f7685a3 --- /dev/null +++ b/tests/lean/match3.lean @@ -0,0 +1,11 @@ +new_frontend + +def f (x : Nat) : Nat := +match x with +| 30 => 31 +| y+1 => y +| 0 => 10 + +#eval f 20 +#eval f 0 +#eval f 30 diff --git a/tests/lean/match3.lean.expected.out b/tests/lean/match3.lean.expected.out new file mode 100644 index 0000000000..d1993f1f6b --- /dev/null +++ b/tests/lean/match3.lean.expected.out @@ -0,0 +1,3 @@ +19 +10 +31