199 lines
6.5 KiB
Text
199 lines
6.5 KiB
Text
/-
|
|
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
Authors: Leonardo de Moura
|
|
-/
|
|
module
|
|
|
|
prelude
|
|
public import Lean.Compiler.LCNF.CompilerM
|
|
public import Lean.Compiler.LCNF.PhaseExt
|
|
public import Lean.Compiler.LCNF.InferType
|
|
public import Lean.Compiler.LCNF.Internalize
|
|
|
|
public section
|
|
|
|
namespace Lean.Compiler.LCNF
|
|
/-!
|
|
# Function arity reduction
|
|
|
|
This module finds "used" parameters in a declaration, and then
|
|
create an auxiliary declaration that contains only used parameters.
|
|
For example:
|
|
```
|
|
def f (x y : Nat) : Nat :=
|
|
let _x.1 := Nat.add x x
|
|
let _x.2 := Nat.mul _x.1 _x.1
|
|
_x.2
|
|
```
|
|
is converted into
|
|
```
|
|
def f._rarg (x : Nat) : Nat :=
|
|
let _x.1 := Nat.add x x
|
|
let _x.2 := Nat.mul _x.1 _x.1
|
|
_x.2
|
|
def f (x y : Nat) : Nat :=
|
|
let _x.1 := f._rarg x
|
|
_x.1
|
|
```
|
|
Note that any `f` full application is going to be inlined in the next `simp` pass.
|
|
|
|
This module has basic support for detecting "unused" variables in recursive definitions.
|
|
For example, the `y` in the following definition in correctly treated as "unused"
|
|
```
|
|
def f (x y : Nat) : Nat :=
|
|
cases x
|
|
| zero => x
|
|
| succ _x.1 =>
|
|
let _x.2 := f _x.1 y
|
|
let _x.3 := Nat.mul _x.2 _x.2
|
|
_x.3
|
|
```
|
|
This module does not have similar support for mutual recursive applications.
|
|
We assume this limitation is irrelevant in practice.
|
|
-/
|
|
namespace FindUsed
|
|
|
|
structure Context where
|
|
decl : Decl
|
|
params : FVarIdSet
|
|
|
|
structure State where
|
|
used : FVarIdHashSet := {}
|
|
|
|
abbrev FindUsedM := ReaderT Context <| StateRefT State CompilerM
|
|
|
|
def visitFVar (fvarId : FVarId) : FindUsedM Unit := do
|
|
if (← read).params.contains fvarId then
|
|
modify fun s => { s with used := s.used.insert fvarId }
|
|
|
|
def visitArg (arg : Arg) : FindUsedM Unit := do
|
|
match arg with
|
|
| .erased | .type .. => return ()
|
|
| .fvar fvarId => visitFVar fvarId
|
|
|
|
def visitLetValue (e : LetValue) : FindUsedM Unit := do
|
|
match e with
|
|
| .erased | .lit .. => return ()
|
|
| .proj _ _ fvarId => visitFVar fvarId
|
|
| .fvar fvarId args => visitFVar fvarId; args.forM visitArg
|
|
| .const declName _ args =>
|
|
let decl := (← read).decl
|
|
if declName == decl.name then
|
|
for param in decl.params, arg in args do
|
|
match arg with
|
|
| .fvar fvarId =>
|
|
unless fvarId == param.fvarId do
|
|
visitFVar fvarId
|
|
| .erased | .type .. => pure ()
|
|
-- over-application
|
|
for arg in args[decl.params.size...*] do
|
|
visitArg arg
|
|
-- partial-application
|
|
for param in decl.params[args.size...*] do
|
|
-- If recursive function is partially applied, we assume missing parameters are used because we don't want to eta-expand.
|
|
visitFVar param.fvarId
|
|
else
|
|
args.forM visitArg
|
|
|
|
partial def visit (code : Code) : FindUsedM Unit := do
|
|
match code with
|
|
| .let decl k =>
|
|
visitLetValue decl.value
|
|
visit k
|
|
| .jp decl k | .fun decl k =>
|
|
visit decl.value; visit k
|
|
| .cases c =>
|
|
visitFVar c.discr
|
|
c.alts.forM fun alt => visit alt.getCode
|
|
| .jmp _ args => args.forM visitArg
|
|
| .return fvarId => visitFVar fvarId
|
|
| .unreach _ => return ()
|
|
|
|
def collectUsedParams (decl : Decl) : CompilerM FVarIdHashSet := do
|
|
let params := decl.params.foldl (init := {}) fun s p => s.insert p.fvarId
|
|
let (_, { used, .. }) ← decl.value.forCodeM visit |>.run { decl, params } |>.run {}
|
|
return used
|
|
|
|
end FindUsed
|
|
|
|
namespace ReduceArity
|
|
|
|
structure Context where
|
|
declName : Name
|
|
auxDeclName : Name
|
|
paramMask : Array Bool
|
|
|
|
abbrev ReduceM := ReaderT Context CompilerM
|
|
|
|
partial def reduce (code : Code) : ReduceM Code := do
|
|
match code with
|
|
| .let decl k =>
|
|
let .const declName _ args := decl.value | do return code.updateLet! decl (← reduce k)
|
|
unless declName == (← read).declName do return code.updateLet! decl (← reduce k)
|
|
let mut argsNew := #[]
|
|
for used in (← read).paramMask, arg in args do
|
|
if used then
|
|
argsNew := argsNew.push arg
|
|
let decl ← decl.updateValue (.const (← read).auxDeclName [] argsNew)
|
|
return code.updateLet! decl (← reduce k)
|
|
| .fun decl k | .jp decl k =>
|
|
let decl ← decl.updateValue (← reduce decl.value)
|
|
return code.updateFun! decl (← reduce k)
|
|
| .cases c =>
|
|
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← reduce alt.getCode)
|
|
return code.updateAlts! alts
|
|
| .unreach .. | .jmp .. | .return .. => return code
|
|
|
|
end ReduceArity
|
|
|
|
open FindUsed ReduceArity Internalize
|
|
|
|
def Decl.reduceArity (decl : Decl) : CompilerM (Array Decl) := do
|
|
match decl.value with
|
|
| .code code =>
|
|
let used ← collectUsedParams decl
|
|
if used.size == decl.params.size || used.size == 0 then
|
|
-- Do nothing if all params were used, or if no params were used. In the latter case,
|
|
-- this would promote the decl to a constant, which could execute unreachable code.
|
|
return #[decl]
|
|
else
|
|
trace[Compiler.reduceArity] "{decl.name}, used params: {used.toList.map mkFVar}"
|
|
let mask := decl.params.map fun param => used.contains param.fvarId
|
|
let auxName := decl.name ++ `_redArg
|
|
let mkAuxDecl : CompilerM Decl := do
|
|
let params := decl.params.filter fun param => used.contains param.fvarId
|
|
let value ← decl.value.mapCodeM reduce |>.run { declName := decl.name, auxDeclName := auxName, paramMask := mask }
|
|
let type ← code.inferType
|
|
let type ← mkForallParams params type
|
|
let auxDecl := { decl with name := auxName, levelParams := [], type, params, value }
|
|
auxDecl.saveMono
|
|
return auxDecl
|
|
let updateDecl : InternalizeM Decl := do
|
|
let params ← decl.params.mapM internalizeParam
|
|
let mut args := #[]
|
|
for used in mask, param in params do
|
|
if used then
|
|
args := args.push param.toArg
|
|
let letDecl ← mkAuxLetDecl (.const auxName [] args)
|
|
let value := .code (.let letDecl (.return letDecl.fvarId))
|
|
let decl := { decl with params, value, inlineAttr? := some .inline, recursive := false }
|
|
decl.saveMono
|
|
return decl
|
|
let unusedParams := decl.params.filter fun param => !used.contains param.fvarId
|
|
let auxDecl ← mkAuxDecl
|
|
let decl ← updateDecl |>.run' {}
|
|
eraseParams unusedParams
|
|
return #[auxDecl, decl]
|
|
| .extern .. => return #[decl]
|
|
|
|
def reduceArity : Pass where
|
|
phase := .mono
|
|
name := `reduceArity
|
|
run := fun decls => do
|
|
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.reduceArity)
|
|
|
|
builtin_initialize
|
|
registerTraceClass `Compiler.reduceArity (inherited := true)
|
|
|
|
end Lean.Compiler.LCNF
|