chore: port ConstantFold.lean
This commit is contained in:
parent
e6232b67b6
commit
fd316ef027
1 changed files with 75 additions and 127 deletions
|
|
@ -7,9 +7,6 @@ import Lean.Compiler.LCNF.CompilerM
|
|||
import Lean.Compiler.LCNF.InferType
|
||||
import Lean.Compiler.LCNF.PassManager
|
||||
|
||||
set_option warningAsError false
|
||||
#exit
|
||||
|
||||
namespace Lean.Compiler.LCNF.Simp
|
||||
namespace ConstantFold
|
||||
|
||||
|
|
@ -24,7 +21,7 @@ A constant folder for a specific function, takes all the arguments of a
|
|||
certain function and produces a new `Expr` + auxiliary declarations in
|
||||
the `FolderM` monad on success. If the folding fails it returns `none`.
|
||||
-/
|
||||
abbrev Folder := Array Expr → FolderM (Option Expr)
|
||||
abbrev Folder := Array Arg → FolderM (Option LetExpr)
|
||||
|
||||
/--
|
||||
A typeclass for detecting and producing literals of arbitrary types
|
||||
|
|
@ -32,25 +29,25 @@ inside of LCNF.
|
|||
-/
|
||||
class Literal (α : Type) where
|
||||
/--
|
||||
Attempt to turn the provied `Expr` into a value of type `α` if
|
||||
Attempt to turn the provided `Expr` into a value of type `α` if
|
||||
it is whatever concept of a literal `α` has. Note that this function
|
||||
does assume that the provided `Expr` does indeed have type `α`.
|
||||
-/
|
||||
getLit : Expr → CompilerM (Option α)
|
||||
getLit : FVarId → CompilerM (Option α)
|
||||
/--
|
||||
Turn a value of type `α` into a series of auxiliary `LetDecl`s + a
|
||||
final `Expr` putting them all together into a literal of type `α`,
|
||||
where again the idea of what a literal is depends on `α`.
|
||||
-/
|
||||
mkLit : α → FolderM Expr
|
||||
mkLit : α → FolderM LetExpr
|
||||
|
||||
export Literal (getLit mkLit)
|
||||
|
||||
/--
|
||||
A wrapper around `LCNF.mkAuxLetDecl` that will automaticaly store the
|
||||
A wrapper around `LCNF.mkAuxLetDecl` that will automatically store the
|
||||
`LetDecl` in the state of `FolderM`.
|
||||
-/
|
||||
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : FolderM FVarId := do
|
||||
def mkAuxLetDecl (e : LetExpr) (prefixName := `_x) : FolderM FVarId := do
|
||||
let decl ← LCNF.mkAuxLetDecl e prefixName
|
||||
modify fun s => s.push <| .let decl
|
||||
return decl.fvarId
|
||||
|
|
@ -64,60 +61,39 @@ def mkAuxLit [Literal α] (x : α) (prefixName := `_x) : FolderM FVarId := do
|
|||
let lit ← mkLit x
|
||||
mkAuxLetDecl lit prefixName
|
||||
|
||||
partial def getNatLit (e : Expr) : CompilerM (Option Nat) := do
|
||||
match e with
|
||||
| .lit (.natVal n) .. => return n
|
||||
| .fvar fvarId .. =>
|
||||
if let some decl ← findLetDecl? fvarId then
|
||||
getNatLit decl.value
|
||||
else
|
||||
return none
|
||||
| _ => return none
|
||||
partial def getNatLit (fvarId : FVarId) : CompilerM (Option Nat) := do
|
||||
let some (.value (.natVal n)) ← findLetExpr? fvarId | return none
|
||||
return n
|
||||
|
||||
def mkNatLit (n : Nat) : FolderM Expr :=
|
||||
return .lit (.natVal n)
|
||||
def mkNatLit (n : Nat) : FolderM LetExpr :=
|
||||
return .value (.natVal n)
|
||||
|
||||
instance : Literal Nat where
|
||||
getLit := getNatLit
|
||||
mkLit := mkNatLit
|
||||
|
||||
partial def getStringLit (e : Expr) : CompilerM (Option String) := do
|
||||
match e with
|
||||
| .lit (.strVal n) .. => return n
|
||||
| .fvar fvarId .. =>
|
||||
let some decl ← findLetDecl? fvarId | return none
|
||||
getStringLit decl.value
|
||||
| _ => return none
|
||||
partial def getStringLit (fvarId : FVarId) : CompilerM (Option String) := do
|
||||
let some (.value (.strVal s)) ← findLetExpr? fvarId | return none
|
||||
return s
|
||||
|
||||
def mkStringLit (n : String) : FolderM Expr :=
|
||||
return .lit (.strVal n)
|
||||
def mkStringLit (n : String) : FolderM LetExpr :=
|
||||
return .value (.strVal n)
|
||||
|
||||
instance : Literal String where
|
||||
getLit := getStringLit
|
||||
mkLit := mkStringLit
|
||||
|
||||
private partial def getLitAux [Inhabited α] (e : Expr) (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : CompilerM (Option α) := do
|
||||
match e with
|
||||
| .fvar fvarId .. =>
|
||||
if let some decl ←findLetDecl? fvarId then
|
||||
getLitAux decl.value ofNat ofNatName toNat
|
||||
else
|
||||
return none
|
||||
| .app .. =>
|
||||
match e.getAppFn, e.getAppArgs with
|
||||
| .const name .., #[.fvar fvarId] =>
|
||||
if name == ofNatName then
|
||||
if let some natLit ← getLit (.fvar fvarId) then
|
||||
return ofNat natLit
|
||||
return none
|
||||
| _, _ => return none
|
||||
| _ => return none
|
||||
private partial def getLitAux [Inhabited α] (fvarId : FVarId) (ofNat : Nat → α) (ofNatName : Name) : CompilerM (Option α) := do
|
||||
let some (.const declName _ #[.fvar fvarId]) ← findLetExpr? fvarId | return none
|
||||
unless declName == ofNatName do return none
|
||||
let some natLit ← getLit fvarId | return none
|
||||
return ofNat natLit
|
||||
|
||||
def mkNatWrapperInstance [Inhabited α] (ofNat : Nat → α) (ofNatName : Name) (toNat : α → Nat) : Literal α where
|
||||
getLit := (getLitAux · ofNat ofNatName toNat)
|
||||
getLit := (getLitAux · ofNat ofNatName)
|
||||
mkLit x := do
|
||||
let helperId ← mkAuxLit <| toNat x
|
||||
return .app (mkConst ofNatName) (.fvar helperId)
|
||||
return .const ofNatName [] #[.fvar helperId]
|
||||
|
||||
instance : Literal UInt8 := mkNatWrapperInstance UInt8.ofNat ``UInt8.ofNat UInt8.toNat
|
||||
instance : Literal UInt16 := mkNatWrapperInstance UInt16.ofNat ``UInt16.ofNat UInt16.toNat
|
||||
|
|
@ -139,26 +115,17 @@ let _x.6 := @List.cons _ e _x.5
|
|||
```
|
||||
into: `[a, b, c, d ,e]` + The type contained in the list
|
||||
-/
|
||||
partial def getPseudoListLiteral (e : Expr) : CompilerM (Option (List FVarId × Expr × Level)) := do
|
||||
go e []
|
||||
partial def getPseudoListLiteral (fvarId : FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
|
||||
go fvarId []
|
||||
where
|
||||
go (e : Expr) (fvarIds : List FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
|
||||
go (fvarId : FVarId) (fvarIds : List FVarId) : CompilerM (Option (List FVarId × Expr × Level)) := do
|
||||
let some e ← findLetExpr? fvarId | return none
|
||||
match e with
|
||||
| .app .. =>
|
||||
if some ``List.nil == e.getAppFn.constName? then
|
||||
return some (fvarIds.reverse, e.getAppArgs[0]!, e.getAppFn.constLevels![0]!)
|
||||
else if some ``List.cons == e.getAppFn.constName? then
|
||||
let args := e.getAppArgs
|
||||
go args[2]! (args[1]!.fvarId! :: fvarIds)
|
||||
else
|
||||
return none
|
||||
| .fvar fvarId =>
|
||||
if let some decl ← findLetDecl? fvarId then
|
||||
go decl.value fvarIds
|
||||
else
|
||||
return none
|
||||
| _ =>
|
||||
return none
|
||||
| .const ``List.nil [u] #[.type α] =>
|
||||
return some (fvarIds.reverse, α, u)
|
||||
| .const ``List.cons _ #[_, .fvar h, .fvar t] =>
|
||||
go t (h :: fvarIds)
|
||||
| _ => return none
|
||||
|
||||
/--
|
||||
Turn an `#[a, b, c]` into:
|
||||
|
|
@ -171,12 +138,12 @@ let _x.26 := @Array.push _ _x.24 z
|
|||
_x.26
|
||||
```
|
||||
-/
|
||||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM Expr := do
|
||||
def mkPseudoArrayLiteral (elements : Array FVarId) (typ : Expr) (typLevel : Level) : FolderM LetExpr := do
|
||||
let sizeLit ← mkAuxLit elements.size
|
||||
let mut literal ← mkAuxLetDecl <| mkApp2 (mkConst ``Array.mkEmpty [typLevel]) typ (.fvar sizeLit)
|
||||
let mut literal ← mkAuxLetDecl <| .const ``Array.mkEmpty [typLevel] #[.type typ, .fvar sizeLit]
|
||||
for element in elements do
|
||||
literal ← mkAuxLetDecl <| mkApp3 (mkConst ``Array.push [typLevel]) typ (.fvar literal) (.fvar element)
|
||||
return .fvar literal
|
||||
literal ← mkAuxLetDecl <| .const ``Array.push [typLevel] #[.type typ, .fvar literal, .fvar element]
|
||||
return .fvar literal #[]
|
||||
|
||||
/--
|
||||
Evaluate array literals at compile time, that is turn:
|
||||
|
|
@ -196,84 +163,66 @@ let _x.24 := @Array.push _ _x.22 y
|
|||
let _x.26 := @Array.push _ _x.24 z
|
||||
```
|
||||
-/
|
||||
def foldArrayLiteral : Folder := fun exprs => do
|
||||
if h:exprs.size = 2 then
|
||||
have h1 : 1 < Array.size exprs := by simp_all
|
||||
if let some (list, typ, level) ← getPseudoListLiteral exprs[1] then
|
||||
let arr := Array.mk list
|
||||
let lit ← mkPseudoArrayLiteral arr typ level
|
||||
return some lit
|
||||
return none
|
||||
def foldArrayLiteral : Folder := fun args => do
|
||||
let #[_, .fvar fvarId] := args | return none
|
||||
let some (list, typ, level) ← getPseudoListLiteral fvarId | return none
|
||||
let arr := Array.mk list
|
||||
let lit ← mkPseudoArrayLiteral arr typ level
|
||||
return some lit
|
||||
|
||||
/--
|
||||
Turn a unary function such as `Nat.succ` into a constant folder.
|
||||
-/
|
||||
def Folder.mkUnary [Literal α] [Literal β] (folder : α → β) : Folder := fun exprs => do
|
||||
if h:exprs.size = 1 then
|
||||
have h1 : 0 < Array.size exprs := by simp_all
|
||||
if let some arg1 ← getLit exprs[0] then
|
||||
let res := folder arg1
|
||||
return (←mkLit res)
|
||||
return none
|
||||
def Folder.mkUnary [Literal α] [Literal β] (folder : α → β) : Folder := fun args => do
|
||||
let #[.fvar fvarId] := args | return none
|
||||
let some arg1 ← getLit fvarId | return none
|
||||
let res := folder arg1
|
||||
mkLit res
|
||||
|
||||
/--
|
||||
Turn a binary function such as `Nat.add` into a constant folder.
|
||||
-/
|
||||
def Folder.mkBinary [Literal α] [Literal β] [Literal γ] (folder : α → β → γ) : Folder := fun exprs => do
|
||||
if h:exprs.size = 2 then
|
||||
have h1 : 0 < Array.size exprs := by simp_all
|
||||
have h2 : 1 < Array.size exprs := by simp_all
|
||||
if let some arg1 ← getLit exprs[0] then
|
||||
if let some arg2 ← getLit exprs[1] then
|
||||
let res := folder arg1 arg2
|
||||
return (←mkLit res)
|
||||
return none
|
||||
def Folder.mkBinary [Literal α] [Literal β] [Literal γ] (folder : α → β → γ) : Folder := fun args => do
|
||||
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
|
||||
let some arg₁ ← getLit fvarId₁ | return none
|
||||
let some arg₂ ← getLit fvarId₂ | return none
|
||||
mkLit <| folder arg₁ arg₂
|
||||
|
||||
/--
|
||||
Provide a folder for an operation with a left neutral element.
|
||||
-/
|
||||
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
|
||||
if h:exprs.size = 2 then
|
||||
have h1 : 0 < Array.size exprs := by simp_all
|
||||
have h2 : 1 < Array.size exprs := by simp_all
|
||||
if let some arg1 ← getLit exprs[0] then
|
||||
if arg1 == neutral then
|
||||
return exprs[1]
|
||||
return none
|
||||
def Folder.leftNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
|
||||
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
|
||||
let some arg₁ ← getLit fvarId₁ | return none
|
||||
unless arg₁ == neutral do return none
|
||||
return some <| .fvar fvarId₂ #[]
|
||||
|
||||
/--
|
||||
Provide a folder for an operation with a right neutral element.
|
||||
-/
|
||||
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun exprs => do
|
||||
if h:exprs.size = 2 then
|
||||
have h1 : 0 < Array.size exprs := by simp_all
|
||||
have h2 : 1 < Array.size exprs := by simp_all
|
||||
if let some arg2 ← getLit exprs[1] then
|
||||
if arg2 == neutral then
|
||||
return exprs[0]
|
||||
return none
|
||||
def Folder.rightNeutral [Literal α] [BEq α] (neutral : α) : Folder := fun args => do
|
||||
let #[.fvar fvarId₁, .fvar fvarId₂] := args | return none
|
||||
let some arg₂ ← getLit fvarId₂ | return none
|
||||
unless arg₂ == neutral do return none
|
||||
return some <| .fvar fvarId₁ #[]
|
||||
|
||||
/--
|
||||
Provide a folder for an operation with a left annihilator.
|
||||
-/
|
||||
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
|
||||
if h:exprs.size = 2 then
|
||||
have h1 : 0 < Array.size exprs := by simp_all
|
||||
if let some arg1 ← getLit exprs[0] then
|
||||
if arg1 == annihilator then
|
||||
return (←mkLit zero)
|
||||
return none
|
||||
def Folder.leftAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
|
||||
let #[.fvar fvarId, _] := args | return none
|
||||
let some arg ← getLit fvarId | return none
|
||||
unless arg == annihilator do return none
|
||||
mkLit zero
|
||||
|
||||
/--
|
||||
Provide a folder for an operation with a right annihilator.
|
||||
-/
|
||||
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun exprs => do
|
||||
if h:exprs.size = 2 then
|
||||
have h1 : 1 < Array.size exprs := by simp_all
|
||||
if let some arg2 ← getLit exprs[1] then
|
||||
if arg2 == annihilator then
|
||||
return (←mkLit zero)
|
||||
return none
|
||||
def Folder.rightAnnihilator [Literal α] [BEq α] (annihilator : α) (zero : α) : Folder := fun args => do
|
||||
let #[_, .fvar fvarId] := args | return none
|
||||
let some arg ← getLit fvarId | return none
|
||||
unless arg == annihilator do return none
|
||||
mkLit zero
|
||||
|
||||
/--
|
||||
Pick the first folder out of `folders` that succeeds.
|
||||
|
|
@ -346,11 +295,10 @@ def stringFolders : List (Name × Folder) := [
|
|||
Apply all known folders to `decl`.
|
||||
-/
|
||||
def applyFolders (decl : LetDecl) (folders : SMap Name Folder) : CompilerM (Option (Array CodeDecl)) := do
|
||||
let e := decl.value
|
||||
match e.getAppFn with
|
||||
| .const name .. =>
|
||||
match decl.value with
|
||||
| .const name _ args =>
|
||||
if let some folder := folders.find? name then
|
||||
if let (some res, aux) ← folder e.getAppArgs |>.run #[] then
|
||||
if let (some res, aux) ← folder args |>.run #[] then
|
||||
let decl ← decl.updateValue res
|
||||
return some <| aux.push (.let decl)
|
||||
return none
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue