feat: rewrite the tactic using simp as the basis.

This commit is contained in:
Daniel Fabian 2022-03-09 09:25:57 +00:00 committed by Leonardo de Moura
parent ed63274874
commit d667d5ab5d
3 changed files with 109 additions and 280 deletions

View file

@ -314,8 +314,8 @@ theorem Context.eval_norm (ctx : Context α) (e : Expr) : evalList α ctx (norm
cases h₁ : ContextInformation.isIdem ctx <;> cases h₂ : ContextInformation.isComm ctx <;>
simp_all [evalList_removeNeutrals, eval_toList, toList_nonEmpty, evalList_mergeIdem, evalList_sort]
theorem Context.eq_of_norm (ctx : Context α) (a b : Expr) (h : norm ctx a = norm ctx b) : eval α ctx a = eval α ctx b := by
have h := congrArg (evalList α ctx) h
theorem Context.eq_of_norm (ctx : Context α) (a b : Expr) (h : norm ctx a == norm ctx b) : eval α ctx a = eval α ctx b := by
have h := congrArg (evalList α ctx) (eq_of_beq h)
rw [eval_norm, eval_norm] at h
assumption

View file

@ -13,7 +13,6 @@ open Lean.Data.AC
open Lean.Elab.Tactic
abbrev ACExpr := Lean.Data.AC.Expr
open Lean.Data.AC.Expr
structure PreContext where
id : Nat
@ -23,44 +22,15 @@ structure PreContext where
idem : Option Expr
deriving Inhabited
instance : ContextInformation (Std.HashMap Nat (Option Expr) × PreContext) where
isComm ctx := ctx.2.comm.isSome
isIdem ctx := ctx.2.idem.isSome
isNeutral ctx x := ctx.1.find? x |>.bind id |>.isSome
instance : ContextInformation (PreContext × Array Bool) where
isComm ctx := ctx.1.comm.isSome
isIdem ctx := ctx.1.idem.isSome
isNeutral ctx x := ctx.2[x]
instance : EvalInformation PreContext ACExpr where
arbitrary _ := var 0
evalOp _ := op
evalVar _ x := var x
structure ACNormContext where
preContexts : ExprMap (Option PreContext)
preContextsReverse : Array PreContext
exprIds : ExprMap Nat
exprIdsReverse : Array Expr
neutrals : Std.HashMap (Nat × Nat) (Option Lean.Expr)
def emptyACNormContext : ACNormContext :=
{ preContexts := Std.HashMap.empty,
exprIds := Std.HashMap.empty,
exprIdsReverse := #[]
preContextsReverse := #[],
neutrals := Std.HashMap.empty }
abbrev M α := StateT ACNormContext MetaM α
def lazyCache
(lookup : ACNormContext → α → Option β)
(insert : ACNormContext → α → β → ACNormContext)
(create : ACNormContext → α → MetaM β)
(a : α) : M β := do
let state ← get
if let some b := lookup state a then
return b
else
let b ← create state a
set $ insert state a b
return b
arbitrary _ := Data.AC.Expr.var 0
evalOp _ := Data.AC.Expr.op
evalVar _ x := Data.AC.Expr.var x
def getInstance (cls : Name) (exprs : Array Expr) : MetaM (Option Expr) := do
try
@ -72,183 +42,74 @@ def getInstance (cls : Name) (exprs : Array Expr) : MetaM (Option Expr) := do
catch
| _ => return none
def preContextCache : Expr → M (Option PreContext) :=
lazyCache
(fun ctx => ctx.preContexts.find?)
(fun ctx a b => { ctx with
preContexts := ctx.preContexts.insert a b,
preContextsReverse := if let some b := b then ctx.preContextsReverse.push b else ctx.preContextsReverse })
fun ctx expr => do
if let some assoc := ←getInstance ``IsAssociative #[expr] then
return some
{ assoc,
op := expr
id := ctx.preContextsReverse.size
comm := ←getInstance ``IsCommutative #[expr]
idem := ←getInstance ``IsIdempotent #[expr] }
def preContext (expr : Expr) : MetaM (Option PreContext) := do
if let some assoc := ←getInstance ``IsAssociative #[expr] then
return some
{ assoc,
op := expr
id := 0
comm := ←getInstance ``IsCommutative #[expr]
idem := ←getInstance ``IsIdempotent #[expr] }
return none
return none
def exprId : Expr → M Nat :=
lazyCache
(fun ctx => ctx.exprIds.find?)
(fun ctx a b => { ctx with exprIds := ctx.exprIds.insert a b, exprIdsReverse := ctx.exprIdsReverse.push a})
fun ctx expr => pure ctx.exprIdsReverse.size
def neutralCache (op : Nat) (var : Nat) : M (Option Expr) :=
lazyCache
(fun ctx => ctx.neutrals.find?)
(fun ctx a b => { ctx with neutrals := ctx.neutrals.insert a b })
(fun ctx (op, var) => do
let op := ctx.preContextsReverse[op].op
let var := ctx.exprIdsReverse[var]
getInstance ``IsNeutral #[op, var])
(op, var)
inductive NormalizedExpr
| maybeNormalized (opId : Nat) (l : NormalizedExpr) (r : NormalizedExpr)
| definitelyNormalized (varId : Nat)
| unnormalized (opId : Nat) (e : NormalizedExpr)
deriving Inhabited, Repr
open NormalizedExpr
def NormalizedExpr.norm : NormalizedExpr → M NormalizedExpr
| definitelyNormalized id => pure $ definitelyNormalized id
| e@(maybeNormalized opId _ _) => normalize opId e
| e@(unnormalized opId _) => normalize opId e
where
loop (opId : Nat) : NormalizedExpr → StateT (Std.HashMap Nat (Option Expr)) M ACExpr
| definitelyNormalized x => do
let isNeutral ← neutralCache opId x
modify fun state => state.insert x isNeutral
return var x
| unnormalized _ e => loop opId e
| maybeNormalized _ l r => return op (←loop opId l) (←loop opId r)
normalize (opId : Nat) (e : NormalizedExpr) := do
let (orig, neutrals) ← loop opId e |>.run Std.HashMap.empty
let preContext := (←get).preContextsReverse[opId]
return convertBack opId $ evalList ACExpr preContext $ Lean.Data.AC.norm (neutrals, preContext) orig
convertBack (opId : Nat) : ACExpr → NormalizedExpr
| var id => definitelyNormalized id
| op l r => maybeNormalized opId (convertBack opId l) (convertBack opId r)
def NormalizedExpr.decide : NormalizedExpr → M (Option (Nat × NormalizedExpr))
| definitelyNormalized _ => pure none
| unnormalized opId e => pure (opId, e)
| e@(maybeNormalized opId l r) => do
let rec loop : NormalizedExpr → StateT (Std.HashMap Nat (Option Expr)) M ACExpr
| definitelyNormalized x => do
let isNeutral ← neutralCache opId x
modify fun state => state.insert x isNeutral
return var x
| unnormalized _ e => loop e
| maybeNormalized _ l r => return op (←loop l) (←loop r)
let (orig, neutrals) ← loop e |>.run Std.HashMap.empty
let preContext := (←get).preContextsReverse[opId]
let res := evalList ACExpr preContext $ Lean.Data.AC.norm (neutrals, preContext) orig
return if orig == res then none else some (opId, e)
open Lean.Expr
inductive PreExpr
| op (lhs rhs : PreExpr)
| var (e : Expr)
@[matchPattern] def bin {x₁ x₂} (op l r : Expr) :=
app (app op l x₁) r x₂
Expr.app (Expr.app op l x₁) r x₂
@[matchPattern] def eq {eq₁ eq₂ eq₃ eq₄ app₁ app₂ n₁} (l r : Expr) :=
app (app (app (const (Lean.Name.str Lean.Name.anonymous "Eq" n₁) eq₁ eq₂) eq₃ eq₄) l app₁) r app₂
partial def findUnnormalizedOperator (e : Expr) : M NormalizedExpr := do
match e with
| lam _ dom b _ => decide [dom, b] >>= wrap
| forallE _ dom b _ => decide [dom, b] >>= wrap
| letE _ ty val b _ => decide [ty, val, b] >>= wrap
| mdata _ e _ => findUnnormalizedOperator e
| proj _ _ e _ => decide [e] >>= wrap
| bin opExpr lExpr rExpr => do
match ←preContextCache opExpr with
| none => decide [lExpr, rExpr] >>= wrap
| some pc =>
match ←matchWithOp pc.id lExpr with
| (false, l) => return l
| (true, l) =>
match ←matchWithOp pc.id rExpr with
| (false, r) => return r
| (true, r) => return maybeNormalized pc.id l r
| app f a _ => decide [f, a] >>= wrap
| atom =>
return definitelyNormalized (←exprId atom)
where
decide : List Lean.Expr → M (Option (Nat × NormalizedExpr))
| [] => pure none
| x :: xs => do
match ←findUnnormalizedOperator x with
| definitelyNormalized _ => decide xs
| unnormalized opId e => pure $ some (opId, e)
| e@(maybeNormalized _ _ _) => return (←e.decide) <|> (←decide xs)
wrap : Option (Nat × NormalizedExpr) → M NormalizedExpr
| none => return definitelyNormalized (←exprId e)
| some (opId, e) => return unnormalized opId e
matchWithOp (op : Nat) (expr : Lean.Expr) : M (Bool × NormalizedExpr) := do
match ←findUnnormalizedOperator expr with
| e@(unnormalized _ _) => return (false, e)
| e@(definitelyNormalized _) =>
return (true, e)
| e@(maybeNormalized op₂ _ _) =>
match op == op₂ with
| true => return (true, e)
| false =>
match ←e.decide with
| none => return (true, definitelyNormalized (←exprId expr))
| some (op, e) => return (false, unnormalized op e)
def buildProof (lhs rhs : NormalizedExpr) : M (Lean.Expr × Lean.Expr) := do
let vars :=
getVars (maybeNormalized 0 lhs rhs)
def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do
let (preExpr, vars) ←
toPreExpr (mkApp2 op l r)
|>.run Std.HashSet.empty
|>.2.toArray.insertionSort Nat.ble
let vars := vars.toArray.insertionSort Expr.lt
let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.find!
let (preContext, context) ← mkContext vars
let ty ← mkEq (←convertType preContext.op lhs) (←convertType preContext.op rhs) let lhs ← convert varMap lhs
let rhs ← convert varMap rhs
let proof ← mkAppM ``Context.eq_of_norm #[context, lhs, rhs, ←mkEqRefl $ ←mkAppM ``Lean.Data.AC.norm #[context, lhs]]
return (proof, ty)
return (vars, toACExpr varMap preExpr)
where
toPreExpr : Expr → StateT ExprSet MetaM PreExpr
| e@(bin op₂ l r) => do
if ←isDefEq op op₂ then
return PreExpr.op (←toPreExpr l) (←toPreExpr r)
modify fun vars => vars.insert e
return PreExpr.var e
| e => do
modify fun vars => vars.insert e
return PreExpr.var e
toACExpr (varMap : Expr → Nat) : PreExpr → ACExpr
| PreExpr.op l r => Data.AC.Expr.op (toACExpr varMap l) (toACExpr varMap r)
| PreExpr.var x => Data.AC.Expr.var (varMap x)
def buildNormProof (preContext : PreContext) (l r : Expr) : MetaM (Lean.Expr × Lean.Expr) := do
let (vars, acExpr) ← toACExpr preContext.op l r
let (isNeutrals, context) ← mkContext vars
let acExprNormed := Data.AC.evalList ACExpr preContext $ Data.AC.norm (preContext, isNeutrals) acExpr
let tgt ← convertTarget vars acExprNormed
let lhs ← convert acExpr
let rhs ← convert acExprNormed
let α ← inferType vars[0]
let u ← getLevel α
let proof := mkAppN (mkConst ``Context.eq_of_norm [u.dec.get!]) #[α, context, lhs, rhs, ←mkEqRefl (mkConst ``Bool.true)]
return (proof, tgt)
where
getVars : NormalizedExpr → StateM (Std.HashSet Nat) Unit
| definitelyNormalized x => modify fun xs => xs.insert x
| unnormalized opId e => getVars e
| maybeNormalized opId l r => do getVars l; getVars r
mkContext (vars : Array Expr) : MetaM (Array Bool × Expr) := do
let arbitrary := vars[0]
mkContext (vars : Array Nat) : M (PreContext × Lean.Expr) := do
let op :=
match lhs, rhs with
| definitelyNormalized _, definitelyNormalized _ => 0
| unnormalized opId _, _ => opId
| maybeNormalized opId _ _, _ => opId
| _, unnormalized opId _ => opId
| _, maybeNormalized opId _ _ => opId
let ctx ← get
let arbitrary := ctx.exprIdsReverse[vars[0]]
let preContext := ctx.preContextsReverse[op]
let vars : List Lean.Expr ← vars.toList.mapM fun x => do
let xExpr := ctx.exprIdsReverse[x]
let vars ← vars.mapM fun x => do
let isNeutral ←
match ←neutralCache op x with
| none => mkAppOptM ``Option.none #[←mkAppM ``IsNeutral #[preContext.op, xExpr]]
| some isNeutral => mkAppM ``Option.some #[isNeutral]
match ←getInstance ``IsNeutral #[preContext.op, x] with
| none => pure (false, ←mkAppOptM ``Option.none #[←mkAppM ``IsNeutral #[preContext.op, x]])
| some isNeutral => pure (true, ←mkAppM ``Option.some #[isNeutral])
mkAppM ``Variable.mk #[xExpr, isNeutral]
return (isNeutral.1, ←mkAppM ``Variable.mk #[x, isNeutral.2])
let vars ← Lean.Meta.mkListLit (←mkAppM ``Variable #[preContext.op]) vars
let (isNeutrals, vars) := vars.unzip
let vars := vars.toList
let vars ← mkListLit (←mkAppM ``Variable #[preContext.op]) vars
let comm ←
match preContext.comm with
@ -259,87 +120,51 @@ where
match preContext.idem with
| none => mkAppOptM ``Option.none #[←mkAppM ``IsIdempotent #[preContext.op]]
| some idem => mkAppM ``Option.some #[idem]
return (preContext, ←mkAppM ``Lean.Data.AC.Context.mk #[preContext.op, preContext.assoc, comm, idem, vars, arbitrary])
convert (varMap : Nat → Nat) : NormalizedExpr → MetaM Lean.Expr
| definitelyNormalized id => mkAppM ``var #[Lean.mkNatLit $ varMap id]
| unnormalized _ e => convert varMap e
| maybeNormalized _ l r => do mkAppM ``op #[←convert varMap l, ←convert varMap r]
return (isNeutrals, ←mkAppM ``Lean.Data.AC.Context.mk #[preContext.op, preContext.assoc, comm, idem, vars, arbitrary])
convertType (op : Lean.Expr) : NormalizedExpr → M Lean.Expr
| definitelyNormalized id => do return (←get).exprIdsReverse[id]
| unnormalized _ e => convertType op e
| maybeNormalized _ l r => do mkAppM' op #[←convertType op l, ←convertType op r]
convert : ACExpr → MetaM Expr
| Data.AC.Expr.op l r => do mkAppM ``Data.AC.Expr.op #[←convert l, ←convert r]
| Data.AC.Expr.var x => mkAppM ``Data.AC.Expr.var #[mkNatLit x]
inductive ProofStrategy
| ac_rfl (lhs rhs : NormalizedExpr)
| simp
| norm (e : NormalizedExpr)
convertTarget (vars : Array Expr) : ACExpr → MetaM Expr
| Data.AC.Expr.op l r => do mkAppM' preContext.op #[←convertTarget vars l, ←convertTarget vars r]
| Data.AC.Expr.var x => return vars[x]
def pickStrategy (e : Expr) : M ProofStrategy := do
match e with
| eq l r =>
match ←findUnnormalizedOperator l with
| lhs@(unnormalized _ _) => return ProofStrategy.norm lhs
| lhs@(maybeNormalized _ _ _) =>
match ←findUnnormalizedOperator r with
| rhs@(unnormalized _ _) => return ProofStrategy.norm rhs
| rhs => return ProofStrategy.ac_rfl lhs rhs
| lhs@(definitelyNormalized _) =>
match ←findUnnormalizedOperator r with
| rhs@(unnormalized _ _) => return ProofStrategy.norm rhs
| rhs@(maybeNormalized _ _ _) => return ProofStrategy.ac_rfl lhs rhs
| rhs@(definitelyNormalized _) => return ProofStrategy.simp
| e => return ProofStrategy.norm $ ←findUnnormalizedOperator e
def addAcEq (mvarId : MVarId) (e : NormalizedExpr) (target : Expr) : M MVarId := do
let (proof, ty) ← buildProof e (←e.norm)
let goal ← withLocalDeclD `h_ac ty fun h_ac =>
mkForallFVars #[h_ac] target
let goal ← mkFreshExprMVar goal
assignExprMVar mvarId (mkApp goal proof)
return goal.mvarId!
partial def rewriteUnnormalized (mvarId : MVarId) : M Unit :=
withMVarContext mvarId do
let target ← getMVarType mvarId
match ←pickStrategy target with
| ProofStrategy.ac_rfl lhs rhs =>
trace[Meta.AC] "picking ac_rfl strategy {MessageData.ofGoal mvarId}"
try
let (proof, ty) ← buildProof lhs rhs
if ←isDefEq target ty then
assignExprMVar mvarId proof
else throwError ""
catch _ => throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}"
| ProofStrategy.simp =>
trace[Meta.AC] "picking simp strategy {MessageData.ofGoal mvarId}"
let simpCtx ← Simp.Context.mkDefault
let newGoal ← simpTarget mvarId simpCtx
unless newGoal.isNone do
throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}"
| ProofStrategy.norm (definitelyNormalized _) => throwError "no unnormalized operators found"
| ProofStrategy.norm e =>
trace[Meta.AC] "picking norm strategy {MessageData.ofGoal mvarId}"
let mvarId ← addAcEq mvarId e target
let (h_ac, mvarId) ← intro mvarId `h_ac
let simpCtx ← Simp.Context.mkDefault
withMVarContext mvarId do
let simpCtx := { simpCtx with simpTheorems := ←simpCtx.simpTheorems.add #[] (mkFVar h_ac) }
trace[Meta.AC] "pre rewrite state:\n{MessageData.ofGoal mvarId}\n"
let mvarId ← simpTarget mvarId simpCtx
if let some mvarId := mvarId then
if not $ ←isDefEq target (←getMVarType mvarId) then
let mvarId ← clear mvarId h_ac
trace[Meta.AC] "post rewrite state:\n{MessageData.ofGoal mvarId}\n"
rewriteUnnormalized mvarId
else
throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}"
def rewriteUnnormalized (mvarId : MVarId) : MetaM Unit := do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let tgt ← getMVarType mvarId
let res ← Simp.main tgt simpCtx (methods := { post })
let newGoal ← applySimpResultToTarget mvarId tgt res
applyRefl newGoal
where
post (e : Expr) : SimpM Simp.Step := do
let ctx ← read
match e, ctx.parent? with
| bin op₁ l r, some (bin op₂ _ _) =>
if ←isDefEq op₁ op₂ then
return Simp.Step.done { expr := e }
match ←preContext op₁ with
| some pc =>
let (proof, newTgt) ← buildNormProof pc l r
return Simp.Step.done { expr := newTgt, proof? := proof }
| none => return Simp.Step.done { expr := e }
| bin op l r, none =>
match ←preContext op with
| some pc =>
let (proof, newTgt) ← buildNormProof pc l r
return Simp.Step.done { expr := newTgt, proof? := proof }
| none => return Simp.Step.done { expr := e }
| e, _ => return Simp.Step.done { expr := e }
@[builtinTactic ac_refl] def ac_refl_tactic : Lean.Elab.Tactic.Tactic := fun stx => do
let goal ← getMainGoal
(rewriteUnnormalized goal).run' emptyACNormContext
rewriteUnnormalized goal
builtin_initialize
registerTraceClass `Meta.AC

View file

@ -59,6 +59,10 @@ theorem ex₂ (n m : Nat) (xs : Vector α (n+m)) (ys : Vector α (m+n)) : (f (n+
ac_refl
-- Repro: Binders also trigger invalid proofs
--theorem ex₃ (n : Nat) : (fun x => n + x) = (fun x => x + n) := by
-- ac_refl
--#print ex₃
theorem ex₃ (n : Nat) : (fun x => n + x) = (fun x => x + n) := by
ac_refl
#print ex₃
-- Repro: the Prop universe doesn't work
example (p q : Prop) : p p q ∧ True = q p := by
ac_refl