From 72c576f62ab3d2e2850cfde6ea613e0f57ae9841 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 16 Oct 2022 16:00:33 -0700 Subject: [PATCH] feat: complete `reduceArity` pass --- src/Lean/Compiler/LCNF/ReduceArity.lean | 85 +++++++++++++++++++++--- tests/lean/holes.lean.expected.out | 10 +-- tests/lean/jason1.lean.expected.out | 76 ++++++++++----------- tests/lean/reduceArity.lean | 14 ++++ tests/lean/reduceArity.lean.expected.out | 17 +++++ 5 files changed, 149 insertions(+), 53 deletions(-) create mode 100644 tests/lean/reduceArity.lean create mode 100644 tests/lean/reduceArity.lean.expected.out diff --git a/src/Lean/Compiler/LCNF/ReduceArity.lean b/src/Lean/Compiler/LCNF/ReduceArity.lean index 4e7933e23b..2bdc19f697 100644 --- a/src/Lean/Compiler/LCNF/ReduceArity.lean +++ b/src/Lean/Compiler/LCNF/ReduceArity.lean @@ -5,10 +5,10 @@ Authors: Leonardo de Moura -/ import Lean.Compiler.LCNF.CompilerM import Lean.Compiler.LCNF.PhaseExt +import Lean.Compiler.LCNF.InferType +import Lean.Compiler.LCNF.Internalize namespace Lean.Compiler.LCNF -namespace ReduceArity - /-! # Function arity reduction @@ -47,6 +47,7 @@ def f (x y : Nat) : Nat := This module does not have similar support for mutual recursive applications. We assume this limitation is irrelevant in practice. -/ +namespace FindUsed structure Context where decl : Decl @@ -79,8 +80,13 @@ def visitExpr (e : Expr) : FindUsedM Unit := do for param in decl.params, arg in args do unless arg.isFVarOf param.fvarId do visitArg arg + -- over-application for arg in args[decl.params.size:] do visitArg arg + -- partial-application + for param in decl.params[args.size:] do + -- If recursive function is partially applied, we assume missing parameters are used because we don't want to eta-expand. + visitFVar param.fvarId else visitArg f args.forM visitArg @@ -104,20 +110,79 @@ def collectUsedParams (decl : Decl) : CompilerM FVarIdSet := do let (_, { used, .. }) ← visit decl.value |>.run { decl, params } |>.run {} return used -end ReduceArity -open ReduceArity +end FindUsed -def Decl.reduceArity (decl : Decl) : CompilerM Decl := do +namespace ReduceArity + +structure Context where + declName : Name + auxDeclName : Name + paramMask : Array Bool + +abbrev ReduceM := ReaderT Context CompilerM + +partial def reduce (code : Code) : ReduceM Code := do + match code with + | .let decl k => + if decl.value.isAppOf (← read).declName then + let mut args := #[] + for used in (← read).paramMask, arg in decl.value.getAppArgs do + if used then + args := args.push arg + let decl ← decl.updateValue (mkAppN (mkConst (← read).auxDeclName) args) + return code.updateLet! decl (← reduce k) + else + return code.updateLet! decl (← reduce k) + | .fun decl k | .jp decl k => + let decl ← decl.updateValue (← reduce decl.value) + return code.updateFun! decl (← reduce k) + | .cases c => + let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← reduce alt.getCode) + return code.updateAlts! alts + | .unreach .. | .jmp .. | .return .. => return code + +end ReduceArity + +open FindUsed ReduceArity Internalize + +def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do let used ← collectUsedParams decl if used.size == decl.params.size then - return decl -- Declarations uses all parameters + return #[decl] -- Declarations uses all parameters else trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}" - -- TODO: create auxiliary function wth used parameters only - return decl + let mask := decl.params.map fun param => used.contains param.fvarId + let auxName := decl.name ++ `_redArg + let mkAuxDecl : CompilerM Decl := do + let params := decl.params.filter fun param => used.contains param.fvarId + let value ← reduce decl.value |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask } + let type ← value.inferType + let type ← mkForallParams params type + let auxDecl := { decl with name := auxName, levelParams := [], type, params, value } + auxDecl.saveMono + return auxDecl + let updateDecl : InternalizeM Decl := do + let params ← decl.params.mapM internalizeParam + let mut args := #[] + for used in mask, param in params do + if used then + args := args.push param.toExpr + let letDecl ← mkAuxLetDecl (mkAppN (mkConst auxName) args) + let value := .let letDecl (.return letDecl.fvarId) + let decl := { decl with params, value, inlineAttr? := some .inline, recursive := false } + decl.saveMono + return decl + let unusedParams := decl.params.filter fun param => !used.contains param.fvarId + let auxDecl ← mkAuxDecl + let decl ← updateDecl |>.run' {} + eraseParams unusedParams + return #[auxDecl, decl] -def reduceArity : Pass := - .mkPerDeclaration `reduceArity (Decl.reduceArity ·) .mono +def reduceArity : Pass where + phase := .mono + name := `reduceArity + run := fun decls => do + decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.reduceArity) builtin_initialize registerTraceClass `Compiler.reduceArity (inherited := true) diff --git a/tests/lean/holes.lean.expected.out b/tests/lean/holes.lean.expected.out index ce202d9c24..1fbb76e993 100644 --- a/tests/lean/holes.lean.expected.out +++ b/tests/lean/holes.lean.expected.out @@ -9,17 +9,17 @@ holes.lean:10:9-10:12: error: don't know how to synthesize implicit argument context: x : Nat ⊢ Type -holes.lean:11:4-11:5: error: don't know how to synthesize placeholder -context: -x : Nat -y : Nat := g x + g x -⊢ Nat holes.lean:11:8-11:13: error: don't know how to synthesize placeholder context: case hole x : Nat y : Nat := g x + g x ⊢ Nat +holes.lean:11:4-11:5: error: don't know how to synthesize placeholder +context: +x : Nat +y : Nat := g x + g x +⊢ Nat holes.lean:13:10-13:11: error: failed to infer binder type holes.lean:15:16-15:17: error: failed to infer binder type holes.lean:19:0-19:3: error: don't know how to synthesize implicit argument diff --git a/tests/lean/jason1.lean.expected.out b/tests/lean/jason1.lean.expected.out index 7022bdeaca..c75eebeb58 100644 --- a/tests/lean/jason1.lean.expected.out +++ b/tests/lean/jason1.lean.expected.out @@ -1,4 +1,4 @@ -jason1.lean:47:40-47:57: error: don't know how to synthesize implicit argument +jason1.lean:48:100-48:117: error: don't know how to synthesize implicit argument @TySyntaxLayer.nat G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝) context: G T Tm : Type @@ -12,6 +12,42 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G TAlgebra : TySyntaxLayer G T EG getCtx → T Γ✝ : G ⊢ G +jason1.lean:48:119-48:124: error: don't know how to synthesize implicit argument + @EGrfl + (getCtx + (TAlgebra + (@TySyntaxLayer.nat G T EG getCtx + (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)))) +context: +G T Tm : Type +EG : G → G → Type +ET : T → T → Type +ETm : Tm → Tm → Type +EGrfl : {Γ : G} → EG Γ Γ +getCtx : T → G +getTy : Tm → T +GAlgebra : CtxSyntaxLayer G T EG getCtx → G +TAlgebra : TySyntaxLayer G T EG getCtx → T +Γ✝ : G +⊢ G +jason1.lean:48:125-48:130: error: don't know how to synthesize implicit argument + @EGrfl + (getCtx + (TAlgebra + (@TySyntaxLayer.nat G T EG getCtx + (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)))) +context: +G T Tm : Type +EG : G → G → Type +ET : T → T → Type +ETm : Tm → Tm → Type +EGrfl : {Γ : G} → EG Γ Γ +getCtx : T → G +getTy : Tm → T +GAlgebra : CtxSyntaxLayer G T EG getCtx → G +TAlgebra : TySyntaxLayer G T EG getCtx → T +Γ✝ : G +⊢ G jason1.lean:48:41-48:130: error: don't know how to synthesize implicit argument @TySyntaxLayer.arrow G T EG getCtx (getCtx @@ -58,7 +94,7 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G TAlgebra : TySyntaxLayer G T EG getCtx → T Γ✝ : G ⊢ G -jason1.lean:48:100-48:117: error: don't know how to synthesize implicit argument +jason1.lean:47:40-47:57: error: don't know how to synthesize implicit argument @TySyntaxLayer.nat G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝) context: G T Tm : Type @@ -72,42 +108,6 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G TAlgebra : TySyntaxLayer G T EG getCtx → T Γ✝ : G ⊢ G -jason1.lean:48:119-48:124: error: don't know how to synthesize implicit argument - @EGrfl - (getCtx - (TAlgebra - (@TySyntaxLayer.nat G T EG getCtx - (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)))) -context: -G T Tm : Type -EG : G → G → Type -ET : T → T → Type -ETm : Tm → Tm → Type -EGrfl : {Γ : G} → EG Γ Γ -getCtx : T → G -getTy : Tm → T -GAlgebra : CtxSyntaxLayer G T EG getCtx → G -TAlgebra : TySyntaxLayer G T EG getCtx → T -Γ✝ : G -⊢ G -jason1.lean:48:125-48:130: error: don't know how to synthesize implicit argument - @EGrfl - (getCtx - (TAlgebra - (@TySyntaxLayer.nat G T EG getCtx - (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)))) -context: -G T Tm : Type -EG : G → G → Type -ET : T → T → Type -ETm : Tm → Tm → Type -EGrfl : {Γ : G} → EG Γ Γ -getCtx : T → G -getTy : Tm → T -GAlgebra : CtxSyntaxLayer G T EG getCtx → G -TAlgebra : TySyntaxLayer G T EG getCtx → T -Γ✝ : G -⊢ G jason1.lean:46:40-46:57: error: don't know how to synthesize implicit argument @TySyntaxLayer.top G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝) context: diff --git a/tests/lean/reduceArity.lean b/tests/lean/reduceArity.lean new file mode 100644 index 0000000000..8e125db5dd --- /dev/null +++ b/tests/lean/reduceArity.lean @@ -0,0 +1,14 @@ +import Lean +open Lean Compiler LCNF +@[noinline] def double (x : Nat) := x + x + +set_option pp.funBinderTypes true +set_option trace.Compiler.result true in +def g (n : Nat) (a b : α) (f : α → α) := + match n with + | 0 => a + | n+1 => f (g n a b f) + +set_option trace.Compiler.result true in +def h (n : Nat) (a : Nat) := + g n a a double + g a n n double diff --git a/tests/lean/reduceArity.lean.expected.out b/tests/lean/reduceArity.lean.expected.out new file mode 100644 index 0000000000..526c156999 --- /dev/null +++ b/tests/lean/reduceArity.lean.expected.out @@ -0,0 +1,17 @@ +[Compiler.result] size: 5 + def g._redArg (n : Nat) (a : ◾) (f : ◾ → ◾) : ◾ := + cases n : ◾ + | Nat.zero => + a + | Nat.succ (n.1 : Nat) => + let _x.2 := g._redArg n.1 a f + let _x.3 := f _x.2 + _x.3 +[Compiler.result] size: 1 def g (α : ◾) (n : Nat) (a : ◾) (b : ◾) (f : ◾ → ◾) : ◾ := let _x.1 := g._redArg n a f _x.1 +[Compiler.result] size: 4 + def h (n : Nat) (a : Nat) : Nat := + let _x.1 := double + let _x.2 := g._redArg n a _x.1 + let _x.3 := g._redArg a n _x.1 + let _x.4 := Nat.add _x.2 _x.3 + _x.4