feat: add FunDecl.etaExpand

This commit is contained in:
Leonardo de Moura 2022-08-27 08:49:22 -07:00
parent 11c8253f6c
commit 0f40dfc063
4 changed files with 57 additions and 15 deletions

View file

@ -46,4 +46,28 @@ where
return c
| .unreach .. => return c
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
let typeArity := getArrowArity decl.type
let valueArity := decl.getArity
if typeArity <= valueArity then
-- It can be < because of the "any" type
return decl
else
let valueType ← instantiateForall decl.type decl.params
let psNew ← mkNewParams valueType #[] #[]
let params := decl.params ++ psNew
let xs := psNew.map fun p => Expr.fvar p.fvarId
let value ← decl.value.bind fun fvarId => do
let auxDecl ← mkAuxLetDecl (mkAppN (.fvar fvarId) xs)
return .let auxDecl (.return auxDecl.fvarId)
decl.update decl.type params value
where
mkNewParams (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
match type with
| .forallE _ d b _ =>
let d := d.instantiateRev xs
let p ← mkAuxParam d
mkNewParams b (xs.push (.fvar p.fvarId)) (ps.push p)
| _ => return ps
end Lean.Compiler.LCNF

View file

@ -182,12 +182,8 @@ def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr :=
def AltCore.inferType (alt : Alt) : CompilerM Expr :=
alt.getCode.inferType
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do
if e.isFVar then
return e
else
let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e
return .fvar letDecl.fvarId
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM LetDecl := do
mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
InferType.mkForallParams params type |>.run {}
@ -206,4 +202,17 @@ def mkAuxJpDecl' (fvarId : FVarId) (code : Code) (prefixName := `_jp) : Compiler
let params := #[{ fvarId, binderName := y, type := yType }]
mkAuxFunDecl params code prefixName
def instantiateForall (type : Expr) (params : Array Param) : CoreM Expr :=
go type 0
where
go (type : Expr) (i : Nat) : CoreM Expr :=
if h : i < params.size then
let p := params[i]
match type with
| .forallE _ _ b _ => go (b.instantiate1 (.fvar p.fvarId)) (i+1)
| _ => throwError "invalid instantiateForall, too many parameters"
else
return type
termination_by go i => params.size - i
end Lean.Compiler.LCNF

View file

@ -37,28 +37,32 @@ inductive Element where
deriving Inhabited
def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do
let e ← mkAuxLetDecl e
go seq.size (.return e.fvarId!)
if let .fvar fvarId := e then
go seq seq.size (.return fvarId)
else
let decl ← mkAuxLetDecl e
let seq := seq.push (.let decl)
go seq seq.size (.return decl.fvarId)
where
go (i : Nat) (c : Code) : CompilerM Code := do
go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do
if i > 0 then
match seq[i-1]! with
| .jp decl => go (i - 1) (.jp decl c)
| .fun decl => go (i - 1) (.fun decl c)
| .let decl => go (i - 1) (.let decl c)
| .jp decl => go seq (i - 1) (.jp decl c)
| .fun decl => go seq (i - 1) (.fun decl c)
| .let decl => go seq (i - 1) (.let decl c)
| .unreach => return .unreach (← c.inferType)
| .cases fvarId cases =>
if let .return fvarId' := c then
if fvarId == fvarId' then
go (i - 1) (.cases cases)
go seq (i - 1) (.cases cases)
else
-- `cases` is dead code
go (i - 1) c
go seq (i - 1) c
else
/- Create a join point for `c` and jump to it from `cases` -/
let jpDecl ← mkAuxJpDecl' fvarId c
let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
go (i - 1) (.jp jpDecl cases)
go seq (i - 1) (.jp jpDecl cases)
else
return c

View file

@ -237,4 +237,9 @@ def isClass? (type : Expr) : CoreM (Option Name) := do
else
return none
def getArrowArity (e : Expr) :=
match e with
| .forallE _ _ b _ => getArrowArity b + 1
| _ => 0
end Lean.Compiler.LCNF