feat: add FunDecl.etaExpand
This commit is contained in:
parent
11c8253f6c
commit
0f40dfc063
4 changed files with 57 additions and 15 deletions
|
|
@ -46,4 +46,28 @@ where
|
|||
return c
|
||||
| .unreach .. => return c
|
||||
|
||||
def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
|
||||
let typeArity := getArrowArity decl.type
|
||||
let valueArity := decl.getArity
|
||||
if typeArity <= valueArity then
|
||||
-- It can be < because of the "any" type
|
||||
return decl
|
||||
else
|
||||
let valueType ← instantiateForall decl.type decl.params
|
||||
let psNew ← mkNewParams valueType #[] #[]
|
||||
let params := decl.params ++ psNew
|
||||
let xs := psNew.map fun p => Expr.fvar p.fvarId
|
||||
let value ← decl.value.bind fun fvarId => do
|
||||
let auxDecl ← mkAuxLetDecl (mkAppN (.fvar fvarId) xs)
|
||||
return .let auxDecl (.return auxDecl.fvarId)
|
||||
decl.update decl.type params value
|
||||
where
|
||||
mkNewParams (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
|
||||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
let d := d.instantiateRev xs
|
||||
let p ← mkAuxParam d
|
||||
mkNewParams b (xs.push (.fvar p.fvarId)) (ps.push p)
|
||||
| _ => return ps
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -182,12 +182,8 @@ def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr :=
|
|||
def AltCore.inferType (alt : Alt) : CompilerM Expr :=
|
||||
alt.getCode.inferType
|
||||
|
||||
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM Expr := do
|
||||
if e.isFVar then
|
||||
return e
|
||||
else
|
||||
let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e
|
||||
return .fvar letDecl.fvarId
|
||||
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : CompilerM LetDecl := do
|
||||
mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e
|
||||
|
||||
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
|
||||
InferType.mkForallParams params type |>.run {}
|
||||
|
|
@ -206,4 +202,17 @@ def mkAuxJpDecl' (fvarId : FVarId) (code : Code) (prefixName := `_jp) : Compiler
|
|||
let params := #[{ fvarId, binderName := y, type := yType }]
|
||||
mkAuxFunDecl params code prefixName
|
||||
|
||||
def instantiateForall (type : Expr) (params : Array Param) : CoreM Expr :=
|
||||
go type 0
|
||||
where
|
||||
go (type : Expr) (i : Nat) : CoreM Expr :=
|
||||
if h : i < params.size then
|
||||
let p := params[i]
|
||||
match type with
|
||||
| .forallE _ _ b _ => go (b.instantiate1 (.fvar p.fvarId)) (i+1)
|
||||
| _ => throwError "invalid instantiateForall, too many parameters"
|
||||
else
|
||||
return type
|
||||
termination_by go i => params.size - i
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -37,28 +37,32 @@ inductive Element where
|
|||
deriving Inhabited
|
||||
|
||||
def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do
|
||||
let e ← mkAuxLetDecl e
|
||||
go seq.size (.return e.fvarId!)
|
||||
if let .fvar fvarId := e then
|
||||
go seq seq.size (.return fvarId)
|
||||
else
|
||||
let decl ← mkAuxLetDecl e
|
||||
let seq := seq.push (.let decl)
|
||||
go seq seq.size (.return decl.fvarId)
|
||||
where
|
||||
go (i : Nat) (c : Code) : CompilerM Code := do
|
||||
go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do
|
||||
if i > 0 then
|
||||
match seq[i-1]! with
|
||||
| .jp decl => go (i - 1) (.jp decl c)
|
||||
| .fun decl => go (i - 1) (.fun decl c)
|
||||
| .let decl => go (i - 1) (.let decl c)
|
||||
| .jp decl => go seq (i - 1) (.jp decl c)
|
||||
| .fun decl => go seq (i - 1) (.fun decl c)
|
||||
| .let decl => go seq (i - 1) (.let decl c)
|
||||
| .unreach => return .unreach (← c.inferType)
|
||||
| .cases fvarId cases =>
|
||||
if let .return fvarId' := c then
|
||||
if fvarId == fvarId' then
|
||||
go (i - 1) (.cases cases)
|
||||
go seq (i - 1) (.cases cases)
|
||||
else
|
||||
-- `cases` is dead code
|
||||
go (i - 1) c
|
||||
go seq (i - 1) c
|
||||
else
|
||||
/- Create a join point for `c` and jump to it from `cases` -/
|
||||
let jpDecl ← mkAuxJpDecl' fvarId c
|
||||
let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
go (i - 1) (.jp jpDecl cases)
|
||||
go seq (i - 1) (.jp jpDecl cases)
|
||||
else
|
||||
return c
|
||||
|
||||
|
|
|
|||
|
|
@ -237,4 +237,9 @@ def isClass? (type : Expr) : CoreM (Option Name) := do
|
|||
else
|
||||
return none
|
||||
|
||||
def getArrowArity (e : Expr) :=
|
||||
match e with
|
||||
| .forallE _ _ b _ => getArrowArity b + 1
|
||||
| _ => 0
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
Loading…
Add table
Reference in a new issue