From d667d5ab5d4a23ec8ab74e376d8e100ee8925381 Mon Sep 17 00:00:00 2001 From: Daniel Fabian Date: Wed, 9 Mar 2022 09:25:57 +0000 Subject: [PATCH] feat: rewrite the tactic using `simp` as the basis. --- src/Init/Data/AC.lean | 4 +- src/Lean/Meta/Tactic/AC/Main.lean | 375 ++++++++---------------------- tests/lean/run/ac_refl.lean | 10 +- 3 files changed, 109 insertions(+), 280 deletions(-) diff --git a/src/Init/Data/AC.lean b/src/Init/Data/AC.lean index e04510578f..39d8efa642 100644 --- a/src/Init/Data/AC.lean +++ b/src/Init/Data/AC.lean @@ -314,8 +314,8 @@ theorem Context.eval_norm (ctx : Context α) (e : Expr) : evalList α ctx (norm cases h₁ : ContextInformation.isIdem ctx <;> cases h₂ : ContextInformation.isComm ctx <;> simp_all [evalList_removeNeutrals, eval_toList, toList_nonEmpty, evalList_mergeIdem, evalList_sort] -theorem Context.eq_of_norm (ctx : Context α) (a b : Expr) (h : norm ctx a = norm ctx b) : eval α ctx a = eval α ctx b := by - have h := congrArg (evalList α ctx) h +theorem Context.eq_of_norm (ctx : Context α) (a b : Expr) (h : norm ctx a == norm ctx b) : eval α ctx a = eval α ctx b := by + have h := congrArg (evalList α ctx) (eq_of_beq h) rw [eval_norm, eval_norm] at h assumption diff --git a/src/Lean/Meta/Tactic/AC/Main.lean b/src/Lean/Meta/Tactic/AC/Main.lean index 4568f00f6a..0c70b96e83 100644 --- a/src/Lean/Meta/Tactic/AC/Main.lean +++ b/src/Lean/Meta/Tactic/AC/Main.lean @@ -13,7 +13,6 @@ open Lean.Data.AC open Lean.Elab.Tactic abbrev ACExpr := Lean.Data.AC.Expr -open Lean.Data.AC.Expr structure PreContext where id : Nat @@ -23,44 +22,15 @@ structure PreContext where idem : Option Expr deriving Inhabited -instance : ContextInformation (Std.HashMap Nat (Option Expr) × PreContext) where - isComm ctx := ctx.2.comm.isSome - isIdem ctx := ctx.2.idem.isSome - isNeutral ctx x := ctx.1.find? x |>.bind id |>.isSome +instance : ContextInformation (PreContext × Array Bool) where + isComm ctx := ctx.1.comm.isSome + isIdem ctx := ctx.1.idem.isSome + isNeutral ctx x := ctx.2[x] instance : EvalInformation PreContext ACExpr where - arbitrary _ := var 0 - evalOp _ := op - evalVar _ x := var x - -structure ACNormContext where - preContexts : ExprMap (Option PreContext) - preContextsReverse : Array PreContext - exprIds : ExprMap Nat - exprIdsReverse : Array Expr - neutrals : Std.HashMap (Nat × Nat) (Option Lean.Expr) - -def emptyACNormContext : ACNormContext := - { preContexts := Std.HashMap.empty, - exprIds := Std.HashMap.empty, - exprIdsReverse := #[] - preContextsReverse := #[], - neutrals := Std.HashMap.empty } - -abbrev M α := StateT ACNormContext MetaM α - -def lazyCache - (lookup : ACNormContext → α → Option β) - (insert : ACNormContext → α → β → ACNormContext) - (create : ACNormContext → α → MetaM β) - (a : α) : M β := do - let state ← get - if let some b := lookup state a then - return b - else - let b ← create state a - set $ insert state a b - return b + arbitrary _ := Data.AC.Expr.var 0 + evalOp _ := Data.AC.Expr.op + evalVar _ x := Data.AC.Expr.var x def getInstance (cls : Name) (exprs : Array Expr) : MetaM (Option Expr) := do try @@ -72,183 +42,74 @@ def getInstance (cls : Name) (exprs : Array Expr) : MetaM (Option Expr) := do catch | _ => return none -def preContextCache : Expr → M (Option PreContext) := - lazyCache - (fun ctx => ctx.preContexts.find?) - (fun ctx a b => { ctx with - preContexts := ctx.preContexts.insert a b, - preContextsReverse := if let some b := b then ctx.preContextsReverse.push b else ctx.preContextsReverse }) - fun ctx expr => do - if let some assoc := ←getInstance ``IsAssociative #[expr] then - return some - { assoc, - op := expr - id := ctx.preContextsReverse.size - comm := ←getInstance ``IsCommutative #[expr] - idem := ←getInstance ``IsIdempotent #[expr] } +def preContext (expr : Expr) : MetaM (Option PreContext) := do + if let some assoc := ←getInstance ``IsAssociative #[expr] then + return some + { assoc, + op := expr + id := 0 + comm := ←getInstance ``IsCommutative #[expr] + idem := ←getInstance ``IsIdempotent #[expr] } - return none + return none -def exprId : Expr → M Nat := - lazyCache - (fun ctx => ctx.exprIds.find?) - (fun ctx a b => { ctx with exprIds := ctx.exprIds.insert a b, exprIdsReverse := ctx.exprIdsReverse.push a}) - fun ctx expr => pure ctx.exprIdsReverse.size - -def neutralCache (op : Nat) (var : Nat) : M (Option Expr) := - lazyCache - (fun ctx => ctx.neutrals.find?) - (fun ctx a b => { ctx with neutrals := ctx.neutrals.insert a b }) - (fun ctx (op, var) => do - let op := ctx.preContextsReverse[op].op - let var := ctx.exprIdsReverse[var] - getInstance ``IsNeutral #[op, var]) - (op, var) - -inductive NormalizedExpr -| maybeNormalized (opId : Nat) (l : NormalizedExpr) (r : NormalizedExpr) -| definitelyNormalized (varId : Nat) -| unnormalized (opId : Nat) (e : NormalizedExpr) - deriving Inhabited, Repr - -open NormalizedExpr - -def NormalizedExpr.norm : NormalizedExpr → M NormalizedExpr - | definitelyNormalized id => pure $ definitelyNormalized id - | e@(maybeNormalized opId _ _) => normalize opId e - | e@(unnormalized opId _) => normalize opId e -where - loop (opId : Nat) : NormalizedExpr → StateT (Std.HashMap Nat (Option Expr)) M ACExpr - | definitelyNormalized x => do - let isNeutral ← neutralCache opId x - modify fun state => state.insert x isNeutral - return var x - | unnormalized _ e => loop opId e - | maybeNormalized _ l r => return op (←loop opId l) (←loop opId r) - - normalize (opId : Nat) (e : NormalizedExpr) := do - let (orig, neutrals) ← loop opId e |>.run Std.HashMap.empty - let preContext := (←get).preContextsReverse[opId] - return convertBack opId $ evalList ACExpr preContext $ Lean.Data.AC.norm (neutrals, preContext) orig - - convertBack (opId : Nat) : ACExpr → NormalizedExpr - | var id => definitelyNormalized id - | op l r => maybeNormalized opId (convertBack opId l) (convertBack opId r) - -def NormalizedExpr.decide : NormalizedExpr → M (Option (Nat × NormalizedExpr)) - | definitelyNormalized _ => pure none - | unnormalized opId e => pure (opId, e) - | e@(maybeNormalized opId l r) => do - let rec loop : NormalizedExpr → StateT (Std.HashMap Nat (Option Expr)) M ACExpr - | definitelyNormalized x => do - let isNeutral ← neutralCache opId x - modify fun state => state.insert x isNeutral - return var x - | unnormalized _ e => loop e - | maybeNormalized _ l r => return op (←loop l) (←loop r) - - let (orig, neutrals) ← loop e |>.run Std.HashMap.empty - let preContext := (←get).preContextsReverse[opId] - let res := evalList ACExpr preContext $ Lean.Data.AC.norm (neutrals, preContext) orig - - return if orig == res then none else some (opId, e) - -open Lean.Expr +inductive PreExpr +| op (lhs rhs : PreExpr) +| var (e : Expr) @[matchPattern] def bin {x₁ x₂} (op l r : Expr) := - app (app op l x₁) r x₂ + Expr.app (Expr.app op l x₁) r x₂ -@[matchPattern] def eq {eq₁ eq₂ eq₃ eq₄ app₁ app₂ n₁} (l r : Expr) := - app (app (app (const (Lean.Name.str Lean.Name.anonymous "Eq" n₁) eq₁ eq₂) eq₃ eq₄) l app₁) r app₂ - -partial def findUnnormalizedOperator (e : Expr) : M NormalizedExpr := do - match e with - | lam _ dom b _ => decide [dom, b] >>= wrap - | forallE _ dom b _ => decide [dom, b] >>= wrap - | letE _ ty val b _ => decide [ty, val, b] >>= wrap - | mdata _ e _ => findUnnormalizedOperator e - | proj _ _ e _ => decide [e] >>= wrap - | bin opExpr lExpr rExpr => do - match ←preContextCache opExpr with - | none => decide [lExpr, rExpr] >>= wrap - | some pc => - match ←matchWithOp pc.id lExpr with - | (false, l) => return l - | (true, l) => - match ←matchWithOp pc.id rExpr with - | (false, r) => return r - | (true, r) => return maybeNormalized pc.id l r - - | app f a _ => decide [f, a] >>= wrap - | atom => - return definitelyNormalized (←exprId atom) - where - decide : List Lean.Expr → M (Option (Nat × NormalizedExpr)) - | [] => pure none - | x :: xs => do - match ←findUnnormalizedOperator x with - | definitelyNormalized _ => decide xs - | unnormalized opId e => pure $ some (opId, e) - | e@(maybeNormalized _ _ _) => return (←e.decide) <|> (←decide xs) - - wrap : Option (Nat × NormalizedExpr) → M NormalizedExpr - | none => return definitelyNormalized (←exprId e) - | some (opId, e) => return unnormalized opId e - - matchWithOp (op : Nat) (expr : Lean.Expr) : M (Bool × NormalizedExpr) := do - match ←findUnnormalizedOperator expr with - | e@(unnormalized _ _) => return (false, e) - | e@(definitelyNormalized _) => - return (true, e) - | e@(maybeNormalized op₂ _ _) => - match op == op₂ with - | true => return (true, e) - | false => - match ←e.decide with - | none => return (true, definitelyNormalized (←exprId expr)) - | some (op, e) => return (false, unnormalized op e) - -def buildProof (lhs rhs : NormalizedExpr) : M (Lean.Expr × Lean.Expr) := do - let vars := - getVars (maybeNormalized 0 lhs rhs) +def toACExpr (op l r : Expr) : MetaM (Array Expr × ACExpr) := do + let (preExpr, vars) ← + toPreExpr (mkApp2 op l r) |>.run Std.HashSet.empty - |>.2.toArray.insertionSort Nat.ble - + let vars := vars.toArray.insertionSort Expr.lt let varMap := vars.foldl (fun xs x => xs.insert x xs.size) Std.HashMap.empty |>.find! - let (preContext, context) ← mkContext vars - let ty ← mkEq (←convertType preContext.op lhs) (←convertType preContext.op rhs) let lhs ← convert varMap lhs - let rhs ← convert varMap rhs - let proof ← mkAppM ``Context.eq_of_norm #[context, lhs, rhs, ←mkEqRefl $ ←mkAppM ``Lean.Data.AC.norm #[context, lhs]] - return (proof, ty) + return (vars, toACExpr varMap preExpr) + where + toPreExpr : Expr → StateT ExprSet MetaM PreExpr + | e@(bin op₂ l r) => do + if ←isDefEq op op₂ then + return PreExpr.op (←toPreExpr l) (←toPreExpr r) + modify fun vars => vars.insert e + return PreExpr.var e + | e => do + modify fun vars => vars.insert e + return PreExpr.var e + + toACExpr (varMap : Expr → Nat) : PreExpr → ACExpr + | PreExpr.op l r => Data.AC.Expr.op (toACExpr varMap l) (toACExpr varMap r) + | PreExpr.var x => Data.AC.Expr.var (varMap x) + +def buildNormProof (preContext : PreContext) (l r : Expr) : MetaM (Lean.Expr × Lean.Expr) := do + let (vars, acExpr) ← toACExpr preContext.op l r + + let (isNeutrals, context) ← mkContext vars + let acExprNormed := Data.AC.evalList ACExpr preContext $ Data.AC.norm (preContext, isNeutrals) acExpr + let tgt ← convertTarget vars acExprNormed + let lhs ← convert acExpr + let rhs ← convert acExprNormed + let α ← inferType vars[0] + let u ← getLevel α + let proof := mkAppN (mkConst ``Context.eq_of_norm [u.dec.get!]) #[α, context, lhs, rhs, ←mkEqRefl (mkConst ``Bool.true)] + return (proof, tgt) where - getVars : NormalizedExpr → StateM (Std.HashSet Nat) Unit - | definitelyNormalized x => modify fun xs => xs.insert x - | unnormalized opId e => getVars e - | maybeNormalized opId l r => do getVars l; getVars r + mkContext (vars : Array Expr) : MetaM (Array Bool × Expr) := do + let arbitrary := vars[0] - mkContext (vars : Array Nat) : M (PreContext × Lean.Expr) := do - let op := - match lhs, rhs with - | definitelyNormalized _, definitelyNormalized _ => 0 - | unnormalized opId _, _ => opId - | maybeNormalized opId _ _, _ => opId - | _, unnormalized opId _ => opId - | _, maybeNormalized opId _ _ => opId - - let ctx ← get - let arbitrary := ctx.exprIdsReverse[vars[0]] - let preContext := ctx.preContextsReverse[op] - let vars : List Lean.Expr ← vars.toList.mapM fun x => do - let xExpr := ctx.exprIdsReverse[x] + let vars ← vars.mapM fun x => do let isNeutral ← - match ←neutralCache op x with - | none => mkAppOptM ``Option.none #[←mkAppM ``IsNeutral #[preContext.op, xExpr]] - | some isNeutral => mkAppM ``Option.some #[isNeutral] + match ←getInstance ``IsNeutral #[preContext.op, x] with + | none => pure (false, ←mkAppOptM ``Option.none #[←mkAppM ``IsNeutral #[preContext.op, x]]) + | some isNeutral => pure (true, ←mkAppM ``Option.some #[isNeutral]) - mkAppM ``Variable.mk #[xExpr, isNeutral] + return (isNeutral.1, ←mkAppM ``Variable.mk #[x, isNeutral.2]) - let vars ← Lean.Meta.mkListLit (←mkAppM ``Variable #[preContext.op]) vars + let (isNeutrals, vars) := vars.unzip + let vars := vars.toList + let vars ← mkListLit (←mkAppM ``Variable #[preContext.op]) vars let comm ← match preContext.comm with @@ -259,87 +120,51 @@ where match preContext.idem with | none => mkAppOptM ``Option.none #[←mkAppM ``IsIdempotent #[preContext.op]] | some idem => mkAppM ``Option.some #[idem] - return (preContext, ←mkAppM ``Lean.Data.AC.Context.mk #[preContext.op, preContext.assoc, comm, idem, vars, arbitrary]) - convert (varMap : Nat → Nat) : NormalizedExpr → MetaM Lean.Expr - | definitelyNormalized id => mkAppM ``var #[Lean.mkNatLit $ varMap id] - | unnormalized _ e => convert varMap e - | maybeNormalized _ l r => do mkAppM ``op #[←convert varMap l, ←convert varMap r] + return (isNeutrals, ←mkAppM ``Lean.Data.AC.Context.mk #[preContext.op, preContext.assoc, comm, idem, vars, arbitrary]) - convertType (op : Lean.Expr) : NormalizedExpr → M Lean.Expr - | definitelyNormalized id => do return (←get).exprIdsReverse[id] - | unnormalized _ e => convertType op e - | maybeNormalized _ l r => do mkAppM' op #[←convertType op l, ←convertType op r] + convert : ACExpr → MetaM Expr + | Data.AC.Expr.op l r => do mkAppM ``Data.AC.Expr.op #[←convert l, ←convert r] + | Data.AC.Expr.var x => mkAppM ``Data.AC.Expr.var #[mkNatLit x] -inductive ProofStrategy - | ac_rfl (lhs rhs : NormalizedExpr) - | simp - | norm (e : NormalizedExpr) + convertTarget (vars : Array Expr) : ACExpr → MetaM Expr + | Data.AC.Expr.op l r => do mkAppM' preContext.op #[←convertTarget vars l, ←convertTarget vars r] + | Data.AC.Expr.var x => return vars[x] -def pickStrategy (e : Expr) : M ProofStrategy := do - match e with - | eq l r => - match ←findUnnormalizedOperator l with - | lhs@(unnormalized _ _) => return ProofStrategy.norm lhs - | lhs@(maybeNormalized _ _ _) => - match ←findUnnormalizedOperator r with - | rhs@(unnormalized _ _) => return ProofStrategy.norm rhs - | rhs => return ProofStrategy.ac_rfl lhs rhs - | lhs@(definitelyNormalized _) => - match ←findUnnormalizedOperator r with - | rhs@(unnormalized _ _) => return ProofStrategy.norm rhs - | rhs@(maybeNormalized _ _ _) => return ProofStrategy.ac_rfl lhs rhs - | rhs@(definitelyNormalized _) => return ProofStrategy.simp - | e => return ProofStrategy.norm $ ←findUnnormalizedOperator e - -def addAcEq (mvarId : MVarId) (e : NormalizedExpr) (target : Expr) : M MVarId := do - let (proof, ty) ← buildProof e (←e.norm) - let goal ← withLocalDeclD `h_ac ty fun h_ac => - mkForallFVars #[h_ac] target - let goal ← mkFreshExprMVar goal - assignExprMVar mvarId (mkApp goal proof) - return goal.mvarId! - -partial def rewriteUnnormalized (mvarId : MVarId) : M Unit := - withMVarContext mvarId do - let target ← getMVarType mvarId - match ←pickStrategy target with - | ProofStrategy.ac_rfl lhs rhs => - trace[Meta.AC] "picking ac_rfl strategy {MessageData.ofGoal mvarId}" - try - let (proof, ty) ← buildProof lhs rhs - if ←isDefEq target ty then - assignExprMVar mvarId proof - else throwError "" - catch _ => throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}" - | ProofStrategy.simp => - trace[Meta.AC] "picking simp strategy {MessageData.ofGoal mvarId}" - let simpCtx ← Simp.Context.mkDefault - let newGoal ← simpTarget mvarId simpCtx - unless newGoal.isNone do - throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}" - | ProofStrategy.norm (definitelyNormalized _) => throwError "no unnormalized operators found" - | ProofStrategy.norm e => - trace[Meta.AC] "picking norm strategy {MessageData.ofGoal mvarId}" - let mvarId ← addAcEq mvarId e target - let (h_ac, mvarId) ← intro mvarId `h_ac - let simpCtx ← Simp.Context.mkDefault - withMVarContext mvarId do - let simpCtx := { simpCtx with simpTheorems := ←simpCtx.simpTheorems.add #[] (mkFVar h_ac) } - - trace[Meta.AC] "pre rewrite state:\n{MessageData.ofGoal mvarId}\n" - let mvarId ← simpTarget mvarId simpCtx - if let some mvarId := mvarId then - if not $ ←isDefEq target (←getMVarType mvarId) then - let mvarId ← clear mvarId h_ac - trace[Meta.AC] "post rewrite state:\n{MessageData.ofGoal mvarId}\n" - rewriteUnnormalized mvarId - else - throwError "cannot synthesize proof:\n{MessageData.ofGoal mvarId}" +def rewriteUnnormalized (mvarId : MVarId) : MetaM Unit := do + let simpCtx := + { + simpTheorems := {} + congrTheorems := (← getSimpCongrTheorems) + config := Simp.neutralConfig + } + let tgt ← getMVarType mvarId + let res ← Simp.main tgt simpCtx (methods := { post }) + let newGoal ← applySimpResultToTarget mvarId tgt res + applyRefl newGoal +where + post (e : Expr) : SimpM Simp.Step := do + let ctx ← read + match e, ctx.parent? with + | bin op₁ l r, some (bin op₂ _ _) => + if ←isDefEq op₁ op₂ then + return Simp.Step.done { expr := e } + match ←preContext op₁ with + | some pc => + let (proof, newTgt) ← buildNormProof pc l r + return Simp.Step.done { expr := newTgt, proof? := proof } + | none => return Simp.Step.done { expr := e } + | bin op l r, none => + match ←preContext op with + | some pc => + let (proof, newTgt) ← buildNormProof pc l r + return Simp.Step.done { expr := newTgt, proof? := proof } + | none => return Simp.Step.done { expr := e } + | e, _ => return Simp.Step.done { expr := e } @[builtinTactic ac_refl] def ac_refl_tactic : Lean.Elab.Tactic.Tactic := fun stx => do let goal ← getMainGoal - (rewriteUnnormalized goal).run' emptyACNormContext + rewriteUnnormalized goal builtin_initialize registerTraceClass `Meta.AC diff --git a/tests/lean/run/ac_refl.lean b/tests/lean/run/ac_refl.lean index 545a2c0729..abe6393330 100644 --- a/tests/lean/run/ac_refl.lean +++ b/tests/lean/run/ac_refl.lean @@ -59,6 +59,10 @@ theorem ex₂ (n m : Nat) (xs : Vector α (n+m)) (ys : Vector α (m+n)) : (f (n+ ac_refl -- Repro: Binders also trigger invalid proofs ---theorem ex₃ (n : Nat) : (fun x => n + x) = (fun x => x + n) := by --- ac_refl ---#print ex₃ +theorem ex₃ (n : Nat) : (fun x => n + x) = (fun x => x + n) := by + ac_refl +#print ex₃ + +-- Repro: the Prop universe doesn't work +example (p q : Prop) : p ∨ p ∨ q ∧ True = q ∨ p := by + ac_refl