fix: smart unfolding bug in over applications
This commit is contained in:
parent
eb7539ef77
commit
4e261b15e5
3 changed files with 21 additions and 20 deletions
|
|
@ -856,25 +856,31 @@ def mkAppRevRange (f : Expr) (beginIdx endIdx : Nat) (revArgs : Array Expr) : Ex
|
|||
If `useZeta` is true, the function also performs zeta-reduction to create further
|
||||
opportunities for beta reduction.
|
||||
-/
|
||||
partial def betaRev (f : Expr) (revArgs : Array Expr) (useZeta := false) : Expr :=
|
||||
partial def betaRev (f : Expr) (revArgs : Array Expr) (useZeta := false) (preserveMData := false) : Expr :=
|
||||
if revArgs.size == 0 then f
|
||||
else
|
||||
let sz := revArgs.size
|
||||
let rec go : Expr → Nat → Expr
|
||||
| Expr.lam _ _ b _, i =>
|
||||
let rec go (e : Expr) (i : Nat) : Expr :=
|
||||
match e with
|
||||
| Expr.lam _ _ b _ =>
|
||||
if i + 1 < sz then
|
||||
go b (i+1)
|
||||
else
|
||||
let n := sz - (i + 1)
|
||||
mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs
|
||||
| e@(Expr.letE _ _ v b _), i =>
|
||||
| Expr.letE _ _ v b _ =>
|
||||
if useZeta && i < sz then
|
||||
go (b.instantiate1 v) i
|
||||
else
|
||||
let n := sz - i
|
||||
mkAppRevRange (e.instantiateRange n sz revArgs) 0 n revArgs
|
||||
| Expr.mdata _ b _, i => go b i
|
||||
| b, i =>
|
||||
| Expr.mdata k b _=>
|
||||
if preserveMData then
|
||||
let n := sz - i
|
||||
mkMData k (mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs)
|
||||
else
|
||||
go b i
|
||||
| b =>
|
||||
let n := sz - i
|
||||
mkAppRevRange (b.instantiateRange n sz revArgs) 0 n revArgs
|
||||
go f 0
|
||||
|
|
|
|||
|
|
@ -290,12 +290,12 @@ end
|
|||
successK val
|
||||
|
||||
@[specialize] private def deltaBetaDefinition (c : ConstantInfo) (lvls : List Level) (revArgs : Array Expr)
|
||||
(failK : Unit → α) (successK : Expr → α) : α :=
|
||||
(failK : Unit → α) (successK : Expr → α) (preserveMData := false) : α :=
|
||||
if c.levelParams.length != lvls.length then
|
||||
failK ()
|
||||
else
|
||||
let val := c.instantiateValueLevelParams lvls
|
||||
let val := val.betaRev revArgs
|
||||
let val := val.betaRev revArgs (preserveMData := preserveMData)
|
||||
successK val
|
||||
|
||||
inductive ReduceMatcherResult where
|
||||
|
|
@ -584,7 +584,8 @@ mutual
|
|||
if smartUnfolding.get (← getOptions) then
|
||||
match ((← getEnv).find? (mkSmartUnfoldingNameFor fInfo.name)) with
|
||||
| some fAuxInfo@(ConstantInfo.defnInfo _) =>
|
||||
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (fun _ => pure none) fun e₁ =>
|
||||
-- We use `preserveMData := true` to make sure the smart unfolding annotation are not erased in an over-application.
|
||||
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (preserveMData := true) (fun _ => pure none) fun e₁ =>
|
||||
smartUnfoldingReduce? e₁
|
||||
| _ =>
|
||||
if (← getMatcherInfo? fInfo.name).isSome then
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ inductive Term : List Ty → Ty → Type
|
|||
| lam : Term (dom :: ctx) ran → Term ctx (.fn dom ran)
|
||||
| «let» : Term ctx ty₁ → Term (ty₁ :: ctx) ty₂ → Term ctx ty₂
|
||||
|
||||
def Term.denote : Term ctx ty → HList Ty.denote ctx → ty.denote
|
||||
@[simp] def Term.denote : Term ctx ty → HList Ty.denote ctx → ty.denote
|
||||
| var h, env => env.get h
|
||||
| const n, _ => n
|
||||
| plus a b, env => a.denote env + b.denote env
|
||||
|
|
@ -36,7 +36,7 @@ def Term.denote : Term ctx ty → HList Ty.denote ctx → ty.denote
|
|||
| lam b, env => fun x => b.denote (x :: env)
|
||||
| «let» a b, env => b.denote (a.denote env :: env)
|
||||
|
||||
def Term.constFold : Term ctx ty → Term ctx ty
|
||||
@[simp] def Term.constFold : Term ctx ty → Term ctx ty
|
||||
| const n => const n
|
||||
| var h => var h
|
||||
| app f a => app f.constFold a.constFold
|
||||
|
|
@ -48,14 +48,8 @@ def Term.constFold : Term ctx ty → Term ctx ty
|
|||
| a', b' => plus a' b'
|
||||
|
||||
theorem Term.constFold_sound (e : Term ctx ty) : e.constFold.denote env = e.denote env := by
|
||||
induction e with
|
||||
| const => rfl
|
||||
| var => rfl
|
||||
| app f a ihf iha => simp [constFold]; rw [denote, denote, iha, ihf]
|
||||
| lam b ih => simp [constFold]; rw [denote, denote]; simp [ih]
|
||||
| «let» a b iha ihb => simp [constFold]; rw [denote, denote, iha, ihb]
|
||||
induction e with simp [*]
|
||||
| plus a b iha ihb =>
|
||||
simp [constFold]
|
||||
split
|
||||
next he₁ he₂ => rw [denote, denote, ← iha, ← ihb, he₁, he₂, denote, denote]
|
||||
next he₁ he₂ _ _ _ => rw [denote, denote, ← he₁, ← he₂, iha, ihb]
|
||||
next he₁ he₂ => simp [← iha, ← ihb, he₁, he₂]
|
||||
next he₁ he₂ _ _ _ => simp [← he₁, ← he₂, iha, ihb]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue