perf: improve over-applied cases in ToLCNF (#12284)

This PR changes the handling of over-applied cases expressions in
`ToLCNF` to avoid generating function declarations that are called
immediately. For example, `ToLCNF` previously produced this:
```lean-4
set_option trace.Compiler.init true
/--
trace: [Compiler.init] size: 4
    def test x y : Bool :=
      fun _y.1 _y.2 : Bool :=
        cases x : Bool
        | PUnit.unit =>
          fun _f.3 a : Bool :=
            return a;
          let _x.4 := _f.3 _y.2;
          return _x.4;
      let _x.5 := _y.1 y;
      return _x.5
-/
#guard_msgs in
def test (x : Unit) (y : Bool) : Bool :=
  x.casesOn (fun a => a) y
```
which is now simplified to
```lean-4
set_option trace.Compiler.init true
/--
trace: [Compiler.init] size: 3
    def test x y : Bool :=
      cases x : Bool
      | PUnit.unit =>
        let a := y;
        return a
-/
#guard_msgs in
def test (x : Unit) (y : Bool) : Bool :=
  x.casesOn (fun a => a) y
```
This is especially relevant for #8309 because there `dite` is defined as
an over-applied `Bool.casesOn`.
This commit is contained in:
Rob23oba 2026-02-06 10:27:15 +01:00 committed by GitHub
parent 71e340eb97
commit 9b7a8eb7c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 61 additions and 14 deletions

View file

@ -348,15 +348,15 @@ def mkParam (binderName : Name) (type : Expr) : M (Param .pure) := do
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure) :
M (LetDecl .pure) := do
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (arg : Arg .pure)
(nondep : Bool) : M (LetDecl .pure) := do
let binderName ← cleanupBinderName binderName
let value' ← match arg with
| .fvar fvarId => pure <| .fvar fvarId #[]
| .erased | .type .. => pure .erased
let letDecl ← LCNF.mkLetDecl binderName type' value'
modify fun s => { s with
lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value false
lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value nondep
seq := s.seq.push <| .let letDecl
}
return letDecl
@ -385,6 +385,38 @@ where
else
return (ps, e.instantiateRev xs)
/--
Given `e` and `args` where `mkAppN e (args.map (·.toExpr))` is not necessarily well-typed
(because of dependent typing), returns `e.beta args'` where `args'` are new local declarations each
assigned to a value in `args` with adjusted type (such that the resulting expression is well-typed).
-/
def mkTypeCorrectApp (e : Expr) (args : Array (Arg .pure)) : M Expr := do
if args.isEmpty then
return e
let type ← liftMetaM <| do
let type ← Meta.inferType e
if type.getNumHeadForalls < args.size then
-- expose foralls
Meta.forallBoundedTelescope type args.size Meta.mkForallFVars
else
return type
go type 0 #[]
where
go (type : Expr) (i : Nat) (xs : Array Expr) : M Expr := do
if h : i < args.size then
match type with
| .forallE nm t b bi =>
let t := t.instantiateRev xs
let arg := args[i]
if ← liftMetaM <| Meta.isProp t then
go b (i + 1) (xs.push (mkLcProof t))
else
let decl ← mkLetDecl nm t arg.toExpr (← arg.inferType) arg (nondep := true)
go b (i + 1) (xs.push (.fvar decl.fvarId))
| _ => liftMetaM <| Meta.throwFunctionExpected (mkAppN e xs)
else
return e.beta xs
def mustEtaExpand (env : Environment) (e : Expr) : Bool :=
if let .const declName _ := e.getAppFn then
match env.find? declName with
@ -526,7 +558,7 @@ where
k args[arity...*]
```
-/
mkOverApplication (app : (Arg .pure)) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do
mkOverApplication (app : Arg .pure) (args : Array Expr) (arity : Nat) : M (Arg .pure) := do
if args.size == arity then
return app
else
@ -541,11 +573,14 @@ where
/--
Visit a `matcher`/`casesOn` alternative.
-/
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) : M (Expr × (Alt .pure)) := do
visitAlt (casesAltInfo : CasesAltInfo) (e : Expr) (overArgs : Array (Arg .pure)) :
M (Expr × (Alt .pure)) := do
withNewScope do
match casesAltInfo with
| .default numHyps =>
let c ← toCode (← visit (mkAppN e (Array.replicate numHyps erasedExpr)))
let e := mkAppN e (Array.replicate numHyps erasedExpr)
let e ← mkTypeCorrectApp e overArgs
let c ← toCode (← visit e)
let altType ← c.inferType
return (altType, .default c)
| .ctor ctorName numParams =>
@ -555,6 +590,7 @@ where
let (ps', e') ← ToLCNF.visitLambda e
ps := ps ++ ps'
e := e'
e ← mkTypeCorrectApp e overArgs
/-
Insert the free variable ids of fields that are type formers into `toAny`.
Recall that we do not want to have "data" occurring in types.
@ -579,7 +615,8 @@ where
visitCases (casesInfo : CasesInfo) (e : Expr) : M (Arg .pure) :=
etaIfUnderApplied e casesInfo.arity do
let args := e.getAppArgs
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[*...casesInfo.arity]))
let overArgs ← (args.drop casesInfo.arity).mapM visitAppArg
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args))
let typeName := casesInfo.indName
let .inductInfo indVal ← getConstInfo typeName | unreachable!
if casesInfo.numAlts == 0 then
@ -609,8 +646,7 @@ where
fieldArgs := fieldArgs.push fieldArg
return fieldArgs
let f := args[casesInfo.altsRange.lower]!
let result ← visit (mkAppN f fieldArgs)
mkOverApplication result args casesInfo.arity
visit (mkAppN (mkAppN f fieldArgs) (overArgs.map (·.toExpr)))
else
let mut alts := #[]
let discr ← visitAppArg args[casesInfo.discrPos]!
@ -618,14 +654,13 @@ where
| .fvar discrFVarId => pure discrFVarId
| .erased | .type .. => mkAuxLetDecl .erased
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams do
let (altType, alt) ← visitAlt numParams args[i]!
let (altType, alt) ← visitAlt numParams args[i]! overArgs
resultType := joinTypes altType resultType
alts := alts.push alt
let cases := ⟨typeName, resultType, discrFVarId, alts⟩
let auxDecl ← mkAuxParam resultType
pushElement (.cases auxDecl cases)
let result := .fvar auxDecl.fvarId
mkOverApplication result args casesInfo.arity
return .fvar auxDecl.fvarId
visitCtor (arity : Nat) (e : Expr) : M (Arg .pure) :=
etaIfUnderApplied e arity do
@ -843,14 +878,14 @@ where
visitLet (e : Expr) (xs : Array Expr) : M (Arg .pure) := do
match e with
| .letE binderName type value body _ =>
| .letE binderName type value body nondep =>
let type := type.instantiateRev xs
let value := value.instantiateRev xs
if (← (liftMetaM <| Meta.isProp type) <||> isTypeFormerType type) then
visitLet body (xs.push value)
else
let type' ← toLCNFType type
let letDecl ← mkLetDecl binderName type value type' (← visit value)
let letDecl ← mkLetDecl binderName type value type' (← visit value) nondep
visitLet body (xs.push (.fvar letDecl.fvarId))
| _ =>
let e := e.instantiateRev xs

12
tests/lean/run/12284.lean Normal file
View file

@ -0,0 +1,12 @@
set_option trace.Compiler.init true
/--
trace: [Compiler.init] size: 3
def test x y : Bool :=
cases x : Bool
| PUnit.unit =>
let a := y;
return a
-/
#guard_msgs in
def test (x : Unit) (y : Bool) : Bool :=
x.casesOn (fun a => a) y