From 5302b7889abcd018cbc84d63dd04eefadca0969f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 6 Mar 2024 10:29:20 -0800 Subject: [PATCH] fix: fold raw `Nat` literals at `dsimp` (#3624) closes #2916 Remark: this PR also renames `Expr.natLit?` ==> `Expr.rawNatLit?`. Motivation: consistent naming convention: `Expr.isRawNatLit`. --- src/Lean/Expr.lean | 2 +- src/Lean/Meta/Reduce.lean | 2 +- src/Lean/Meta/Tactic/Simp/Main.lean | 48 ++++++++++++------ tests/lean/run/2916.lean | 77 +++++++++++++++++++++++++++++ tests/lean/run/maze.lean | 4 +- tests/lean/run/meta2.lean | 2 +- 6 files changed, 115 insertions(+), 20 deletions(-) create mode 100644 tests/lean/run/2916.lean diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 0bdaaa4131..0db02e1705 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -924,7 +924,7 @@ def isRawNatLit : Expr → Bool | lit (Literal.natVal _) => true | _ => false -def natLit? : Expr → Option Nat +def rawNatLit? : Expr → Option Nat | lit (Literal.natVal v) => v | _ => none diff --git a/src/Lean/Meta/Reduce.lean b/src/Lean/Meta/Reduce.lean index aa20e6a5ab..e5144d08a5 100644 --- a/src/Lean/Meta/Reduce.lean +++ b/src/Lean/Meta/Reduce.lean @@ -33,7 +33,7 @@ partial def reduce (e : Expr) (explicitOnly skipTypes skipProofs := true) : Meta else args ← args.modifyM i visit if f.isConstOf ``Nat.succ && args.size == 1 && args[0]!.isRawNatLit then - return mkRawNatLit (args[0]!.natLit?.get! + 1) + return mkRawNatLit (args[0]!.rawNatLit?.get! + 1) else return mkAppN f args | Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs (← visit b) diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 0ad49a6f76..ecd1ca67ae 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -35,6 +35,21 @@ def Config.updateArith (c : Config) : CoreM Config := do def isOfNatNatLit (e : Expr) : Bool := e.isAppOfArity ``OfNat.ofNat 3 && e.appFn!.appArg!.isRawNatLit +/-- +If `e` is a raw Nat literal and `OfNat.ofNat` is not in the list of declarations to unfold, +return an `OfNat.ofNat`-application. +-/ +def foldRawNatLit (e : Expr) : SimpM Expr := do + match e.rawNatLit? with + | some n => + /- If `OfNat.ofNat` is marked to be unfolded, we do not pack orphan nat literals as `OfNat.ofNat` applications + to avoid non-termination. See issue #788. -/ + if (← readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then + return e + else + return toExpr n + | none => return e + private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do matchConst e.getAppFn (fun _ => pure none) fun cinfo _ => do match (← getProjectionFnInfo? cinfo.name) with @@ -179,7 +194,7 @@ private def reduceStep (e : Expr) : SimpM Expr := do trace[Meta.Tactic.simp.rewrite] "unfold {mkConst e.getAppFn.constName!}, {e} ==> {e'}" recordSimpTheorem (.decl e.getAppFn.constName!) return e' - | none => return e + | none => foldRawNatLit e private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do let e' ← reduceStep e @@ -233,17 +248,6 @@ def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do else f -def simpLit (e : Expr) : SimpM Result := do - match e.natLit? with - | some n => - /- If `OfNat.ofNat` is marked to be unfolded, we do not pack orphan nat literals as `OfNat.ofNat` applications - to avoid non-termination. See issue #788. -/ - if (← readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then - return { expr := e } - else - return { expr := (← mkNumeral (mkConst ``Nat) n) } - | none => return { expr := e } - def simpProj (e : Expr) : SimpM Result := do match (← reduceProj? e) with | some e => return { expr := e } @@ -406,13 +410,27 @@ private def dsimpReduce : DSimproc := fun e => do eNew ← reduceFVar (← getConfig) (← getSimpTheorems) eNew if eNew != e then return .visit eNew else return .done e +/-- +Auliliary `dsimproc` for not visiting `OfNat.ofNat` application subterms. +This is the `dsimp` equivalent of the approach used at `visitApp`. +Recall that we fold orphan raw Nat literals. +-/ +private def doNotVisitOfNat : DSimproc := fun e => do + if isOfNatNatLit e then + if (← readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then + return .continue e + else + return .done e + else + return .continue e + @[export lean_dsimp] private partial def dsimpImpl (e : Expr) : SimpM Expr := do let cfg ← getConfig unless cfg.dsimp do return e let m ← getMethods - let pre := m.dpre + let pre := m.dpre >> doNotVisitOfNat let post := m.dpost >> dsimpReduce transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post) @@ -533,7 +551,7 @@ def congr (e : Expr) : SimpM Result := do def simpApp (e : Expr) : SimpM Result := do if isOfNatNatLit e then - -- Recall that we expand "orphan" kernel nat literals `n` into `ofNat n` + -- Recall that we expand "orphan" kernel Nat literals `n` into `OfNat.ofNat n` return { expr := e } else congr e @@ -549,7 +567,7 @@ def simpStep (e : Expr) : SimpM Result := do | .const .. => simpConst e | .bvar .. => unreachable! | .sort .. => return { expr := e } - | .lit .. => simpLit e + | .lit .. => return { expr := e } | .mvar .. => return { expr := (← instantiateMVars e) } | .fvar .. => return { expr := (← reduceFVar (← getConfig) (← getSimpTheorems) e) } diff --git a/tests/lean/run/2916.lean b/tests/lean/run/2916.lean new file mode 100644 index 0000000000..f90249abe8 --- /dev/null +++ b/tests/lean/run/2916.lean @@ -0,0 +1,77 @@ +set_option pp.coercions false -- Show `OfNat.ofNat` when present for clarity + +/-- +warning: declaration uses 'sorry' +--- +info: x : Nat +⊢ OfNat.ofNat 2 = x +-/ +#guard_msgs in +example : nat_lit 2 = x := by + simp only + trace_state + sorry + +/-- +warning: declaration uses 'sorry' +--- +info: x : Nat +⊢ OfNat.ofNat 2 = x +-/ +#guard_msgs in +example : nat_lit 2 = x := by + dsimp only -- dsimp made no progress + trace_state + sorry + +/-- +warning: declaration uses 'sorry' +--- +info: α : Nat → Type +f : (n : Nat) → α n +x : α (OfNat.ofNat 2) +⊢ f (OfNat.ofNat 2) = x +-/ +#guard_msgs in +example (α : Nat → Type) (f : (n : Nat) → α n) (x : α 2) : f (nat_lit 2) = x := by + simp only + trace_state + sorry + +/-- +info: x : Nat +f : Nat → Nat +h : f (OfNat.ofNat 2) = x +⊢ f (OfNat.ofNat 2) = x +--- +info: x : Nat +f : Nat → Nat +h : f (OfNat.ofNat 2) = x +⊢ f 2 = x +-/ +#guard_msgs in +example (f : Nat → Nat) (h : f 2 = x) : f 2 = x := by + trace_state + simp [OfNat.ofNat] + trace_state + assumption + +/-- +warning: declaration uses 'sorry' +--- +info: α : Nat → Type +f : (n : Nat) → α n +x : α (OfNat.ofNat 2) +⊢ f (OfNat.ofNat 2) = x +--- +info: α : Nat → Type +f : (n : Nat) → α n +x : α (OfNat.ofNat 2) +⊢ f 2 = x +-/ +#guard_msgs in +example (α : Nat → Type) (f : (n : Nat) → α n) (x : α 2) : f 2 = x := by + trace_state + simp [OfNat.ofNat] + trace_state + sorry diff --git a/tests/lean/run/maze.lean b/tests/lean/run/maze.lean index 431aaecdc5..23c2f64c29 100644 --- a/tests/lean/run/maze.lean +++ b/tests/lean/run/maze.lean @@ -100,8 +100,8 @@ def extractXY : Lean.Expr → Lean.MetaM Coords let sizeArgs := Lean.Expr.getAppArgs e' let x ← Lean.Meta.whnf sizeArgs[0]! let y ← Lean.Meta.whnf sizeArgs[1]! - let numCols := (Lean.Expr.natLit? x).get! - let numRows := (Lean.Expr.natLit? y).get! + let numCols := (Lean.Expr.rawNatLit? x).get! + let numRows := (Lean.Expr.rawNatLit? y).get! return Coords.mk numCols numRows partial def extractWallList : Lean.Expr → Lean.MetaM (List Coords) diff --git a/tests/lean/run/meta2.lean b/tests/lean/run/meta2.lean index 230835f6ab..7f7a9dbf2f 100644 --- a/tests/lean/run/meta2.lean +++ b/tests/lean/run/meta2.lean @@ -676,7 +676,7 @@ check t; (match t.arrayLit? with | some (_, xs) => do checkM $ pure $ xs.length == 2; - (match (xs.get! 0).natLit?, (xs.get! 1).natLit? with + (match (xs.get! 0).rawNatLit?, (xs.get! 1).rawNatLit? with | some 1, some 2 => pure () | _, _ => throwError "nat lits expected") | none => throwError "array lit expected")