diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 72080ef727..2a48399e0b 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -155,6 +155,12 @@ structure Config where deriving Inhabited structure Context where + /-- + Name of the declaration being simplified. + We currently use this information because we are generating phase1 declarations on demand, + and it may trigger non-termination when trying to access the phase1 declaration. + -/ + declName : Name config : Config := {} discrCtorMap : FVarIdMap Expr := {} @@ -196,6 +202,16 @@ abbrev SimpM := ReaderT Context $ StateRefT State CompilerM instance : MonadFVarSubst SimpM where getSubst := return (← get).subst +/-- +Use `findExpr`, and if the result is a free variable, check whether it is in the map `discrCtorMap`. +We use this method when simplifying projections and cases-constructor. +-/ +def findCtor (e : Expr) : SimpM Expr := do + let e ← findExpr e + let .fvar fvarId := e | return e + let some ctor := (← read).discrCtorMap.find? fvarId | return e + return ctor + /-- Execute `x` with the information that `discr = ctorName ctorFields`. We use this information to simplify nested cases on the same discriminant. @@ -318,6 +334,28 @@ structure InlineCandidateInfo where def InlineCandidateInfo.arity : InlineCandidateInfo → Nat | { params, .. } => params.size +/-- +Return `some i` if `decl` is of the form +``` +def f (a_0 ... a_i ...) := + ... + cases a_i + | ... + | ... +``` +That is, `f` is a sequence of declarations followed by a `cases` on the parameter `i`. +We use this function to decide whether we should inline a declaration tagged with +`[inlineIfReduce]` or not. +-/ +def isCasesOnParam? (decl : Decl) : Option Nat := + go decl.value +where + go (code : Code) : Option Nat := + match code with + | .let _ k | .jp _ k | .fun _ k => go k + | .cases c => decl.params.findIdx? fun param => param.fvarId == c.discr + | _ => none + /-- Return `some info` if `e` should be inlined. -/ @@ -330,13 +368,20 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do let numArgs := e.getAppNumArgs let f := e.getAppFn if let .const declName us ← findExpr f then - unless mustInline || hasInlineAttribute (← getEnv) declName do return none + if declName == (← read).declName then return none -- TODO: remove after we start storing phase1 code in .olean files + let inlineIfReduce := hasInlineIfReduceAttribute (← getEnv) declName + unless mustInline || hasInlineAttribute (← getEnv) declName || inlineIfReduce do return none -- TODO: check whether function is recursive or not. -- We can skip the test and store function inline so far. let some decl ← getStage1Decl? declName | return none let arity := decl.getArity let inlinePartial := (← read).config.inlinePartial if !mustInline && !inlinePartial && numArgs < arity then return none + if inlineIfReduce then + let some paramIdx := isCasesOnParam? decl | return none + unless paramIdx < numArgs do return none + let arg ← findCtor (e.getArg! paramIdx) + unless arg.isConstructorApp (← getEnv) do return none let params := decl.instantiateParamsLevelParams us let value := decl.instantiateValueLevelParams us incInline @@ -746,16 +791,6 @@ private def addDefault (alts : Array Alt) : SimpM (Array Alt) := do altsNew := altsNew.push alt return altsNew.push (.default max.getCode) -/-- -Use `findExpr`, and if the result is a free variable, check whether it is in the map `discrCtorMap`. -We use this method when simplifying projections and cases-constructor. --/ -def findCtor (e : Expr) : SimpM Expr := do - let e ← findExpr e - let .fvar fvarId := e | return e - let some ctor := (← read).discrCtorMap.find? fvarId | return e - return ctor - /-- Try to simplify projections `.proj _ i s` where `s` is constructor. -/ @@ -1002,7 +1037,7 @@ partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do go decl config where go (decl : Decl) (config : Config) : CompilerM Decl := do - if let some decl ← decl.simp? |>.run { config } |>.run' {} then + if let some decl ← decl.simp? |>.run { config, declName := decl.name } |>.run' {} then -- TODO: bound number of steps? go decl config else diff --git a/tests/lean/run/inlineIfReduceLCNF.lean b/tests/lean/run/inlineIfReduceLCNF.lean new file mode 100644 index 0000000000..78b4343567 --- /dev/null +++ b/tests/lean/run/inlineIfReduceLCNF.lean @@ -0,0 +1,7 @@ +import Lean + +def f (x y z : Nat) : Array Nat := + #[x, y, z, y, x] + +set_option trace.Compiler.result true +#eval Lean.Compiler.compile #[``f]