From 0f40dfc0637e402e846f8a41c96457d2dc0bef99 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 27 Aug 2022 08:49:22 -0700 Subject: [PATCH] feat: add `FunDecl.etaExpand` --- src/Lean/Compiler/LCNF/Bind.lean | 24 ++++++++++++++++++++++++ src/Lean/Compiler/LCNF/InferType.lean | 21 +++++++++++++++------ src/Lean/Compiler/LCNF/ToLCNF.lean | 22 +++++++++++++--------- src/Lean/Compiler/LCNF/Types.lean | 5 +++++ 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index a392fa480d..8158e98154 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -46,4 +46,28 @@ where return c | .unreach .. => return c +def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do + let typeArity := getArrowArity decl.type + let valueArity := decl.getArity + if typeArity <= valueArity then + -- It can be < because of the "any" type + return decl + else + let valueType ← instantiateForall decl.type decl.params + let psNew ← mkNewParams valueType #[] #[] + let params := decl.params ++ psNew + let xs := psNew.map fun p => Expr.fvar p.fvarId + let value ← decl.value.bind fun fvarId => do + let auxDecl ← mkAuxLetDecl (mkAppN (.fvar fvarId) xs) + return .let auxDecl (.return auxDecl.fvarId) + decl.update decl.type params value +where + mkNewParams (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do + match type with + | .forallE _ d b _ => + let d := d.instantiateRev xs + let p ← mkAuxParam d + mkNewParams b (xs.push (.fvar p.fvarId)) (ps.push p) + | _ => return ps + end Lean.Compiler.LCNF \ No newline at end of file diff --git a/src/Lean/Compiler/LCNF/InferType.lean b/src/Lean/Compiler/LCNF/InferType.lean index 6fef698e90..ded4239e02 100644 --- a/src/Lean/Compiler/LCNF/InferType.lean +++ b/src/Lean/Compiler/LCNF/InferType.lean @@ -182,12 +182,8 @@ def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr := def AltCore.inferType (alt : Alt) : CompilerM Expr := alt.getCode.inferType -def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do - if e.isFVar then - return e - else - let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e - return .fvar letDecl.fvarId +def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM LetDecl := do + mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr := InferType.mkForallParams params type |>.run {} @@ -206,4 +202,17 @@ def mkAuxJpDecl' (fvarId : FVarId) (code : Code) (prefixName := `_jp) : Compiler let params := #[{ fvarId, binderName := y, type := yType }] mkAuxFunDecl params code prefixName +def instantiateForall (type : Expr) (params : Array Param) : CoreM Expr := + go type 0 +where + go (type : Expr) (i : Nat) : CoreM Expr := + if h : i < params.size then + let p := params[i] + match type with + | .forallE _ _ b _ => go (b.instantiate1 (.fvar p.fvarId)) (i+1) + | _ => throwError "invalid instantiateForall, too many parameters" + else + return type +termination_by go i => params.size - i + end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index 31a0b32b0f..cb1b1ec2c6 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -37,28 +37,32 @@ inductive Element where deriving Inhabited def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do - let e ← mkAuxLetDecl e - go seq.size (.return e.fvarId!) + if let .fvar fvarId := e then + go seq seq.size (.return fvarId) + else + let decl ← mkAuxLetDecl e + let seq := seq.push (.let decl) + go seq seq.size (.return decl.fvarId) where - go (i : Nat) (c : Code) : CompilerM Code := do + go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do if i > 0 then match seq[i-1]! with - | .jp decl => go (i - 1) (.jp decl c) - | .fun decl => go (i - 1) (.fun decl c) - | .let decl => go (i - 1) (.let decl c) + | .jp decl => go seq (i - 1) (.jp decl c) + | .fun decl => go seq (i - 1) (.fun decl c) + | .let decl => go seq (i - 1) (.let decl c) | .unreach => return .unreach (← c.inferType) | .cases fvarId cases => if let .return fvarId' := c then if fvarId == fvarId' then - go (i - 1) (.cases cases) + go seq (i - 1) (.cases cases) else -- `cases` is dead code - go (i - 1) c + go seq (i - 1) c else /- Create a join point for `c` and jump to it from `cases` -/ let jpDecl ← mkAuxJpDecl' fvarId c let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] - go (i - 1) (.jp jpDecl cases) + go seq (i - 1) (.jp jpDecl cases) else return c diff --git a/src/Lean/Compiler/LCNF/Types.lean b/src/Lean/Compiler/LCNF/Types.lean index 117276897d..414029a3ea 100644 --- a/src/Lean/Compiler/LCNF/Types.lean +++ b/src/Lean/Compiler/LCNF/Types.lean @@ -237,4 +237,9 @@ def isClass? (type : Expr) : CoreM (Option Name) := do else return none +def getArrowArity (e : Expr) := + match e with + | .forallE _ _ b _ => getArrowArity b + 1 + | _ => 0 + end Lean.Compiler.LCNF \ No newline at end of file