lean4-htt/src/Lean/Meta/Tactic/Split.lean
2022-06-13 17:10:14 -07:00

323 lines
15 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Match.MatchEqs
import Lean.Meta.Tactic.Generalize
namespace Lean.Meta
namespace Split
def getSimpMatchContext : MetaM Simp.Context :=
return {
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
def simpMatch (e : Expr) : MetaM Simp.Result := do
Simp.main e (← getSimpMatchContext) (methods := { pre })
where
pre (e : Expr) : SimpM Simp.Step := do
let some app ← matchMatcherApp? e | return Simp.Step.visit { expr := e }
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
| none =>
match (← Simp.simpMatchCore? app e SplitIf.discharge?) with
| some r => return r
| none => return Simp.Step.visit { expr := e }
def simpMatchTarget (mvarId : MVarId) : MetaM MVarId := withMVarContext mvarId do
let target ← instantiateMVars (← getMVarType mvarId)
let r ← simpMatch target
applySimpResultToTarget mvarId target r
private def simpMatchCore (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : MetaM Simp.Result := do
Simp.main e (← getSimpMatchContext) (methods := { pre })
where
pre (e : Expr) : SimpM Simp.Step := do
if e.isAppOf matchDeclName then
-- First try to reduce matcher
match (← reduceRecMatcher? e) with
| some e' => return Simp.Step.done { expr := e' }
| none =>
-- Try lemma
match (← withReducible <| Simp.tryTheorem? e { proof := mkConst matchEqDeclName, name? := matchEqDeclName, rfl := (← isRflTheorem matchEqDeclName) } SplitIf.discharge?) with
| none => return Simp.Step.visit { expr := e }
| some r => return Simp.Step.done r
else
return Simp.Step.visit { expr := e }
private def simpMatchTargetCore (mvarId : MVarId) (matchDeclName : Name) (matchEqDeclName : Name) : MetaM MVarId := do
withMVarContext mvarId do
let target ← instantiateMVars (← getMVarType mvarId)
let r ← simpMatchCore matchDeclName matchEqDeclName target
match r.proof? with
| some proof => replaceTargetEq mvarId r.expr proof
| none => replaceTargetDefEq mvarId r.expr
private partial def withEqs (lhs rhs : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α := do
go 0 #[] #[]
where
go (i : Nat) (hs : Array Expr) (rfls : Array Expr) : MetaM α := do
if i < lhs.size then
withLocalDeclD (← mkFreshUserName `heq) (← mkEqHEq lhs[i] rhs[i]) fun h => do
let rfl ← if (← inferType h).isEq then mkEqRefl lhs[i] else mkHEqRefl lhs[i]
go (i+1) (hs.push h) (rfls.push rfl)
else
k hs rfls
/--
This method makes sure each discriminant is a free variable.
Return the tuple `(discrsNew, discrEqs, mvarId)`. `discrsNew` in an array representing the new discriminants, `discrEqs` is an array of auxiliary equality hypotheses
that connect the new discriminants to the original terms they represent.
Remark: `discrEqs.size ≤ discrsNew.size`
Remark:
We should only generalize `discrs` occurrences as `match`-expression discriminants.
For example, given the following goal.
```
x : Nat
⊢ (match g x with
| 0 => 1
| Nat.succ y => g x) =
2 * x + 1
```
we should not generalize the `g x` in the rhs of the second alternative, and the two resulting goals
for the `split` tactic should be
```
case h_1
x x✝ : Nat
h✝ : g x = 0
⊢ 1 = 2 * x + 1
case h_2
x x✝ y✝ : Nat
h✝ : g x = Nat.succ y✝
⊢ g x = 2 * x + 1
```
-/
private partial def generalizeMatchDiscrs (mvarId : MVarId) (matcherDeclName : Name) (motiveType : Expr) (discrs : Array Expr) : MetaM (Array FVarId × Array FVarId × MVarId) := withMVarContext mvarId do
if discrs.all (·.isFVar) then
return (discrs.map (·.fvarId!), #[], mvarId)
let some matcherInfo ← getMatcherInfo? matcherDeclName | unreachable!
let numDiscrEqs := matcherInfo.getNumDiscrEqs -- Number of `h : discr = pattern` equations
let (targetNew, rfls) ←
forallTelescope motiveType fun discrVars _ =>
withEqs discrs discrVars fun eqs rfls => do
let foundRef ← IO.mkRef false
let rec mkNewTarget (e : Expr) : MetaM Expr := do
let pre (e : Expr) : MetaM TransformStep := do
if !e.isAppOf matcherDeclName || e.getAppNumArgs != matcherInfo.arity then
return .visit e
let some matcherApp ← matchMatcherApp? e | return .visit e
for matcherDiscr in matcherApp.discrs, discr in discrs do
unless matcherDiscr == discr do
trace[Meta.Tactic.split] "discr mismatch {matcherDiscr} != {discr}"
return .visit e
let matcherApp := { matcherApp with discrs := discrVars }
foundRef.set true
let mut altsNew := #[]
for i in [:matcherApp.alts.size] do
let alt := matcherApp.alts[i]
let altNumParams := matcherApp.altNumParams[i]
let altNew ← lambdaTelescope alt fun xs body => do
if xs.size < altNumParams || xs.size < numDiscrEqs then
throwError "'applyMatchSplitter' failed, unexpected `match` alternative"
let body ← mkLambdaFVars xs[altNumParams:] (← mkNewTarget body)
let ys := xs[:altNumParams - numDiscrEqs]
if numDiscrEqs == 0 then
mkLambdaFVars ys body
else
let altEqs := xs[altNumParams - numDiscrEqs : altNumParams]
withNewAltEqs matcherInfo eqs altEqs fun altEqsNew subst => do
let body := body.replaceFVars altEqs subst
mkLambdaFVars (ys++altEqsNew) body
altsNew := altsNew.push altNew
return .done { matcherApp with alts := altsNew }.toExpr
transform (← instantiateMVars e) pre
let targetNew ← mkNewTarget (← getMVarType mvarId)
unless (← foundRef.get) do
throwError "'applyMatchSplitter' failed, did not find discriminants"
let targetNew ← mkForallFVars (discrVars ++ eqs) targetNew
unless (← isTypeCorrect targetNew) do
throwError "'applyMatchSplitter' failed, failed to generalize target"
return (targetNew, rfls)
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew (← getMVarTag mvarId)
trace[Meta.Tactic.split] "targetNew:\n{mvarNew.mvarId!}"
assignExprMVar mvarId (mkAppN (mkAppN mvarNew discrs) rfls)
let (discrs', mvarId') ← introNP mvarNew.mvarId! discrs.size
let (discrEqs, mvarId') ← introNP mvarId' discrs.size
return (discrs', discrEqs, mvarId')
where
/-
- `eqs` are free variables `h_eq : discr = discrVar`. `eqs.size == discrs.size`
- `altEqs` are free variables of the form `h_altEq : discr = pattern`. `altEqs.size = numDiscrEqs ≤ discrs.size`
This method executes `k altEqsNew subst` where
- `altEqsNew` are fresh free variables of the form `h_altEqNew : discrVar = pattern`
- `subst` are terms of the form `h_eq.trans h_altEqNew : discr = pattern`. We use `subst` later to replace occurences of `h_altEq` with `h_eq.trans h_altEqNew`.
-/
withNewAltEqs (matcherInfo : MatcherInfo) (eqs : Array Expr) (altEqs : Array Expr) (k : Array Expr → Array Expr → MetaM Expr) : MetaM Expr := do
let eqs' := (eqs.zip matcherInfo.discrInfos).filterMap fun (eq, info) => if info.hName?.isNone then none else some eq
-- `eqs'.size == altEqs.size ≤ eqs.size`
let rec go (i : Nat) (altEqsNew : Array Expr) (subst : Array Expr) : MetaM Expr := do
if i < altEqs.size then
let altEqDecl ← getFVarLocalDecl altEqs[i]
let eq := eqs'[i]
let eqType ← inferType eq
let altEqType := altEqDecl.type
match eqType.eq?, altEqType.eq? with
| some (_, _, discrVar), some (_, _ /- discr -/, pattern) =>
withLocalDeclD altEqDecl.userName (← mkEq discrVar pattern) fun altEqNew => do
go (i+1) (altEqsNew.push altEqNew) (subst.push (← mkEqTrans eq altEqNew))
| _, _ =>
match eqType.heq?, altEqType.heq? with
| some (_, _, _, discrVar), some (_, _ /- discr -/, _, pattern) =>
withLocalDeclD altEqDecl.userName (← mkHEq discrVar pattern) fun altEqNew => do
go (i+1) (altEqsNew.push altEqNew) (subst.push (← mkHEqTrans eq altEqNew))
| _, _ =>
throwError "'applyMatchSplitter' failed, unexpected discriminant equalities"
else
k altEqsNew subst
go 0 #[] #[]
private def substDiscrEqs (mvarId : MVarId) (fvarSubst : FVarSubst) (discrEqs : Array FVarId) : MetaM MVarId := withMVarContext mvarId do
let mut mvarId := mvarId
let mut fvarSubst := fvarSubst
for fvarId in discrEqs do
if let .fvar fvarId _ := fvarSubst.apply (mkFVar fvarId) then
let (fvarId, mvarId') ← heqToEq mvarId fvarId
match (← substCore? mvarId' fvarId (symm := false) fvarSubst) with
| some (fvarSubst', mvarId') => mvarId := mvarId'; fvarSubst := fvarSubst'
| none =>
match (← substCore? mvarId' fvarId (symm := true) fvarSubst) with
| some (fvarSubst', mvarId') => mvarId := mvarId'; fvarSubst := fvarSubst'
| none => mvarId := mvarId'
return mvarId
def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Level) (params : Array Expr) (discrs : Array Expr) : MetaM (List MVarId) := do
let some info ← getMatcherInfo? matcherDeclName | throwError "'applyMatchSplitter' failed, '{matcherDeclName}' is not a 'match' auxiliary declaration."
let matchEqns ← Match.getEquationsFor matcherDeclName
-- splitterPre does not have the correct universe elimination level, but this is fine, we only use it to compute the `motiveType`,
-- and we only care about the `motiveType` arguments, and not the resulting `Sort u`.
let splitterPre := mkAppN (mkConst matchEqns.splitterName us.toList) params
let motiveType := (← whnfForall (← inferType splitterPre)).bindingDomain!
trace[Meta.Tactic.split] "applyMatchSplitter\n{mvarId}"
let (discrFVarIds, discrEqs, mvarId) ← generalizeMatchDiscrs mvarId matcherDeclName motiveType discrs
trace[Meta.Tactic.split] "after generalizeMatchDiscrs\n{mvarId}"
let mvarId ← generalizeTargetsEq mvarId motiveType (discrFVarIds.map mkFVar)
withMVarContext mvarId do trace[Meta.Tactic.split] "discrEqs after generalizeTargetsEq: {discrEqs.map mkFVar}"
trace[Meta.Tactic.split] "after generalize\n{mvarId}"
let numEqs := discrs.size
let (discrFVarIdsNew, mvarId) ← introN mvarId discrs.size
trace[Meta.Tactic.split] "after introN\n{mvarId}"
let discrsNew := discrFVarIdsNew.map mkFVar
let mvarType ← getMVarType mvarId
let elimUniv ← withMVarContext mvarId <| getLevel mvarType
let us ← if let some uElimPos := info.uElimPos? then
pure <| us.set! uElimPos elimUniv
else
unless elimUniv.isZero do
throwError "match-splitter can only eliminate into `Prop`"
pure us
let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) params
withMVarContext mvarId do
let motive ← mkLambdaFVars discrsNew mvarType
let splitter := mkAppN (mkApp splitter motive) discrsNew
check splitter
trace[Meta.Tactic.split] "after check splitter"
let mvarIds ← apply mvarId splitter
unless mvarIds.length == matchEqns.size do
throwError "'applyMatchSplitter' failed, unexpected number of goals created after applying splitter for '{matcherDeclName}'."
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
let numParams := matchEqns.splitterAltNumParams[i]
let (_, mvarId) ← introN mvarId numParams
trace[Meta.Tactic.split] "before unifyEqs\n{mvarId}"
match (← Cases.unifyEqs? (numEqs + info.getNumDiscrEqs) mvarId {}) with
| none => return (i+1, mvarIds) -- case was solved
| some (mvarId, fvarSubst) =>
trace[Meta.Tactic.split] "after unifyEqs\n{mvarId}"
let mvarId ← substDiscrEqs mvarId fvarSubst discrEqs
return (i+1, mvarId::mvarIds)
return mvarIds.reverse
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
try
let some app ← matchMatcherApp? e | throwError "match application expected"
let matchEqns ← Match.getEquationsFor app.matcherName
let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
let mvarId ← simpMatchTargetCore mvarId app.matcherName matchEqns.eqnNames[i]
return (i+1, mvarId::mvarIds)
return mvarIds.reverse
catch ex =>
throwNestedTacticEx `splitMatch ex
/-- Return an `if-then-else` or `match-expr` to split. -/
partial def findSplit? (env : Environment) (e : Expr) (splitIte := true) (exceptionSet : ExprSet := {}) : Option Expr :=
go e
where
go (e : Expr) : Option Expr :=
if let some target := e.find? isCandidate then
if e.isIte || e.isDIte then
let cond := target.getArg! 1 5
-- Try to find a nested `if` in `cond`
go cond |>.getD target
else
some target
else
none
isCandidate (e : Expr) : Bool := Id.run do
if exceptionSet.contains e then
false
else if splitIte && (e.isIte || e.isDIte) then
!(e.getArg! 1 5).hasLooseBVars
else if let some info := isMatcherAppCore? env e then
let args := e.getAppArgs
for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do
if args[i].hasLooseBVars then
return false
return true
else
false
end Split
open Split
partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do
let target ← instantiateMVars (← getMVarType mvarId)
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
if let some e := findSplit? (← getEnv) target splitIte badCases then
if e.isIte || e.isDIte then
return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
else
try
splitMatch mvarId e
catch _ =>
go (badCases.insert e)
else
trace[Meta.Tactic.split] "did not find term to split\n{MessageData.ofGoal mvarId}"
return none
go {}
def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
withMVarContext mvarId do
if let some e := findSplit? (← getEnv) (← instantiateMVars (← inferType (mkFVar fvarId))) then
if e.isIte || e.isDIte then
return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂]
else
let (fvarIds, mvarId) ← revert mvarId #[fvarId]
let num := fvarIds.size
let mvarIds ← splitMatch mvarId e
let mvarIds ← mvarIds.mapM fun mvarId => return (← introNP mvarId num).2
return some mvarIds
else
return none
builtin_initialize registerTraceClass `Meta.Tactic.split
end Lean.Meta