feat: add simpMatch and use it at splitMatch

This commit is contained in:
Leonardo de Moura 2021-08-31 12:53:17 -07:00
parent 6d4422e5ac
commit c7d797f5b6
3 changed files with 46 additions and 28 deletions

View file

@ -9,7 +9,43 @@ import Lean.Meta.Tactic.Generalize
namespace Lean.Meta
namespace Split
private def genMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array FVarId × MVarId) := do
private def getSimpMatchContext : MetaM Simp.Context :=
return {
simpLemmas := {}
congrLemmas := (← getCongrLemmas)
config.zeta := false
config.beta := false
config.eta := false
config.iota := false
config.proj := false
config.decide := false
}
private def simpMatchPre (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : SimpM Simp.Step := do
if e.isAppOf matchDeclName then
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
| none =>
-- Try lemma
match (← Simp.tryLemma? e { proof := mkConst matchEqDeclName, name? := matchEqDeclName } SplitIf.discharge?) with
| none => return Simp.Step.visit { expr := e }
| some r => return Simp.Step.done r
else
return Simp.Step.visit { expr := e }
private def simpMatch (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : MetaM Simp.Result := do
Simp.main e (← getSimpMatchContext) (methods := { pre := simpMatchPre matchDeclName matchEqDeclName })
private def simpMatchTarget (mvarId : MVarId) (matchDeclName : Name) (matchEqDeclName : Name) : MetaM MVarId := do
withMVarContext mvarId do
let target ← instantiateMVars (← getMVarType mvarId)
let r ← simpMatch matchDeclName matchEqDeclName target
match r.proof? with
| some proof => replaceTargetEq mvarId r.expr proof
| none => replaceTargetDefEq mvarId r.expr
private def generalizeMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array FVarId × MVarId) := do
if discrs.all (·.isFVar) then
return (discrs.map (·.fvarId!), mvarId)
else
@ -19,35 +55,30 @@ private def genMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Arra
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
let some app ← matchMatcherApp? e | throwError "match application expected"
let (discrFVarIds, mvarId) ← genMatchDiscrs mvarId app.discrs
let (discrFVarIds, mvarId) ← generalizeMatchDiscrs mvarId app.discrs
trace[Meta.debug] "split [1]:\n{MessageData.ofGoal mvarId}"
let (reverted, mvarId) ← revert mvarId discrFVarIds
trace[Meta.debug] "split [2]:\n{MessageData.ofGoal mvarId}"
let (discrFVarIds, mvarId) ← introNP mvarId discrFVarIds.size
let numExtra := reverted.size - discrFVarIds.size
trace[Meta.debug] "split [3]:\n{MessageData.ofGoal mvarId}"
let discrs := discrFVarIds.map mkFVar
let matchEqns ← Match.getEquationsFor app.matcherName
withMVarContext mvarId do
let motive ← mkLambdaFVars discrs (← getMVarType mvarId)
trace[Meta.debug] "split [4]: {motive}"
-- Fix universe
let mut us := app.matcherLevels
if let some uElimPos := app.uElimPos? then
-- Set universe elimination level to zero (Prop).
us := us.set! uElimPos levelZero
trace[Meta.debug] "us: {us}"
let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) app.params
let splitter := mkAppN (mkApp splitter motive) discrs
trace[Meta.debug] "splitter: {splitter}"
check splitter -- TODO
let mvarIds ← apply mvarId splitter
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
trace[Meta.debug] "split [5]:\n{MessageData.ofGoal mvarId}"
let numParams := matchEqns.splitterAltNumParams[i]
-- TODO: use equation lemmas to reduce `match`-expressions
let (_, mvarId) ← introN mvarId numParams
let (_, mvarId) ← introNP mvarId numExtra
trace[Meta.debug] "before simpMatch:\n{MessageData.ofGoal mvarId}"
let mvarId ← simpMatchTarget mvarId app.matcherName matchEqns.eqnNames[i]
return (i+1, mvarId::mvarIds)
return mvarIds.reverse

View file

@ -22,14 +22,7 @@ test% f.match_1
theorem ex (x : List Nat) : f x > 0 := by
simp [f]
induction x using f.match_1.splitter
next => simp [f.match_1.eq_1]
next x => simp [f.match_1.eq_2]
next x h1 h2 =>
rw [f.match_1.eq_3]
. decide
. exact h1
. exact h2
split <;> decide
test% Std.RBNode.balance1.match_1
#check @Std.RBNode.balance1.match_1.splitter

View file

@ -7,13 +7,9 @@ def f (xs : List Nat) : Nat :=
theorem ex1 (xs : List Nat) (hr : xs.reverse = xs) (ys : Nat) : ys > 0 → f xs > 0 := by
simp [f]
split
next => intro hys; simp
next => intro hys; simp; apply Nat.zero_lt_succ
next zs n₁ n₂ =>
intro hys
rw [f.match_1.eq_3]
anyGoals assumption
decide
next => intro hys; decide
next => intro hys; apply Nat.zero_lt_succ
next zs n₁ n₂ => intro hys; decide
def g (xs : List Nat) : Nat :=
match xs with
@ -23,7 +19,5 @@ def g (xs : List Nat) : Nat :=
theorem ex2 (xs : List Nat) : g xs > 0 := by
simp [g]
split
. simp; apply Nat.zero_lt_succ
. rw [g.match_1.eq_2]
. decide
. assumption
next a b c d e => apply Nat.zero_lt_succ
next h => decide