diff --git a/src/Lean/Meta/AppBuilder.lean b/src/Lean/Meta/AppBuilder.lean index b6ad609875..b08017a056 100644 --- a/src/Lean/Meta/AppBuilder.lean +++ b/src/Lean/Meta/AppBuilder.lean @@ -440,6 +440,18 @@ def mkFunExt (h : Expr) : MetaM Expr := def mkPropExt (h : Expr) : MetaM Expr := mkAppM ``propext #[h] +/-- Return `let_congr h₁ h₂` -/ +def mkLetCongr (h₁ h₂ : Expr) : MetaM Expr := + mkAppM ``let_congr #[h₁, h₂] + +/-- Return `let_val_congr b h` -/ +def mkLetValCongr (b h : Expr) : MetaM Expr := + mkAppM ``let_val_congr #[b, h] + +/-- Return `let_body_congr a h` -/ +def mkLetBodyCongr (a h : Expr) : MetaM Expr := + mkAppM ``let_body_congr #[a, h] + /-- Return `of_eq_true h` -/ def mkOfEqTrue (h : Expr) : MetaM Expr := mkAppM ``of_eq_true #[h] diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 8ffdc1092c..0b155daf4c 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -323,13 +323,39 @@ where return { expr := (← dsimp e) } simpLet (e : Expr) : M Result := do - if (← getConfig).zeta then - match e with - | Expr.letE _ _ v b _ => return { expr := b.instantiate1 v } - | _ => unreachable! - else - -- TODO: simplify nondependent let-decls - return { expr := (← dsimp e) } + match e with + | Expr.letE n t v b _ => + if (← getConfig).zeta then + return { expr := b.instantiate1 v } + else + withLocalDeclD n t fun x => do + let bx := b.instantiate1 x + /- The following step is potentially very expensive when we have many nested let-decls. + TODO: handle a block of nested let decls in a single pass if this becomes a performance problem. -/ + if (← isTypeCorrect bx) then + let bxType ← whnf (← inferType bx) + let rbx ← simp bx + let hb? ← match rbx.proof? with + | none => pure none + | some h => pure (some (← mkLambdaFVars #[x] h)) + if (← dependsOn bxType x.fvarId!) then + /- The type of the body depends on `x`. So, we use `let_body_congr` -/ + let v' ← dsimp v + let e' := mkLet n t v' (← abstract rbx.expr #[x]) + match hb? with + | none => return { expr := e' } + | some h => return { expr := e', proof? := some (← mkLetBodyCongr v' h) } + else + /- The type of the body does not depend on `x`. So, we use `let_congr` -/ + let rv ← simp v + let e' := mkLet n t rv.expr (← abstract rbx.expr #[x]) + match rv.proof?, hb? with + | none, none => return { expr := e' } + | some h, none => return { expr := e', proof? := some (← mkLetValCongr (← mkLambdaFVars #[x] rbx.expr) h) } + | _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) } + else + return { expr := (← dsimp e) } + | _ => unreachable! cacheResult (cfg : Config) (r : Result) : M Result := do if cfg.memoize then diff --git a/tests/lean/simpZetaFalse.lean b/tests/lean/simpZetaFalse.lean new file mode 100644 index 0000000000..da973bcc5f --- /dev/null +++ b/tests/lean/simpZetaFalse.lean @@ -0,0 +1,28 @@ +constant f : Nat → Nat +axiom f_eq (x : Nat) : f (f x) = x + +theorem ex1 (x : Nat) (h : f (f x) = x) : (let y := x*x; if f (f x) = x then 1 else y + 1) = 1 := by + simp (config := { zeta := false }) only [h] + traceState + simp + +#print ex1 -- uses let_congr + +theorem ex2 (x z : Nat) (h : f (f x) = x) (h' : z = x) : (let y := f (f x); y) = z := by + simp (config := { zeta := false }) only [h] + traceState + simp [h'] + +#print ex2 -- uses let_val_congr + +theorem ex3 (x z : Nat) : (let α := Nat; (fun x : α => 0 + x)) = id := by + simp (config := { zeta := false }) + traceState -- should not simplify let body since `fun α : Nat => fun x : α => 0 + x` is not type correct + simp [id] + +theorem ex4 (p : Prop) (h : p) : (let n := 10; fun x : { z : Nat // z < n } => x = x) = fun z => p := by + simp (config := { zeta := false }) + traceState + simp [h] + +#print ex4 -- uses let_body_congr diff --git a/tests/lean/simpZetaFalse.lean.expected.out b/tests/lean/simpZetaFalse.lean.expected.out new file mode 100644 index 0000000000..2fa03ad295 --- /dev/null +++ b/tests/lean/simpZetaFalse.lean.expected.out @@ -0,0 +1,52 @@ +x : Nat +h : f (f x) = x +⊢ (let y := x * x; + if True then 1 else y + 1) = + 1 +theorem ex1 : ∀ (x : Nat), + f (f x) = x → + (let y := x * x; + if f (f x) = x then 1 else y + 1) = + 1 := +fun x h => + Eq.mpr + (congrFun + (congrArg Eq + (let_congr (Eq.refl (x * x)) + fun y => + ite_congr (Eq.trans (congrFun (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) + fun a => Eq.refl (y + 1))) + 1) + (of_eq_true (Eq.trans (congrFun (congrArg Eq (ite_true 1 (x * x + 1))) 1) (eq_true_of_decide (Eq.refl true)))) +x z : Nat +h : f (f x) = x +h' : z = x +⊢ (let y := x; + y) = + z +theorem ex2 : ∀ (x z : Nat), + f (f x) = x → + z = x → + (let y := f (f x); + y) = + z := +fun x z h h' => + Eq.mpr (congrFun (congrArg Eq (let_val_congr (fun y => y) h)) z) + (of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x))) +x z : Nat +⊢ (let α := Nat; + fun x => 0 + x) = + id +p : Prop +h : p +⊢ (let n := 10; + fun x => True) = + fun z => p +theorem ex4 : ∀ (p : Prop), + p → + (let n := 10; + fun x => x = x) = + fun z => p := +fun p h => + Eq.mpr (congrFun (congrArg Eq (let_body_congr 10 fun n => funext fun x => eq_self x)) fun z => p) + (of_eq_true (Eq.trans (congrArg (Eq fun x => True) (funext fun z => eq_true h)) (eq_self fun x => True)))