feat: support for [inlineIfReduce] at new compiler

This commit is contained in:
Leonardo de Moura 2022-09-13 18:21:14 -07:00
parent e8246e026d
commit fccb60fb69
2 changed files with 54 additions and 12 deletions

View file

@ -155,6 +155,12 @@ structure Config where
deriving Inhabited
structure Context where
/--
Name of the declaration being simplified.
We currently use this information because we are generating phase1 declarations on demand,
and it may trigger non-termination when trying to access the phase1 declaration.
-/
declName : Name
config : Config := {}
discrCtorMap : FVarIdMap Expr := {}
@ -196,6 +202,16 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM
instance : MonadFVarSubst SimpM where
getSubst := return (← get).subst
/--
Use `findExpr`, and if the result is a free variable, check whether it is in the map `discrCtorMap`.
We use this method when simplifying projections and cases-constructor.
-/
def findCtor (e : Expr) : SimpM Expr := do
let e ← findExpr e
let .fvar fvarId := e | return e
let some ctor := (← read).discrCtorMap.find? fvarId | return e
return ctor
/--
Execute `x` with the information that `discr = ctorName ctorFields`.
We use this information to simplify nested cases on the same discriminant.
@ -318,6 +334,28 @@ structure InlineCandidateInfo where
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
| { params, .. } => params.size
/--
Return `some i` if `decl` is of the form
```
def f (a_0 ... a_i ...) :=
...
cases a_i
| ...
| ...
```
That is, `f` is a sequence of declarations followed by a `cases` on the parameter `i`.
We use this function to decide whether we should inline a declaration tagged with
`[inlineIfReduce]` or not.
-/
def isCasesOnParam? (decl : Decl) : Option Nat :=
go decl.value
where
go (code : Code) : Option Nat :=
match code with
| .let _ k | .jp _ k | .fun _ k => go k
| .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr
| _ => none
/--
Return `some info` if `e` should be inlined.
-/
@ -330,13 +368,20 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
let numArgs := e.getAppNumArgs
let f := e.getAppFn
if let .const declName us ← findExpr f then
unless mustInline || hasInlineAttribute (← getEnv) declName do return none
if declName == (← read).declName then return none -- TODO: remove after we start storing phase1 code in .olean files
let inlineIfReduce := hasInlineIfReduceAttribute (← getEnv) declName
unless mustInline || hasInlineAttribute (← getEnv) declName || inlineIfReduce do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
let arity := decl.getArity
let inlinePartial := (← read).config.inlinePartial
if !mustInline && !inlinePartial && numArgs < arity then return none
if inlineIfReduce then
let some paramIdx := isCasesOnParam? decl | return none
unless paramIdx < numArgs do return none
let arg ← findCtor (e.getArg! paramIdx)
unless arg.isConstructorApp (← getEnv) do return none
let params := decl.instantiateParamsLevelParams us
let value := decl.instantiateValueLevelParams us
incInline
@ -746,16 +791,6 @@ private def addDefault (alts : Array Alt) : SimpM (Array Alt) := do
altsNew := altsNew.push alt
return altsNew.push (.default max.getCode)
/--
Use `findExpr`, and if the result is a free variable, check whether it is in the map `discrCtorMap`.
We use this method when simplifying projections and cases-constructor.
-/
def findCtor (e : Expr) : SimpM Expr := do
let e ← findExpr e
let .fvar fvarId := e | return e
let some ctor := (← read).discrCtorMap.find? fvarId | return e
return ctor
/--
Try to simplify projections `.proj _ i s` where `s` is constructor.
-/
@ -1002,7 +1037,7 @@ partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
go decl config
where
go (decl : Decl) (config : Config) : CompilerM Decl := do
if let some decl ← decl.simp? |>.run { config } |>.run' {} then
if let some decl ← decl.simp? |>.run { config, declName := decl.name } |>.run' {} then
-- TODO: bound number of steps?
go decl config
else

View file

@ -0,0 +1,7 @@
import Lean
def f (x y z : Nat) : Array Nat :=
#[x, y, z, y, x]
set_option trace.Compiler.result true
#eval Lean.Compiler.compile #[``f]