feat: complete reduceArity pass
This commit is contained in:
parent
1a02c326e5
commit
72c576f62a
5 changed files with 149 additions and 53 deletions
|
|
@ -5,10 +5,10 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
import Lean.Compiler.LCNF.PhaseExt
|
||||
import Lean.Compiler.LCNF.InferType
|
||||
import Lean.Compiler.LCNF.Internalize
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
namespace ReduceArity
|
||||
|
||||
/-!
|
||||
# Function arity reduction
|
||||
|
||||
|
|
@ -47,6 +47,7 @@ def f (x y : Nat) : Nat :=
|
|||
This module does not have similar support for mutual recursive applications.
|
||||
We assume this limitation is irrelevant in practice.
|
||||
-/
|
||||
namespace FindUsed
|
||||
|
||||
structure Context where
|
||||
decl : Decl
|
||||
|
|
@ -79,8 +80,13 @@ def visitExpr (e : Expr) : FindUsedM Unit := do
|
|||
for param in decl.params, arg in args do
|
||||
unless arg.isFVarOf param.fvarId do
|
||||
visitArg arg
|
||||
-- over-application
|
||||
for arg in args[decl.params.size:] do
|
||||
visitArg arg
|
||||
-- partial-application
|
||||
for param in decl.params[args.size:] do
|
||||
-- If recursive function is partially applied, we assume missing parameters are used because we don't want to eta-expand.
|
||||
visitFVar param.fvarId
|
||||
else
|
||||
visitArg f
|
||||
args.forM visitArg
|
||||
|
|
@ -104,20 +110,79 @@ def collectUsedParams (decl : Decl) : CompilerM FVarIdSet := do
|
|||
let (_, { used, .. }) ← visit decl.value |>.run { decl, params } |>.run {}
|
||||
return used
|
||||
|
||||
end ReduceArity
|
||||
open ReduceArity
|
||||
end FindUsed
|
||||
|
||||
def Decl.reduceArity (decl : Decl) : CompilerM Decl := do
|
||||
namespace ReduceArity
|
||||
|
||||
structure Context where
|
||||
declName : Name
|
||||
auxDeclName : Name
|
||||
paramMask : Array Bool
|
||||
|
||||
abbrev ReduceM := ReaderT Context CompilerM
|
||||
|
||||
partial def reduce (code : Code) : ReduceM Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
if decl.value.isAppOf (← read).declName then
|
||||
let mut args := #[]
|
||||
for used in (← read).paramMask, arg in decl.value.getAppArgs do
|
||||
if used then
|
||||
args := args.push arg
|
||||
let decl ← decl.updateValue (mkAppN (mkConst (← read).auxDeclName) args)
|
||||
return code.updateLet! decl (← reduce k)
|
||||
else
|
||||
return code.updateLet! decl (← reduce k)
|
||||
| .fun decl k | .jp decl k =>
|
||||
let decl ← decl.updateValue (← reduce decl.value)
|
||||
return code.updateFun! decl (← reduce k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← reduce alt.getCode)
|
||||
return code.updateAlts! alts
|
||||
| .unreach .. | .jmp .. | .return .. => return code
|
||||
|
||||
end ReduceArity
|
||||
|
||||
open FindUsed ReduceArity Internalize
|
||||
|
||||
def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
||||
let used ← collectUsedParams decl
|
||||
if used.size == decl.params.size then
|
||||
return decl -- Declarations uses all parameters
|
||||
return #[decl] -- Declarations uses all parameters
|
||||
else
|
||||
trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}"
|
||||
-- TODO: create auxiliary function wth used parameters only
|
||||
return decl
|
||||
let mask := decl.params.map fun param => used.contains param.fvarId
|
||||
let auxName := decl.name ++ `_redArg
|
||||
let mkAuxDecl : CompilerM Decl := do
|
||||
let params := decl.params.filter fun param => used.contains param.fvarId
|
||||
let value ← reduce decl.value |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask }
|
||||
let type ← value.inferType
|
||||
let type ← mkForallParams params type
|
||||
let auxDecl := { decl with name := auxName, levelParams := [], type, params, value }
|
||||
auxDecl.saveMono
|
||||
return auxDecl
|
||||
let updateDecl : InternalizeM Decl := do
|
||||
let params ← decl.params.mapM internalizeParam
|
||||
let mut args := #[]
|
||||
for used in mask, param in params do
|
||||
if used then
|
||||
args := args.push param.toExpr
|
||||
let letDecl ← mkAuxLetDecl (mkAppN (mkConst auxName) args)
|
||||
let value := .let letDecl (.return letDecl.fvarId)
|
||||
let decl := { decl with params, value, inlineAttr? := some .inline, recursive := false }
|
||||
decl.saveMono
|
||||
return decl
|
||||
let unusedParams := decl.params.filter fun param => !used.contains param.fvarId
|
||||
let auxDecl ← mkAuxDecl
|
||||
let decl ← updateDecl |>.run' {}
|
||||
eraseParams unusedParams
|
||||
return #[auxDecl, decl]
|
||||
|
||||
def reduceArity : Pass :=
|
||||
.mkPerDeclaration `reduceArity (Decl.reduceArity ·) .mono
|
||||
def reduceArity : Pass where
|
||||
phase := .mono
|
||||
name := `reduceArity
|
||||
run := fun decls => do
|
||||
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.reduceArity)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.reduceArity (inherited := true)
|
||||
|
|
|
|||
|
|
@ -9,17 +9,17 @@ holes.lean:10:9-10:12: error: don't know how to synthesize implicit argument
|
|||
context:
|
||||
x : Nat
|
||||
⊢ Type
|
||||
holes.lean:11:4-11:5: error: don't know how to synthesize placeholder
|
||||
context:
|
||||
x : Nat
|
||||
y : Nat := g x + g x
|
||||
⊢ Nat
|
||||
holes.lean:11:8-11:13: error: don't know how to synthesize placeholder
|
||||
context:
|
||||
case hole
|
||||
x : Nat
|
||||
y : Nat := g x + g x
|
||||
⊢ Nat
|
||||
holes.lean:11:4-11:5: error: don't know how to synthesize placeholder
|
||||
context:
|
||||
x : Nat
|
||||
y : Nat := g x + g x
|
||||
⊢ Nat
|
||||
holes.lean:13:10-13:11: error: failed to infer binder type
|
||||
holes.lean:15:16-15:17: error: failed to infer binder type
|
||||
holes.lean:19:0-19:3: error: don't know how to synthesize implicit argument
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
jason1.lean:47:40-47:57: error: don't know how to synthesize implicit argument
|
||||
jason1.lean:48:100-48:117: error: don't know how to synthesize implicit argument
|
||||
@TySyntaxLayer.nat G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)
|
||||
context:
|
||||
G T Tm : Type
|
||||
|
|
@ -12,6 +12,42 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
|||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:48:119-48:124: error: don't know how to synthesize implicit argument
|
||||
@EGrfl
|
||||
(getCtx
|
||||
(TAlgebra
|
||||
(@TySyntaxLayer.nat G T EG getCtx
|
||||
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
|
||||
context:
|
||||
G T Tm : Type
|
||||
EG : G → G → Type
|
||||
ET : T → T → Type
|
||||
ETm : Tm → Tm → Type
|
||||
EGrfl : {Γ : G} → EG Γ Γ
|
||||
getCtx : T → G
|
||||
getTy : Tm → T
|
||||
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
||||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:48:125-48:130: error: don't know how to synthesize implicit argument
|
||||
@EGrfl
|
||||
(getCtx
|
||||
(TAlgebra
|
||||
(@TySyntaxLayer.nat G T EG getCtx
|
||||
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
|
||||
context:
|
||||
G T Tm : Type
|
||||
EG : G → G → Type
|
||||
ET : T → T → Type
|
||||
ETm : Tm → Tm → Type
|
||||
EGrfl : {Γ : G} → EG Γ Γ
|
||||
getCtx : T → G
|
||||
getTy : Tm → T
|
||||
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
||||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:48:41-48:130: error: don't know how to synthesize implicit argument
|
||||
@TySyntaxLayer.arrow G T EG getCtx
|
||||
(getCtx
|
||||
|
|
@ -58,7 +94,7 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
|||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:48:100-48:117: error: don't know how to synthesize implicit argument
|
||||
jason1.lean:47:40-47:57: error: don't know how to synthesize implicit argument
|
||||
@TySyntaxLayer.nat G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)
|
||||
context:
|
||||
G T Tm : Type
|
||||
|
|
@ -72,42 +108,6 @@ GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
|||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:48:119-48:124: error: don't know how to synthesize implicit argument
|
||||
@EGrfl
|
||||
(getCtx
|
||||
(TAlgebra
|
||||
(@TySyntaxLayer.nat G T EG getCtx
|
||||
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
|
||||
context:
|
||||
G T Tm : Type
|
||||
EG : G → G → Type
|
||||
ET : T → T → Type
|
||||
ETm : Tm → Tm → Type
|
||||
EGrfl : {Γ : G} → EG Γ Γ
|
||||
getCtx : T → G
|
||||
getTy : Tm → T
|
||||
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
||||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:48:125-48:130: error: don't know how to synthesize implicit argument
|
||||
@EGrfl
|
||||
(getCtx
|
||||
(TAlgebra
|
||||
(@TySyntaxLayer.nat G T EG getCtx
|
||||
(?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝))))
|
||||
context:
|
||||
G T Tm : Type
|
||||
EG : G → G → Type
|
||||
ET : T → T → Type
|
||||
ETm : Tm → Tm → Type
|
||||
EGrfl : {Γ : G} → EG Γ Γ
|
||||
getCtx : T → G
|
||||
getTy : Tm → T
|
||||
GAlgebra : CtxSyntaxLayer G T EG getCtx → G
|
||||
TAlgebra : TySyntaxLayer G T EG getCtx → T
|
||||
Γ✝ : G
|
||||
⊢ G
|
||||
jason1.lean:46:40-46:57: error: don't know how to synthesize implicit argument
|
||||
@TySyntaxLayer.top G T EG getCtx (?m G T Tm EG ET ETm EGrfl getCtx getTy GAlgebra TAlgebra getTyStep Γ✝)
|
||||
context:
|
||||
|
|
|
|||
14
tests/lean/reduceArity.lean
Normal file
14
tests/lean/reduceArity.lean
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import Lean
|
||||
open Lean Compiler LCNF
|
||||
@[noinline] def double (x : Nat) := x + x
|
||||
|
||||
set_option pp.funBinderTypes true
|
||||
set_option trace.Compiler.result true in
|
||||
def g (n : Nat) (a b : α) (f : α → α) :=
|
||||
match n with
|
||||
| 0 => a
|
||||
| n+1 => f (g n a b f)
|
||||
|
||||
set_option trace.Compiler.result true in
|
||||
def h (n : Nat) (a : Nat) :=
|
||||
g n a a double + g a n n double
|
||||
17
tests/lean/reduceArity.lean.expected.out
Normal file
17
tests/lean/reduceArity.lean.expected.out
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
[Compiler.result] size: 5
|
||||
def g._redArg (n : Nat) (a : ◾) (f : ◾ → ◾) : ◾ :=
|
||||
cases n : ◾
|
||||
| Nat.zero =>
|
||||
a
|
||||
| Nat.succ (n.1 : Nat) =>
|
||||
let _x.2 := g._redArg n.1 a f
|
||||
let _x.3 := f _x.2
|
||||
_x.3
|
||||
[Compiler.result] size: 1 def g (α : ◾) (n : Nat) (a : ◾) (b : ◾) (f : ◾ → ◾) : ◾ := let _x.1 := g._redArg n a f _x.1
|
||||
[Compiler.result] size: 4
|
||||
def h (n : Nat) (a : Nat) : Nat :=
|
||||
let _x.1 := double
|
||||
let _x.2 := g._redArg n a _x.1
|
||||
let _x.3 := g._redArg a n _x.1
|
||||
let _x.4 := Nat.add _x.2 _x.3
|
||||
_x.4
|
||||
Loading…
Add table
Reference in a new issue