feat: complete reduceArity pass

This commit is contained in:
Leonardo de Moura 2022-10-16 16:00:33 -07:00
parent 1a02c326e5
commit 72c576f62a
5 changed files with 149 additions and 53 deletions

View file

@ -5,10 +5,10 @@ Authors: Leonardo de Moura
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.InferType
import Lean.Compiler.LCNF.Internalize
namespace Lean.Compiler.LCNF
namespace ReduceArity
/-!
# Function arity reduction
@ -47,6 +47,7 @@ def f (x y : Nat) : Nat :=
This module does not have similar support for mutual recursive applications.
We assume this limitation is irrelevant in practice.
-/
namespace FindUsed
structure Context where
decl : Decl
@ -79,8 +80,13 @@ def visitExpr (e : Expr) : FindUsedM Unit := do
for param in decl.params, arg in args do
unless arg.isFVarOf param.fvarId do
visitArg arg
-- over-application
for arg in args[decl.params.size:] do
visitArg arg
-- partial-application
for param in decl.params[args.size:] do
-- If recursive function is partially applied, we assume missing parameters are used because we don't want to eta-expand.
visitFVar param.fvarId
else
visitArg f
args.forM visitArg
@ -104,20 +110,79 @@ def collectUsedParams (decl : Decl) : CompilerM FVarIdSet := do
let (_, { used, .. }) ← visit decl.value |>.run { decl, params } |>.run {}
return used
end ReduceArity
open ReduceArity
end FindUsed
def Decl.reduceArity (decl : Decl) : CompilerM Decl := do
namespace ReduceArity
structure Context where
declName : Name
auxDeclName : Name
paramMask : Array Bool
abbrev ReduceM := ReaderT Context CompilerM
partial def reduce (code : Code) : ReduceM Code := do
match code with
| .let decl k =>
if decl.value.isAppOf (← read).declName then
let mut args := #[]
for used in (← read).paramMask, arg in decl.value.getAppArgs do
if used then
args := args.push arg
let decl ← decl.updateValue (mkAppN (mkConst (← read).auxDeclName) args)
return code.updateLet! decl (← reduce k)
else
return code.updateLet! decl (← reduce k)
| .fun decl k | .jp decl k =>
let decl ← decl.updateValue (← reduce decl.value)
return code.updateFun! decl (← reduce k)
| .cases c =>
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← reduce alt.getCode)
return code.updateAlts! alts
| .unreach .. | .jmp .. | .return .. => return code
end ReduceArity
open FindUsed ReduceArity Internalize
def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
let used ← collectUsedParams decl
if used.size == decl.params.size then
return decl -- Declarations uses all parameters
return #[decl] -- Declarations uses all parameters
else
trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}"
-- TODO: create auxiliary function wth used parameters only
return decl
let mask := decl.params.map fun param => used.contains param.fvarId
let auxName := decl.name ++ `_redArg
let mkAuxDecl : CompilerM Decl := do
let params := decl.params.filter fun param => used.contains param.fvarId
let value ← reduce decl.value |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask }
let type ← value.inferType
let type ← mkForallParams params type
let auxDecl := { decl with name := auxName, levelParams := [], type, params, value }
auxDecl.saveMono
return auxDecl
let updateDecl : InternalizeM Decl := do
let params ← decl.params.mapM internalizeParam
let mut args := #[]
for used in mask, param in params do
if used then
args := args.push param.toExpr
let letDecl ← mkAuxLetDecl (mkAppN (mkConst auxName) args)
let value := .let letDecl (.return letDecl.fvarId)
let decl := { decl with params, value, inlineAttr? := some .inline, recursive := false }
decl.saveMono
return decl
let unusedParams := decl.params.filter fun param => !used.contains param.fvarId
let auxDecl ← mkAuxDecl
let decl ← updateDecl |>.run' {}
eraseParams unusedParams
return #[auxDecl, decl]
def reduceArity : Pass :=
.mkPerDeclaration `reduceArity (Decl.reduceArity ·) .mono
def reduceArity : Pass where
phase := .mono
name := `reduceArity
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.reduceArity)
builtin_initialize
registerTraceClass `Compiler.reduceArity (inherited := true)

View file

