feat: use simp instead of rewrite inside of ac_refl
This commit is contained in:
parent
fda1c5b192
commit
73a59e5bc4
2 changed files with 37 additions and 11 deletions
|
|
@ -273,7 +273,7 @@ where
|
|||
|
||||
inductive ProofStrategy
|
||||
| ac_rfl (lhs rhs : NormalizedExpr)
|
||||
| rfl (tgt : Expr)
|
||||
| simp
|
||||
| norm (e : NormalizedExpr)
|
||||
|
||||
def pickStrategy (e : Expr) : M ProofStrategy := do
|
||||
|
|
@ -289,7 +289,7 @@ def pickStrategy (e : Expr) : M ProofStrategy := do
|
|||
match ←findUnnormalizedOperator r with
|
||||
| rhs@(unnormalized _ _) => return ProofStrategy.norm rhs
|
||||
| rhs@(maybeNormalized _ _ _) => return ProofStrategy.ac_rfl lhs rhs
|
||||
| rhs@(definitelyNormalized _) => return ProofStrategy.rfl l
|
||||
| rhs@(definitelyNormalized _) => return ProofStrategy.simp
|
||||
| e => return ProofStrategy.norm $ ←findUnnormalizedOperator e
|
||||
|
||||
def addAcEq (mvarId : MVarId) (e : NormalizedExpr) (target : Expr) : M MVarId := do
|
||||
|
|
@ -312,19 +312,30 @@ partial def rewriteUnnormalized (mvarId : MVarId) : M Unit :=
|
|||
assignExprMVar mvarId proof
|
||||
else throwError ""
|
||||
catch _ => throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}"
|
||||
| ProofStrategy.rfl tgt =>
|
||||
trace[Meta.AC] "picking rfl strategy {MessageData.ofGoal mvarId}"
|
||||
assignExprMVar mvarId (←mkAppM ``Eq.refl #[tgt])
|
||||
| ProofStrategy.simp =>
|
||||
trace[Meta.AC] "picking simp strategy {MessageData.ofGoal mvarId}"
|
||||
let simpCtx ← Simp.Context.mkDefault
|
||||
let newGoal ← simpTarget mvarId simpCtx
|
||||
unless newGoal.isNone do
|
||||
throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}"
|
||||
| ProofStrategy.norm (definitelyNormalized _) => throwError "no unnormalized operators found"
|
||||
| ProofStrategy.norm e =>
|
||||
trace[Meta.AC] "picking norm strategy {MessageData.ofGoal mvarId}"
|
||||
let mvarId ← addAcEq mvarId e target
|
||||
let (h_ac, mvarId) ← intro mvarId `h_ac
|
||||
let res ← rewrite mvarId (←getMVarType mvarId) (mkFVar h_ac)
|
||||
let [] := res.mvarIds | throwError "no meta variables expected after rewrite"
|
||||
let mvarId ← replaceTargetEq mvarId res.eNew res.eqProof
|
||||
let mvarId ← clear mvarId h_ac
|
||||
rewriteUnnormalized mvarId
|
||||
let simpCtx ← Simp.Context.mkDefault
|
||||
withMVarContext mvarId do
|
||||
let simpCtx := { simpCtx with simpTheorems := ←simpCtx.simpTheorems.add #[] (mkFVar h_ac) }
|
||||
|
||||
trace[Meta.AC] "pre rewrite state:\n{MessageData.ofGoal mvarId}\n"
|
||||
let mvarId ← simpTarget mvarId simpCtx
|
||||
if let some mvarId := mvarId then
|
||||
if not $ ←isDefEq target (←getMVarType mvarId) then
|
||||
let mvarId ← clear mvarId h_ac
|
||||
trace[Meta.AC] "post rewrite state:\n{MessageData.ofGoal mvarId}\n"
|
||||
rewriteUnnormalized mvarId
|
||||
else
|
||||
throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}"
|
||||
|
||||
syntax (name := ac_refl) "ac_refl " : tactic
|
||||
@[builtinTactic ac_refl] def ac_refl_tactic : Lean.Elab.Tactic.Tactic := fun stx => do
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ example (x y z : Nat) : (x + y) * (0 + z) = (x + y) * z:= by ac_refl
|
|||
|
||||
example (x y z : Nat) : (x + y) * (0 + z) = 1 * z * (y + 0 + x) := by ac_refl
|
||||
|
||||
example (x y z : Nat) : max (0 + (max x (max z (max (0 + 0) ((max 1 0) + 0 + 0) * y)))) y = max (max x y) z := by ac_refl
|
||||
theorem ex₁ (x y z : Nat) : max (0 + (max x (max z (max (0 + 0) ((max 1 0) + 0 + 0) * y)))) y = max (max x y) z := by ac_refl
|
||||
#print ex₁
|
||||
|
||||
example (x y : Nat) : 1 + 0 + 0 = 0 + 1 := by ac_refl
|
||||
|
||||
|
|
@ -48,3 +49,17 @@ example (x y : Nat) : (x + y = 42) = (y + x = 42) := by ac_refl
|
|||
|
||||
example (x y : Nat) (P : Prop) : (x + y = 42 → P) = (y + x = 42 → P) := by ac_refl
|
||||
|
||||
inductive Vector (α : Type u) : Nat → Type u where
|
||||
| nil : Vector α 0
|
||||
| cons : α → Vector α n → Vector α (n+1)
|
||||
|
||||
def f (n : Nat) (xs : Vector α n) := xs
|
||||
|
||||
-- Repro: Dependent types trigger incorrect proofs
|
||||
theorem ex₂ (n m : Nat) (xs : Vector α (n+m)) (ys : Vector α (m+n)) : (f (n+m) xs, f (m+n) ys, n+m) = (f (n+m) xs, f (m+n) ys, m+n) := by
|
||||
ac_refl
|
||||
|
||||
-- Repro: Binders also trigger invalid proofs
|
||||
--theorem ex₃ (n : Nat) : (fun x => n + x) = (fun x => x + n) := by
|
||||
-- ac_refl
|
||||
--#print ex₃
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue