feat: substitute auxiliary equations introduced by the split tactic
This commit is contained in:
parent
314bd3ae4c
commit
3c964f3b9f
5 changed files with 32 additions and 11 deletions
|
|
@ -162,4 +162,4 @@ theorem Term.constFold_sound (e : Term ctx ty) : e.constFold.denote env = e.deno
|
|||
| plus a b iha ihb =>
|
||||
split
|
||||
next he₁ he₂ => simp [← iha, ← ihb, he₁, he₂]
|
||||
next he₁ he₂ _ _ _ => simp [← he₁, ← he₂, iha, ihb]
|
||||
next => simp [iha, ihb]
|
||||
|
|
|
|||
|
|
@ -239,4 +239,4 @@ theorem constFold_sound (e : Term' Ty.denote ty) : denote (constFold e) = denote
|
|||
| plus a b iha ihb =>
|
||||
split
|
||||
next he₁ he₂ => simp [← iha, ← ihb, he₁, he₂]
|
||||
next he₁ he₂ _ _ _ => simp [← he₁, ← he₂, iha, ihb]
|
||||
next => simp [iha, ihb]
|
||||
|
|
|
|||
|
|
@ -61,9 +61,15 @@ private def simpMatchTargetCore (mvarId : MVarId) (matchDeclName : Name) (matchE
|
|||
| some proof => replaceTargetEq mvarId r.expr proof
|
||||
| none => replaceTargetDefEq mvarId r.expr
|
||||
|
||||
private def generalizeMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array FVarId × MVarId) := do
|
||||
/--
|
||||
Use `generalize` to make sure each discriminant is a free variable.
|
||||
Return the tuple `(discrsNew, discrEqs, mvarId)`. `discrsNew` in an array representing the new discriminants, `discrEqs` is an array of auxiliary equality hypotheses
|
||||
that connect the new discriminants to the original terms they represent.
|
||||
Remark: `discrEqs.size ≤ discrsNew.size`
|
||||
-/
|
||||
private def generalizeMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array FVarId × Array FVarId × MVarId) := do
|
||||
if discrs.all (·.isFVar) then
|
||||
return (discrs.map (·.fvarId!), mvarId)
|
||||
return (discrs.map (·.fvarId!), #[], mvarId)
|
||||
else
|
||||
let discrsToGeneralize := discrs.filter fun d => !d.isFVar
|
||||
let args ← discrsToGeneralize.mapM fun d => return { expr := d, hName? := (← mkFreshUserName `h) : GeneralizeArg }
|
||||
|
|
@ -76,7 +82,14 @@ private def generalizeMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : Meta
|
|||
else
|
||||
result := result.push fvarIdsNew[j]
|
||||
j := j + 1
|
||||
return (result, mvarId)
|
||||
return (result, fvarIdsNew[j:], mvarId)
|
||||
|
||||
private def substDiscrEqs (mvarId : MVarId) (discrEqs : Array FVarId) : MetaM MVarId := do
|
||||
let mut mvarId := mvarId
|
||||
for fvarId in discrEqs.reverse do
|
||||
trace[Meta.Tactic.split] "subst auxiliary eq {mkFVar fvarId} : {← inferType (mkFVar fvarId)}"
|
||||
mvarId ← trySubst mvarId fvarId
|
||||
return mvarId
|
||||
|
||||
def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Level) (params : Array Expr) (discrs : Array Expr) : MetaM (List MVarId) := do
|
||||
let some info ← getMatcherInfo? matcherDeclName | throwError "'applyMatchSplitter' failed, '{matcherDeclName}' is not a 'match' auxiliary declaration."
|
||||
|
|
@ -87,10 +100,14 @@ def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Le
|
|||
us := us.set! uElimPos levelZero
|
||||
let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) params
|
||||
let motiveType := (← whnfForall (← inferType splitter)).bindingDomain!
|
||||
let (discrFVarIds, mvarId) ← generalizeMatchDiscrs mvarId discrs
|
||||
trace[Meta.Tactic.split] "applyMatchSplitter\n{mvarId}"
|
||||
let (discrFVarIds, discrEqs, mvarId) ← generalizeMatchDiscrs mvarId discrs
|
||||
trace[Meta.Tactic.split] "after generalizeMatchDiscrs\n{mvarId}"
|
||||
let mvarId ← generalizeTargetsEq mvarId motiveType (discrFVarIds.map mkFVar)
|
||||
trace[Meta.Tactic.split] "after generalize\n{mvarId}"
|
||||
let numEqs := discrs.size
|
||||
let (discrFVarIdsNew, mvarId) ← introN mvarId discrs.size
|
||||
trace[Meta.Tactic.split] "after introN\n{mvarId}"
|
||||
let discrsNew := discrFVarIdsNew.map mkFVar
|
||||
withMVarContext mvarId do
|
||||
let motive ← mkLambdaFVars discrsNew (← getMVarType mvarId)
|
||||
|
|
@ -102,9 +119,11 @@ def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Le
|
|||
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
|
||||
let numParams := matchEqns.splitterAltNumParams[i]
|
||||
let (_, mvarId) ← introN mvarId numParams
|
||||
trace[Meta.Tactic.split] "before unifyEqs\n{mvarId}"
|
||||
match (← Cases.unifyEqs numEqs mvarId {}) with
|
||||
| none => return (i+1, mvarIds) -- case was solved
|
||||
| some (mvarId, _) =>
|
||||
let mvarId ← substDiscrEqs mvarId discrEqs
|
||||
return (i+1, mvarId::mvarIds)
|
||||
return mvarIds.reverse
|
||||
|
||||
|
|
|
|||
|
|
@ -350,18 +350,20 @@ theorem Stmt.simplify_correct (h : (σ, s) ⇓ σ') : (σ, s.simplify) ⇓ σ' :
|
|||
revert ih₂ -- This is a hack to make sure the next split simplify the two match expressions: TODO: make sure `simp` can do it
|
||||
split <;> intro ih₂
|
||||
next h => rw [h] at heq; simp at heq
|
||||
next h _ _ => rw [h] at heq; apply Bigstep.whileTrue heq ih₁ ih₂
|
||||
next => apply Bigstep.whileTrue heq ih₁ ih₂
|
||||
| whileFalse heq =>
|
||||
split
|
||||
next => exact Bigstep.skip
|
||||
next h _ _ => apply Bigstep.whileFalse; rw [← h]; simp [heq]
|
||||
next => apply Bigstep.whileFalse; simp [heq]
|
||||
| ifFalse heq h ih =>
|
||||
rw [← Expr.eval_simplify] at heq
|
||||
split <;> simp_all
|
||||
rw [← Expr.eval_simplify] at heq
|
||||
apply Bigstep.ifFalse heq ih
|
||||
| ifTrue heq h ih =>
|
||||
rw [← Expr.eval_simplify] at heq
|
||||
split <;> simp_all
|
||||
rw [← Expr.eval_simplify] at heq
|
||||
apply Bigstep.ifTrue heq ih
|
||||
|
||||
@[simp] def Expr.constProp (e : Expr) (σ : State) : Expr :=
|
||||
|
|
@ -579,9 +581,9 @@ theorem Stmt.constProp_correct (h₁ : (σ₁, s) ⇓ σ₂) (h₂ : σ₁' ≼
|
|||
rw [← Expr.eval_simplify, h] at heq'
|
||||
simp at heq'
|
||||
apply Bigstep.assign; simp [*]
|
||||
next h _ _ =>
|
||||
next =>
|
||||
have heq' := Expr.eval_constProp_of_eq_of_sub heq h₂
|
||||
rw [← Expr.eval_simplify, h] at heq'
|
||||
rw [← Expr.eval_simplify] at heq'
|
||||
apply Bigstep.assign heq'
|
||||
| seq h₁ h₂ ih₁ ih₂ =>
|
||||
apply Bigstep.seq (ih₁ h₂) (ih₂ (constProp_sub h₁ h₂))
|
||||
|
|
|
|||
|
|
@ -52,4 +52,4 @@ theorem Term.constFold_sound (e : Term ctx ty) : e.constFold.denote env = e.deno
|
|||
| plus a b iha ihb =>
|
||||
split
|
||||
next he₁ he₂ => simp [← iha, ← ihb, he₁, he₂]
|
||||
next he₁ he₂ _ _ _ => simp [← he₁, ← he₂, iha, ihb]
|
||||
next => simp [iha, ihb]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue