From d52b8e3cc162012b7b70ed1d98ba4adc8e5e56ec Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Wed, 16 Apr 2025 15:53:18 -0700 Subject: [PATCH] fix: use lcAny in more cases of type erasure (#7990) This PR adopts lcAny in more cases of type erasure in the new code generator. --- src/Lean/Compiler/LCNF/InferType.lean | 8 ++-- src/Lean/Compiler/LCNF/MonoTypes.lean | 8 +++- src/Lean/Compiler/LCNF/ToLCNF.lean | 2 +- src/Lean/Compiler/LCNF/Types.lean | 21 +++++----- tests/lean/lcnfTypes.lean.expected.out | 53 +++++++++++++------------- tests/lean/run/erased.lean | 4 +- 6 files changed, 52 insertions(+), 44 deletions(-) diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index 1bb54b7571..418026de19 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -142,7 +142,7 @@ mutual fType := instantiateRevRangeArgs fType j i args |>.headBeta match fType with | .forallE _ _ b _ => j := i; fType := b - | _ => return erasedExpr + | _ => return anyExpr return instantiateRevRangeArgs fType j args.size args |>.headBeta partial def inferAppType (e : Expr) : InferTypeM Expr := do @@ -157,7 +157,7 @@ mutual fType := fType.instantiateRevRange j i args |>.headBeta match fType with | .forallE _ _ b _ => j := i; fType := b - | _ => return erasedExpr + | _ => return anyExpr return fType.instantiateRevRange j args.size args |>.headBeta partial def inferProjType (structName : Name) (idx : Nat) (s : FVarId) : InferTypeM Expr := do @@ -167,6 +167,8 @@ mutual if structType.isErased then /- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/ return erasedExpr + else if structType.isAny then + return anyExpr else matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal => let structTypeArgs := structType.getAppArgs @@ -179,7 +181,7 @@ mutual | .forallE _ _ body _ => if body.hasLooseBVars then -- This can happen when one of the fields is a type or type former. - ctorType := body.instantiate1 erasedExpr + ctorType := body.instantiate1 anyExpr else ctorType := body | _ => diff --git a/src/Lean/Compiler/LCNF/MonoTypes.lean b/src/Lean/Compiler/LCNF/MonoTypes.lean index 8983ca43e1..2aa5e61112 100644 --- a/src/Lean/Compiler/LCNF/MonoTypes.lean +++ b/src/Lean/Compiler/LCNF/MonoTypes.lean @@ -85,7 +85,11 @@ partial def toMonoType (type : Expr) : CoreM Expr := do where visitApp (f : Expr) (args : Array Expr) : CoreM Expr := do match f with - | .const ``lcErased _ => return erasedExpr + | .const ``lcErased _ => + if args.all (·.isErased) then + return erasedExpr + else + return anyExpr | .const ``lcAny _ => return anyExpr | .const ``Decidable _ => return mkConst ``Bool | .const declName us => @@ -101,7 +105,7 @@ where if d matches .const ``lcErased _ | .sort _ then result := mkApp result (← toMonoType arg) else - result := mkApp result erasedExpr + result := mkApp result anyExpr type := b.instantiate1 arg return result | _ => return anyExpr diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 06ec970439..84a7b29fca 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -302,7 +302,7 @@ are type formers. This can happen when we have a field whose type is, for exampl def applyToAny (type : Expr) : M Expr := do let toAny := (← get).toAny return type.replace fun - | .fvar fvarId => if toAny.contains fvarId then some erasedExpr else none + | .fvar fvarId => if toAny.contains fvarId then some anyExpr else none | _ => none def toLCNFType (type : Expr) : M Expr := do diff --git a/src/Lean/Compiler/LCNF/Types.lean b/src/Lean/Compiler/LCNF/Types.lean index fe16d3414f..a3ba0621bd 100644 --- a/src/Lean/Compiler/LCNF/Types.lean +++ b/src/Lean/Compiler/LCNF/Types.lean @@ -18,6 +18,9 @@ def anyExpr := mkConst ``lcAny def _root_.Lean.Expr.isErased (e : Expr) := e.isAppOf ``lcErased +def _root_.Lean.Expr.isAny (e : Expr) := + e.isAppOf ``lcAny + def isPropFormerTypeQuick : Expr → Bool | .forallE _ _ b _ => isPropFormerTypeQuick b | .sort .zero => true @@ -132,7 +135,7 @@ partial def toLCNFType (type : Expr) : MetaM Expr := do | .forallE .. => visitForall type #[] | .app .. => type.withApp visitApp | .fvar .. => visitApp type #[] - | _ => return erasedExpr + | _ => return mkConst ``lcAny where whnfEta (type : Expr) : MetaM Expr := do let type ← whnf type @@ -156,10 +159,10 @@ where visitApp (f : Expr) (args : Array Expr) := do let fNew ← match f with | .const declName us => - let .inductInfo _ ← getConstInfo declName | return erasedExpr + let .inductInfo _ ← getConstInfo declName | return anyExpr pure <| .const declName us | .fvar .. => pure f - | _ => return erasedExpr + | _ => return anyExpr let mut result := fNew for arg in args do if (← isProp arg) then @@ -169,13 +172,13 @@ where else if (← isTypeFormer arg) then result := mkApp result (← toLCNFType arg) else - result := mkApp result erasedExpr + result := mkApp result (mkConst ``lcAny) return result mutual partial def joinTypes (a b : Expr) : Expr := - joinTypes? a b |>.getD erasedExpr + joinTypes? a b |>.getD (mkConst ``lcAny) partial def joinTypes? (a b : Expr) : Option Expr := do if a.isErased || b.isErased then @@ -194,16 +197,16 @@ partial def joinTypes? (a b : Expr) : Option Expr := do | .app f a, .app g b => (do return .app (← joinTypes? f g) (← joinTypes? a b)) <|> - return erasedExpr + return (mkConst ``lcAny) | .forallE n d₁ b₁ _, .forallE _ d₂ b₂ _ => (do return .forallE n (← joinTypes? d₁ d₂) (joinTypes b₁ b₂) .default) <|> - return erasedExpr + return (mkConst ``lcAny) | .lam n d₁ b₁ _, .lam _ d₂ b₂ _ => (do return .lam n (← joinTypes? d₁ d₂) (joinTypes b₁ b₂) .default) <|> - return erasedExpr - | _, _ => return erasedExpr + return (mkConst ``lcAny) + | _, _ => return (mkConst ``lcAny) end diff --git a/tests/lean/lcnfTypes.lean.expected.out b/tests/lean/lcnfTypes.lean.expected.out index fd19f152d1..23460d8668 100644 --- a/tests/lean/lcnfTypes.lean.expected.out +++ b/tests/lean/lcnfTypes.lean.expected.out @@ -1,20 +1,20 @@ -Vec.zip : {α : Type u_1} → {n : Nat} → {β : Type u_2} → Vec α ◾ → Vec β ◾ → Vec (α × β) ◾ -mkConstTuple : {α : Type u_1} → α → Nat → ◾ -Fin.add : {n : Nat} → Fin ◾ → Fin ◾ → Fin ◾ -Vec.cons : {α : Type u} → {n : Nat} → α → Vec α ◾ → Vec α ◾ -Eq.rec : {α : Sort u_1} → {a : α} → {motive : α → ◾ → Sort u} → motive ◾ ◾ → {a : α} → ◾ → motive ◾ ◾ +Vec.zip : {α : Type u_1} → {n : Nat} → {β : Type u_2} → Vec α lcAny → Vec β lcAny → Vec (α × β) lcAny +mkConstTuple : {α : Type u_1} → α → Nat → lcAny +Fin.add : {n : Nat} → Fin lcAny → Fin lcAny → Fin lcAny +Vec.cons : {α : Type u} → {n : Nat} → α → Vec α lcAny → Vec α lcAny +Eq.rec : {α : Sort u_1} → {a : α} → {motive : α → ◾ → Sort u} → motive lcAny lcAny → {a : α} → ◾ → motive lcAny lcAny GetElem.getElem : {coll : Type u} → {idx : Type v} → {elem : Type w} → {valid : coll → idx → Prop} → [self : GetElem coll idx elem ◾] → coll → idx → ◾ → elem -Term.constFold : {ctx : List Ty} → {ty : Ty} → _root_.Term ◾ ◾ → _root_.Term ◾ ◾ -Term.denote : {ctx : List Ty} → {ty : Ty} → _root_.Term ◾ ◾ → HList ◾ ◾ → ◾ -HList.get : {α : Type u_1} → {β : α → Type u_2} → {is : List α} → {i : α} → HList β ◾ → Member ◾ ◾ → β ◾ -Member.head : {α : Type u_1} → {a : α} → {as : List α} → Member ◾ ◾ +Term.constFold : {ctx : List Ty} → {ty : Ty} → _root_.Term lcAny lcAny → _root_.Term lcAny lcAny +Term.denote : {ctx : List Ty} → {ty : Ty} → _root_.Term lcAny lcAny → HList lcAny lcAny → lcAny +HList.get : {α : Type u_1} → {β : α → Type u_2} → {is : List α} → {i : α} → HList β lcAny → Member lcAny lcAny → β lcAny +Member.head : {α : Type u_1} → {a : α} → {as : List α} → Member lcAny lcAny Ty.denote : Ty → Type MonadControl.liftWith : {m : Type u → Type v} → - {n : Type u → Type w} → [self : MonadControl m n] → {α : Type u} → (({β : Type u} → n β → m ◾) → m α) → n α -MonadControl.restoreM : {m : Type u → Type v} → {n : Type u → Type w} → [self : MonadControl m n] → {α : Type u} → m ◾ → n α -Decidable.casesOn : {p : Prop} → {motive : Decidable ◾ → Sort u} → Decidable ◾ → (◾ → motive ◾) → (◾ → motive ◾) → motive ◾ + {n : Type u → Type w} → [self : MonadControl m n] → {α : Type u} → (({β : Type u} → n β → m lcAny) → m α) → n α +MonadControl.restoreM : {m : Type u → Type v} → {n : Type u → Type w} → [self : MonadControl m n] → {α : Type u} → m lcAny → n α +Decidable.casesOn : {p : Prop} → {motive : Decidable ◾ → Sort u} → Decidable ◾ → (◾ → motive lcAny) → (◾ → motive lcAny) → motive lcAny Lean.getConstInfo : {m : Type → Type} → [Monad m] → [MonadEnv m] → [MonadError m] → Name → m ConstantInfo Lean.Meta.instMonadMetaM : Monad fun α => Context → ST.Ref PUnit State → Core.Context → ST.Ref PUnit Core.State → PUnit → EStateM.Result Exception PUnit α @@ -29,26 +29,25 @@ Lean.Elab.Term.elabTerm : Syntax → Context → ST.Ref PUnit State → Core.Context → ST.Ref PUnit Core.State → PUnit → EStateM.Result Exception PUnit Expr Nat.add : Nat → Nat → Nat -Magma.mul : Magma → ◾ → ◾ → ◾ -weird1 : Bool → ◾ -lamAny₁ : Bool → Monad ◾ -lamAny₂ : Bool → Monad ◾ -Term.constFold : List Ty → Ty → _root_.Term lcErased lcErased → _root_.Term lcErased lcErased -Term.denote : lcErased -HList.get : lcErased → lcErased → List lcAny → lcAny → HList lcAny lcErased lcErased → Member lcAny lcErased lcErased → lcAny -Member.head : lcErased → lcAny → List lcAny → Member lcAny lcErased lcErased +Magma.mul : Magma → lcAny → lcAny → lcAny +weird1 : Bool → lcAny +lamAny₁ : Bool → Monad fun α => lcAny +lamAny₂ : Bool → Monad lcAny +Term.constFold : List Ty → Ty → _root_.Term lcAny lcAny → _root_.Term lcAny lcAny +Term.denote : List Ty → Ty → _root_.Term lcAny lcAny → HList Ty lcAny lcAny → lcAny +HList.get : lcErased → lcErased → List lcAny → lcAny → HList lcAny lcAny lcAny → Member lcAny lcAny lcAny → lcAny +Member.head : lcErased → lcAny → List lcAny → Member lcAny lcAny lcAny Ty.denote : lcErased -MonadControl.liftWith : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → ((lcErased → lcAny → lcAny) → lcAny) → lcAny -MonadControl.restoreM : lcErased → lcErased → MonadControl lcErased lcErased → lcErased → lcAny → lcAny +MonadControl.liftWith : lcErased → lcErased → MonadControl lcAny lcAny → lcErased → ((lcErased → lcAny → lcAny) → lcAny) → lcAny +MonadControl.restoreM : lcErased → lcErased → MonadControl lcAny lcAny → lcErased → lcAny → lcAny Decidable.casesOn : lcErased → lcErased → Bool → (lcErased → lcAny) → (lcErased → lcAny) → lcAny -Lean.getConstInfo : lcErased → Monad lcErased → MonadEnv lcErased → MonadError lcErased → Name → lcAny -Lean.Meta.instMonadMetaM : Monad lcErased -Lean.Meta.inferType : Expr → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr +Lean.getConstInfo : lcErased → Monad lcAny → MonadEnv lcAny → MonadError lcAny → Name → lcAny +Lean.Meta.instMonadMetaM : Monad lcAny +Lean.Meta.inferType : Expr → Context → lcAny → Core.Context → lcAny → PUnit → EStateM.Result Exception PUnit Expr Lean.Elab.Term.elabTerm : Syntax → Option Expr → Bool → Bool → - Elab.Term.Context → - lcErased → Context → lcErased → Core.Context → lcErased → PUnit → EStateM.Result Exception PUnit Expr + Elab.Term.Context → lcAny → Context → lcAny → Core.Context → lcAny → PUnit → EStateM.Result Exception PUnit Expr Nat.add : Nat → Nat → Nat Fin.add : Nat → Nat → Nat → Nat diff --git a/tests/lean/run/erased.lean b/tests/lean/run/erased.lean index 1028112390..642da1d034 100644 --- a/tests/lean/run/erased.lean +++ b/tests/lean/run/erased.lean @@ -21,8 +21,8 @@ set_option pp.letVarTypes true set_option trace.Compiler.result true /-- info: [Compiler.result] size: 1 - def Erased.mk (α : lcErased) (a : lcAny) : PSigma lcErased lcErased := - let _x.1 : PSigma lcErased lcErased := PSigma.mk lcErased ◾ ◾ ◾; + def Erased.mk (α : lcErased) (a : lcAny) : PSigma lcErased lcAny := + let _x.1 : PSigma lcErased lcAny := PSigma.mk lcErased ◾ ◾ ◾; return _x.1 -/ #guard_msgs in