diff --git a/src/Lean/Compiler/CompilerM.lean b/src/Lean/Compiler/CompilerM.lean index 3e317cac41..9956936f2d 100644 --- a/src/Lean/Compiler/CompilerM.lean +++ b/src/Lean/Compiler/CompilerM.lean @@ -152,6 +152,20 @@ def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) := do let (fvars, e) ← visitLambdaCore e return (fvars, e.instantiateRev fvars) +/-- +Similar to `visitLambda` but for arrow-types. +-/ +def visitArrow (type : Expr) : CompilerM (Array Expr × Expr) := do + go type #[] +where + go (type : Expr) (fvars : Array Expr) := do + if let .forallE binderName type body binderInfo := type then + let type := type.instantiateRev fvars + let fvar ← mkLocalDecl binderName type binderInfo + go body (fvars.push fvar) + else + return (fvars, type.instantiateRev fvars) + /-- Given an expression representing a `match` return a tuple consisting of: 1. The motive diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index c7e4a76e04..acc798416f 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -264,11 +264,14 @@ def isOnceOrMustInline (binderName : Name) : SimpM Bool := do | some .once | some .mustInline => return true | _ => return false +def isSmallValue (value : Expr) : SimpM Bool := do + lcnfSizeLe value (← read).config.smallThreshold + def shouldInlineLocal (localDecl : LocalDecl) : SimpM Bool := do if (← isOnceOrMustInline localDecl.userName) then return true else - lcnfSizeLe localDecl.value (← read).config.smallThreshold + isSmallValue localDecl.value structure InlineCandidateInfo where isLocal : Bool @@ -365,6 +368,23 @@ private def simpUsingEtaReduction (e : Expr) : Expr := | .letE n t v b d => .letE n t v (simpUsingEtaReduction b) d | _ => e +private def etaExpand (type : Expr) (value : Expr) : CompilerM Expr := do + let typeArity := getArrowArity type + let valueArity := getLambdaArity value + if typeArity <= valueArity then + -- It can be < because of the "any" type + return value + else + withNewScope do + let (xs, _) ← visitArrow type + let value := getLambdaBody value + let value := value.instantiateRev xs[:valueArity] + let valueType ← inferType value + let f ← mkLocalDecl (← mkFreshUserName `_f) valueType + let k ← mkLambda #[f] (mkAppN f xs[valueArity:]) + let value ← attachJp value k + mkLambda xs value + /-- Auxiliary function for projecting "type class dictionary access". That is, we are trying to extract one of the type class instance elements. @@ -434,9 +454,6 @@ where def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do -- TODO: add necessary casts to `args` - let rec getLambdaBody : Expr → Expr - | .lam _ _ b _ => getLambdaBody b - | b => b let result ← instantiateRevInternalize (getLambdaBody e) args trace[Meta.debug] "inline:\n{result}" return result @@ -632,6 +649,7 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do modify fun s => { s with counter := s.counter + 1 } match e with | .letE binderName type value body nonDep => + let type := type.instantiateRev xs let mut value := value.instantiateRev xs if value.isLambda then unless (← isOnceOrMustInline binderName) do @@ -640,6 +658,25 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do we do it after its is inlined and we have information about the actual arguments. -/ value ← visitLambda value + unless isJpBinderName binderName || (← isSmallValue value) do + /- + This lambda is not going to be inlined. So, we eta-expand it IF it is not a join point. + Recall that local function declarations that are not join points will be lambda lifted + anyway. Eta-expanding here also creates new simplification opportunities for + monadic local functions before we perform the lambda-lifting. + For example, consider the local function + ``` + let _x.23 := fun xs body => + ... + let _x.29 := StateRefT'.lift _x.24 + let _x.30 := _x.25 _x.29 + let _x.31 := fun a => ... + ReaderT.bind _x.30 _x.31 + ``` + The function applications `StateRefT'.lift` and `ReaderT.bind` are not inlined because + they are partially applied. After, we eta-expand this code, it will be reduced at this stage. + -/ + value ← etaExpand type value else if let some value' ← simpValue? value then if value'.isLet then let e := mkFlatLet binderName type value' body nonDep @@ -653,7 +690,6 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do else if let some e ← inlineApp? value xs body then return e else - let type := type.instantiateRev xs let x ← mkLetDecl binderName type value nonDep visitLet body (xs.push x) | _ => diff --git a/src/Lean/Compiler/Util.lean b/src/Lean/Compiler/Util.lean index 467cbf29fc..f44b0ac225 100644 --- a/src/Lean/Compiler/Util.lean +++ b/src/Lean/Compiler/Util.lean @@ -170,11 +170,19 @@ where return false return true +def getArrowArity (e : Expr) := + match e with + | .forallE _ _ b _ => getArrowArity b + 1 + | _ => 0 + def getLambdaArity (e : Expr) := match e with | .lam _ _ b _ => getLambdaArity b + 1 | _ => 0 +def getLambdaBody : Expr → Expr + | .lam _ _ b _ => getLambdaBody b + | b => b /-- Whether a given local declaration is a join point.