@ -9,17 +9,17 @@ holes.lean:10:9-10:12: error: don't know how to synthesize implicit argument
context:
x : Nat
⊢ Type
holes.lean:11:4-11:5: error: don't know how to synthesize placeholder
context:
x : Nat
y : Nat := g x + g x
⊢ Nat
holes.lean:11:8-11:13: error: don't know how to synthesize placeholder
context:
case hole
x : Nat
y : Nat := g x + g x
⊢ Nat
holes.lean:11:4-11:5: error: don't know how to synthesize placeholder
context:
x : Nat
y : Nat := g x + g x
⊢ Nat
holes.lean:13:10-13:11: error: failed to infer binder type
holes.lean:15:16-15:17: error: failed to infer binder type
holes.lean:19:0-19:3: error: don't know how to synthesize implicit argument

View file

@ -1,4 +1,4 @@
jason1.lean:47:40-47:57: error: don't know how to synthesize implicit argument
jason1.lean:48:100-48:117: error: don't know how to synthesize implicit argument
@TySyntaxLayer.nat G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)
context:
G T Tm : Type
@ -12,6 +12,42 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:48:119-48:124: error: don't know how to synthesize implicit argument
@EGrfl
(getCtx
(TAlgebra
(@TySyntaxLayer.nat G T EG getCtx
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
context:
G T Tm : Type
EG : G → G → Type
ET : T → T → Type
ETm : Tm → Tm → Type
EGrfl : {Γ : G} → EG Γ Γ
getCtx : T → G
getTy : Tm → T
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:48:125-48:130: error: don't know how to synthesize implicit argument
@EGrfl
(getCtx
(TAlgebra
(@TySyntaxLayer.nat G T EG getCtx
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
context:
G T Tm : Type
EG : G → G → Type
ET : T → T → Type
ETm : Tm → Tm → Type
EGrfl : {Γ : G} → EG Γ Γ
getCtx : T → G
getTy : Tm → T
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:48:41-48:130: error: don't know how to synthesize implicit argument
@TySyntaxLayer.arrow G T EG getCtx
(getCtx
@ -58,7 +94,7 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:48:100-48:117: error: don't know how to synthesize implicit argument
jason1.lean:47:40-47:57: error: don't know how to synthesize implicit argument
@TySyntaxLayer.nat G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)
context:
G T Tm : Type
@ -72,42 +108,6 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:48:119-48:124: error: don't know how to synthesize implicit argument
@EGrfl
(getCtx
(TAlgebra
(@TySyntaxLayer.nat G T EG getCtx
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
context:
G T Tm : Type
EG : G → G → Type
ET : T → T → Type
ETm : Tm → Tm → Type
EGrfl : {Γ : G} → EG Γ Γ
getCtx : T → G
getTy : Tm → T
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:48:125-48:130: error: don't know how to synthesize implicit argument
@EGrfl
(getCtx
(TAlgebra
(@TySyntaxLayer.nat G T EG getCtx
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
context:
G T Tm : Type
EG : G → G → Type
ET : T → T → Type
ETm : Tm → Tm → Type
EGrfl : {Γ : G} → EG Γ Γ
getCtx : T → G
getTy : Tm → T
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
TAlgebra : TySyntaxLayer G T EG getCtx → T
Γ✝ : G
⊢ G
jason1.lean:46:40-46:57: error: don't know how to synthesize implicit argument
@TySyntaxLayer.top G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)
context:

View file

@ -0,0 +1,14 @@
import Lean
open Lean Compiler LCNF
@[noinline] def double (x : Nat) := x + x
set_option pp.funBinderTypes true
set_option trace.Compiler.result true in
def g (n : Nat) (a b : α) (f : αα) :=
match n with
| 0 => a
| n+1 => f (g n a b f)
set_option trace.Compiler.result true in
def h (n : Nat) (a : Nat) :=
g n a a double + g a n n double

View file

@ -0,0 +1,17 @@
[Compiler.result] size: 5
def g._redArg (n : Nat) (a : ◾) (f : ◾ → ◾) : ◾ :=
cases n : ◾
| Nat.zero =>
a
| Nat.succ (n.1 : Nat) =>
let _x.2 := g._redArg n.1 a f
let _x.3 := f _x.2
_x.3
[Compiler.result] size: 1 def g (α : ◾) (n : Nat) (a : ◾) (b : ◾) (f : ◾ → ◾) : ◾ := let _x.1 := g._redArg n a f _x.1
[Compiler.result] size: 4
def h (n : Nat) (a : Nat) : Nat :=
let _x.1 := double
let _x.2 := g._redArg n a _x.1
let _x.3 := g._redArg a n _x.1
let _x.4 := Nat.add _x.2 _x.3
_x.4