chore: port ToLCNF

This commit is contained in:
Leonardo de Moura 2022-10-31 06:51:00 -07:00
parent 7e2c476a77
commit 01791b0c19
2 changed files with 133 additions and 125 deletions

View file

@ -53,9 +53,9 @@ def checkpoint (stepName : Name) (decls : Array Decl) : CompilerM Unit := do
if (← Lean.isTracingEnabledFor clsName) then
Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl' decl}"
if compiler.check.get (← getOptions) then
decl.check
pure () -- TODO: decl.check
if compiler.check.get (← getOptions) then
checkDeadLocalDecls decls
pure () -- TODO: checkDeadLocalDecls decls
namespace PassManager

View file

@ -70,7 +70,7 @@ where
findFun? (f : FVarId) : CompilerM (Option FunDecl) := do
if let some funDecl ← findFunDecl? f then
return funDecl
else if let some { value := .fvar f', .. } ← findLetDecl? f then
else if let some { value := .fvar f' #[], .. } ← findLetDecl? f then
findFun? f'
else
return none
@ -87,43 +87,43 @@ where
```
where `_alt.<idx>` is an auxiliary declaration created by `inlineMatcher`
-/
if decl.fvarId == fvarId && decl.value.isApp && decl.value.getAppFn.isFVar then
let f := decl.value.getAppFn.fvarId!
let binderName ← getBinderName f
if binderName.getPrefix == `_alt then
if let some funDecl ← findFun? f then
let args := decl.value.getAppArgs
eraseLetDecl decl
if let some altJp := (← get).find? f then
/- We already have an auxiliary join point for `f`, then, we just use it. -/
return .jmp altJp.fvarId args
else
/-
We have not created a join point for `f` yet.
The join point has the form
```
jp altJp jpParams :=
let _x := f jpParams
jmp jpDecl _x
```
Then, we replace the current `let`-declaration with `jmp altJp args`
-/
let mut jpParams := #[]
let mut subst := {}
let mut jpArgs := #[]
/- Remark: `funDecl.params.size` may be greater than `args.size`. -/
for param in funDecl.params[:args.size] do
let type ← replaceExprFVars param.type subst (translator := true)
let paramNew ← mkAuxParam type
jpParams := jpParams.push paramNew
let arg := .fvar paramNew.fvarId
subst := subst.insert param.fvarId arg
jpArgs := jpArgs.push arg
let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs)
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
let altJp ← mkAuxJpDecl jpParams jpValue
modify fun map => map.insert f altJp
return .jmp altJp.fvarId args
if decl.fvarId == fvarId then
match decl.value with
| .fvar f args =>
let binderName ← getBinderName f
if binderName.getPrefix == `_alt then
if let some funDecl ← findFun? f then
eraseLetDecl decl
if let some altJp := (← get).find? f then
/- We already have an auxiliary join point for `f`, then, we just use it. -/
return .jmp altJp.fvarId args
else
/-
We have not created a join point for `f` yet.
The join point has the form
```
jp altJp jpParams :=
let _x := f jpParams
jmp jpDecl _x
```
Then, we replace the current `let`-declaration with `jmp altJp args`
-/
let mut jpParams := #[]
let mut subst := {}
let mut jpArgs := #[]
/- Remark: `funDecl.params.size` may be greater than `args.size`. -/
for param in funDecl.params[:args.size] do
let type ← replaceExprFVars param.type subst (translator := true)
let paramNew ← mkAuxParam type
jpParams := jpParams.push paramNew
subst := subst.insert param.fvarId (Expr.fvar paramNew.fvarId)
jpArgs := jpArgs.push (Arg.fvar paramNew.fvarId)
let letDecl ← mkAuxLetDecl (.fvar f jpArgs)
let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId])
let altJp ← mkAuxJpDecl jpParams jpValue
modify fun map => map.insert f altJp
return .jmp altJp.fvarId args
| _ => pure ()
let k ← go k
if let some altJp := (← get).find? decl.fvarId then
-- The new join point depends on this variable. Thus, we must insert it here
@ -148,13 +148,8 @@ where
| .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
| .jmp .. | .unreach .. => return code
def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do
if let .fvar fvarId := e then
go seq seq.size (.return fvarId)
else
let decl ← mkAuxLetDecl e
let seq := seq.push (.let decl)
go seq seq.size (.return decl.fvarId)
def seqToCode (seq : Array Element) (k : Code) : CompilerM Code := do
go seq seq.size k
where
go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do
if i > 0 then
@ -200,7 +195,7 @@ structure State where
/-- Local context containing the original Lean types (not LCNF ones). -/
lctx : LocalContext := {}
/-- Cache from Lean regular expression to LCNF expression. -/
cache : PHashMap Expr Expr := {}
cache : PHashMap Expr LetExpr := {}
/-- `toLCNFType` cache -/
typeCache : HashMap Expr Expr := {}
/-- isTypeFormerType cache -/
@ -221,26 +216,27 @@ abbrev M := StateRefT State CompilerM
@[inline] def liftMetaM (x : MetaM α) : M α := do
x.run' { lctx := (← get).lctx }
/-- Create `Code` that executes the current `seq` and then returns `e` -/
def toCode (e : Expr) : M Code := do
seqToCode (← get).seq e
/-- Add LCNF element to the current sequence -/
def pushElement (elem : Element) : M Unit := do
modify fun s => { s with seq := s.seq.push elem }
def mkUnreachable (type : Expr) : M Expr := do
def mkUnreachable (type : Expr) : M LetExpr := do
let p ← mkAuxParam type
pushElement (.unreach p)
return .fvar p.fvarId
return .fvar p.fvarId #[]
def mkAuxLetDecl (e : Expr) (prefixName := `_x) : M Expr := do
if e.isFVar then
return e
else
let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e
def mkAuxLetDecl (e : LetExpr) (prefixName := `_x) : M FVarId := do
match e with
| .fvar fvarId #[] => return fvarId
| _ =>
let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← e.inferType) e
pushElement (.let letDecl)
return .fvar letDecl.fvarId
return letDecl.fvarId
/-- Create `Code` that executes the current `seq` and then returns `result` -/
def toCode (result : LetExpr) : M Code := do
let fvarId ← mkAuxLetDecl result
seqToCode (← get).seq (.return fvarId)
def run (x : M α) : CompilerM α :=
x |>.run' {}
@ -325,7 +321,7 @@ def mkParam (binderName : Name) (type : Expr) : M Param := do
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (value' : Expr) : M LetDecl := do
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (value' : LetExpr) : M LetDecl := do
let binderName ← cleanupBinderName binderName
let letDecl ← LCNF.mkLetDecl binderName type' value'
modify fun s => { s with
@ -396,6 +392,11 @@ partial def etaReduceImplicit (e : Expr) : Expr :=
e
| _ => e
def litToValue (lit : Literal) : Value :=
match lit with
| .natVal val => .natVal val
| .strVal val => .strVal val
/--
Put the given expression in `LCNF`.
@ -405,73 +406,77 @@ Put the given expression in `LCNF`.
-/
partial def toLCNF (e : Expr) : CompilerM Code := do
run do
let e ← visit e
toCode e
let e' ← visit e
toCode e'
where
visitCore (e : Expr) : M Expr := withIncRecDepth do
if let some e := (← get).cache.find? e then
return e
let r ← match e with
visitCore (e : Expr) : M LetExpr := withIncRecDepth do
if let some fvarId := (← get).cache.find? e then
return fvarId
let r : LetExpr ← match e with
| .app .. => visitApp e
| .const .. => visitApp e
| .proj s i e => visitProj s i e
| .mdata d e => visitMData d e
| .lam .. => visitLambda e
| .letE .. => visitLet e #[]
| .lit .. => mkAuxLetDecl e
| .forallE .. => unreachable!
| .mvar .. => throwError "unexpected occurrence of metavariable in code generator{indentExpr e}"
| .bvar .. => unreachable!
| .fvar fvarId => if (← get).toAny.contains fvarId then pure erasedExpr else pure e
| _ => pure e
| .lit val => return .value (litToValue val)
| .fvar fvarId => if (← get).toAny.contains fvarId then pure .erased else pure (.fvar fvarId #[])
| .forallE .. | .mvar .. | .bvar .. | .sort .. => unreachable!
modify fun s => { s with cache := s.cache.insert e r }
return r
visit (e : Expr) : M Expr := withIncRecDepth do
visit (e : Expr) : M LetExpr := withIncRecDepth do
if isLCProof e then
return erasedExpr
return .erased
let type ← liftMetaM <| Meta.inferType e
if (← liftMetaM <| Meta.isProp type) then
/- We erase proofs. -/
return erasedExpr
return .erased
if (← isTypeFormerType type) then
/-
We erase type formers unless they occur as application arguments.
Recall that we usually do not generate code for functions that return type,
by this branch can be reachable if we cannot establish whether the function produces a type or not.
See `visitAppArg`
-/
return erasedExpr
return .erased
visitCore e
visitAppArg (e : Expr) : M Expr := do
visitAppArg (e : Expr) : M Arg := do
if isLCProof e then
return erasedExpr
return .erased
let type ← liftMetaM <| Meta.inferType e
if (← liftMetaM <| Meta.isProp type) then
/- We erase proofs. -/
return erasedExpr
return .erased
if (← isTypeFormerType type) then
/- Predicates are erased (e.g., `Eq`) -/
if isPredicateType (← toLCNFType type) then
return erasedExpr
return .erased
else
/- Types and Type formers are not put into A-normal form -/
toLCNFType e
return .type (← toLCNFType e)
else
visitCore e
match (← visitCore e) with
| .erased => return .erased
| e => return .fvar (← mkAuxLetDecl e)
/-- Visit args, and return `f args` -/
visitAppDefault (f : Expr) (args : Array Expr) : M Expr := do
if f.isErased then
return f
else
visitAppDefault (f : LetExpr) (args : Array Expr) : M LetExpr := do
match f with
| .erased => return .erased
| _ =>
let f ← mkAuxLetDecl f
let args ← args.mapM visitAppArg
mkAuxLetDecl <| mkAppN f args
return .fvar f args
/-- Giving `f` a constant `.const declName us`, convert `args` into `args'`, and return `.const declName us args'` -/
visitAppDefaultConst (f : Expr) (args : Array Expr) : M LetExpr := do
let .const declName us := f | unreachable!
let args ← args.mapM visitAppArg
return .const declName us args
/-- Eta expand if under applied, otherwise apply k -/
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Expr) : M Expr := do
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M LetExpr) : M LetExpr := do
let numArgs := e.getAppNumArgs
if numArgs < arity then
visit (← etaExpandN e (arity - numArgs))
@ -486,15 +491,15 @@ where
k args[arity:]
```
-/
mkOverApplication (app : Expr) (args : Array Expr) (arity : Nat) : M Expr := do
mkOverApplication (app : LetExpr) (args : Array Expr) (arity : Nat) : M LetExpr := do
if args.size == arity then
mkAuxLetDecl app
return app
else
let k ← mkAuxLetDecl app
let mut args := args
for i in [arity : args.size] do
args ← args.modifyM i visitAppArg
mkAuxLetDecl (mkAppN k args[arity:])
let k ← mkAuxLetDecl app
let mut argsNew := #[]
for i in [arity : args.size] do
argsNew := argsNew.push (← visitAppArg args[i]!)
return .fvar k argsNew
/--
Visit a `matcher`/`casesOn` alternative.
@ -528,7 +533,7 @@ where
let altType ← c.inferType
return (altType, .alt ctorName ps c)
visitCases (casesInfo : CasesInfo) (e : Expr) : M Expr :=
visitCases (casesInfo : CasesInfo) (e : Expr) : M LetExpr :=
etaIfUnderApplied e casesInfo.arity do
let args := e.getAppArgs
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[:casesInfo.arity]))
@ -540,31 +545,32 @@ where
let typeName := casesInfo.declName.getPrefix
let discr ← visitAppArg args[casesInfo.discrPos]!
let .inductInfo indVal ← getConstInfo typeName | unreachable!
if !discr.isFVar then
match discr with
| .erased | .type .. =>
/-
This can happen for inductive predicates that can eliminate into type (e.g., `And`, `Iff`).
TODO: add support for them. Right now, we have hard-coded support for the ones defined at `Init`.
-/
throwError "unsupported `{casesInfo.declName}` application during code generation"
else
| .fvar discrFVarId =>
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams, ctorName in indVal.ctors do
let (altType, alt) ← visitAlt ctorName numParams args[i]!
resultType := joinTypes altType resultType
alts := alts.push alt
let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts }
let cases : Cases := { typeName, discr := discrFVarId, resultType, alts }
let auxDecl ← mkAuxParam resultType
pushElement (.cases auxDecl cases)
let result := .fvar auxDecl.fvarId
let result := .fvar auxDecl.fvarId #[]
if args.size == casesInfo.arity then
return result
else
mkOverApplication result args casesInfo.arity
visitCtor (arity : Nat) (e : Expr) : M Expr :=
visitCtor (arity : Nat) (e : Expr) : M LetExpr :=
etaIfUnderApplied e arity do
visitAppDefault e.getAppFn e.getAppArgs
visitAppDefaultConst e.getAppFn e.getAppArgs
visitQuotLift (e : Expr) : M Expr := do
visitQuotLift (e : Expr) : M LetExpr := do
let arity := 6
etaIfUnderApplied e arity do
let mut args := e.getAppArgs
@ -573,11 +579,13 @@ where
let f ← visitAppArg args[3]!
let q ← visitAppArg args[5]!
let .const _ [u, _] := e.getAppFn | unreachable!
let invq ← mkAuxLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q)
let r := mkApp f invq
mkOverApplication r args arity
let invq ← mkAuxLetDecl (.const ``Quot.lcInv [u] #[.type α, .type r, q])
match f with
| .erased => return .erased
| .type _ => unreachable!
| .fvar fvarId => mkOverApplication (.fvar fvarId #[.fvar invq]) args arity
visitEqRec (e : Expr) : M Expr :=
visitEqRec (e : Expr) : M LetExpr :=
let arity := 6
etaIfUnderApplied e arity do
let args := e.getAppArgs
@ -585,13 +593,13 @@ where
let minor ← visit minor
mkOverApplication minor args arity
visitFalseRec (e : Expr) : M Expr :=
visitFalseRec (e : Expr) : M LetExpr :=
let arity := 2
etaIfUnderApplied e arity do
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
mkUnreachable type
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M Expr :=
visitAndIffRecCore (e : Expr) (minorPos : Nat) : M LetExpr :=
let arity := 5
etaIfUnderApplied e arity do
let args := e.getAppArgs
@ -601,7 +609,7 @@ where
let minor := minor.beta #[ha, hb]
visit (mkAppN minor args[arity:])
visitNoConfusion (e : Expr) : M Expr := do
visitNoConfusion (e : Expr) : M LetExpr := do
let .const declName _ := e.getAppFn | unreachable!
let typeName := declName.getPrefix
let .inductInfo inductVal ← getConstInfo typeName | unreachable!
@ -635,7 +643,7 @@ where
else
expandNoConfusionMajor (← etaExpandN major (n+1)) (n+1)
visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M Expr := do
visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M LetExpr := do
let typeName := projInfo.ctorName.getPrefix
if isRuntimeBultinType typeName then
let numArgs := e.getAppNumArgs
@ -643,14 +651,14 @@ where
if numArgs < arity then
visit (← etaExpandN e (arity - numArgs))
else
visitAppDefault e.getAppFn e.getAppArgs
visitAppDefaultConst e.getAppFn e.getAppArgs
else
let .const declName us := e.getAppFn | unreachable!
let info ← getConstInfo declName
let f ← Core.instantiateValueLevelParams info us
visit (f.beta e.getAppArgs)
visitApp (e : Expr) : M Expr := do
visitApp (e : Expr) : M LetExpr := do
if let .const declName _ := e.getAppFn then
if declName == ``Quot.lift then
visitQuotLift e
@ -673,11 +681,11 @@ where
else if let some projInfo ← getProjectionFnInfo? declName then
visitProjFn projInfo e
else
e.withApp visitAppDefault
e.withApp visitAppDefaultConst
else
e.withApp fun f args => do visitAppDefault (← visit f) args
visitLambda (e : Expr) : M Expr := do
visitLambda (e : Expr) : M LetExpr := do
let b := etaReduceImplicit e
/-
Note: we don't want to eta-reduce arbitrary lambda expressions since it can
@ -711,20 +719,20 @@ where
let c ← toCode e
mkAuxFunDecl ps c
pushElement (.fun funDecl)
return .fvar funDecl.fvarId
return .fvar funDecl.fvarId #[]
visitMData (mdata : MData) (e : Expr) : M Expr := do
visitMData (mdata : MData) (e : Expr) : M LetExpr := do
if let some (.app (.lam n t b ..) v) := letFunAnnotation? (.mdata mdata e) then
visitLet (.letE n t v b (nonDep := true)) #[]
else if isCompilerRelevantMData mdata then
mkAuxLetDecl <| .mdata mdata (← visit e)
else
visit e
visitProj (s : Name) (i : Nat) (e : Expr) : M Expr := do
mkAuxLetDecl <| .proj s i (← visit e)
visitProj (s : Name) (i : Nat) (e : Expr) : M LetExpr := do
match (← visit e) with
| .erased => return .erased
| e => return .proj s i (← mkAuxLetDecl e)
visitLet (e : Expr) (xs : Array Expr) : M Expr := do
visitLet (e : Expr) (xs : Array Expr) : M LetExpr := do
match e with
| .letE binderName type value body _ =>
let type := type.instantiateRev xs