diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 25acd72585..66f8cb479d 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -856,25 +856,31 @@ def mkAppRevRange (f : Expr) (beginIdx endIdx : Nat) (revArgs : Array Expr) : Ex If `useZeta` is true, the function also performs zeta-reduction to create further opportunities for beta reduction. -/ -partial def betaRev (f : Expr) (revArgs : Array Expr) (useZeta := false) : Expr := +partial def betaRev (f : Expr) (revArgs : Array Expr) (useZeta := false) (preserveMData := false) : Expr := if revArgs.size == 0 then f else let sz := revArgs.size - let rec go : Expr → Nat → Expr - | Expr.lam _ _ b _, i => + let rec go (e : Expr) (i : Nat) : Expr := + match e with + | Expr.lam _ _ b _ => if i + 1 < sz then go b (i+1) else let n := sz - (i + 1) mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs - | e@(Expr.letE _ _ v b _), i => + | Expr.letE _ _ v b _ => if useZeta && i < sz then go (b.instantiate1 v) i else let n := sz - i mkAppRevRange (e.instantiateRange n sz revArgs) 0 n revArgs - | Expr.mdata _ b _, i => go b i - | b, i => + | Expr.mdata k b _=> + if preserveMData then + let n := sz - i + mkMData k (mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs) + else + go b i + | b => let n := sz - i mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs go f 0 diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 86f68e10c8..b0cb8ce480 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -290,12 +290,12 @@ end successK val @[specialize] private def deltaBetaDefinition (c : ConstantInfo) (lvls : List Level) (revArgs : Array Expr) - (failK : Unit → α) (successK : Expr → α) : α := + (failK : Unit → α) (successK : Expr → α) (preserveMData := false) : α := if c.levelParams.length != lvls.length then failK () else let val := c.instantiateValueLevelParams lvls - let val := val.betaRev revArgs + let val := val.betaRev revArgs (preserveMData := preserveMData) successK val inductive ReduceMatcherResult where @@ -584,7 +584,8 @@ mutual if smartUnfolding.get (← getOptions) then match ((← getEnv).find? (mkSmartUnfoldingNameFor fInfo.name)) with | some fAuxInfo@(ConstantInfo.defnInfo _) => - deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) fun e₁ => + -- We use `preserveMData := true` to make sure the smart unfolding annotation are not erased in an over-application. + deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (preserveMData := true) (fun _ => pure none) fun e₁ => smartUnfoldingReduce? e₁ | _ => if (← getMatcherInfo? fInfo.name).isSome then diff --git a/tests/lean/run/deBruijn.lean b/tests/lean/run/deBruijn.lean index 4f1716ba1a..5a82bd7b0b 100644 --- a/tests/lean/run/deBruijn.lean +++ b/tests/lean/run/deBruijn.lean @@ -28,7 +28,7 @@ inductive Term : List Ty → Ty → Type | lam : Term (dom :: ctx) ran → Term ctx (.fn dom ran) | «let» : Term ctx ty₁ → Term (ty₁ :: ctx) ty₂ → Term ctx ty₂ -def Term.denote : Term ctx ty → HList Ty.denote ctx → ty.denote +@[simp] def Term.denote : Term ctx ty → HList Ty.denote ctx → ty.denote | var h, env => env.get h | const n, _ => n | plus a b, env => a.denote env + b.denote env @@ -36,7 +36,7 @@ def Term.denote : Term ctx ty → HList Ty.denote ctx → ty.denote | lam b, env => fun x => b.denote (x :: env) | «let» a b, env => b.denote (a.denote env :: env) -def Term.constFold : Term ctx ty → Term ctx ty +@[simp] def Term.constFold : Term ctx ty → Term ctx ty | const n => const n | var h => var h | app f a => app f.constFold a.constFold @@ -48,14 +48,8 @@ def Term.constFold : Term ctx ty → Term ctx ty | a', b' => plus a' b' theorem Term.constFold_sound (e : Term ctx ty) : e.constFold.denote env = e.denote env := by - induction e with - | const => rfl - | var => rfl - | app f a ihf iha => simp [constFold]; rw [denote, denote, iha, ihf] - | lam b ih => simp [constFold]; rw [denote, denote]; simp [ih] - | «let» a b iha ihb => simp [constFold]; rw [denote, denote, iha, ihb] + induction e with simp [*] | plus a b iha ihb => - simp [constFold] split - next he₁ he₂ => rw [denote, denote, ← iha, ← ihb, he₁, he₂, denote, denote] - next he₁ he₂ _ _ _ => rw [denote, denote, ← he₁, ← he₂, iha, ihb] + next he₁ he₂ => simp [← iha, ← ihb, he₁, he₂] + next he₁ he₂ _ _ _ => simp [← he₁, ← he₂, iha, ihb]