feat: att new LCNF/toLCNF.lean

This commit is contained in:
Leonardo de Moura 2022-08-23 20:21:06 -07:00
parent 3e2f8c61ec
commit 083523ee9c
2 changed files with 494 additions and 18 deletions

View file

@ -0,0 +1,487 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.ProjFns
import Lean.Compiler.LCNF.Types
import Lean.Compiler.LCNF.Bind
import Lean.Compiler.LCNF.InferType
import Lean.Compiler.LCNF.Util
namespace Lean.Compiler.LCNF
namespace ToLCNF
/--
Return `true` if `e` is a `lcProof` application.
Recall that we use `lcProof` to erase all nested proofs.
-/
def isLCProof (e : Expr) : Bool :=
e.isAppOfArity ``lcProof 1
/-- Create the temporary `lcProof` -/
def mkLcProof (p : Expr) :=
mkApp (mkConst ``lcProof []) p
inductive Element where
| jp (decl : FunDecl)
| fun (decl : FunDecl)
| let (decl : LetDecl)
| cases (fvarId : FVarId) (cases : Cases)
| unreach
deriving Inhabited
def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do
let e ← mkAuxLetDecl e
go seq.size (.return e.fvarId!)
where
go (i : Nat) (c : Code) : CompilerM Code := do
if i > 0 then
match seq[i-1]! with
| .jp decl => go (i - 1) (.jp decl c)
| .fun decl => go (i - 1) (.fun decl c)
| .let decl => go (i - 1) (.let decl c)
| .unreach => return .unreach (← c.inferType)
| .cases fvarId cases =>
if let .return fvarId' := c then
if fvarId == fvarId' then
return .cases cases
else
-- `cases` is dead code
go (i - 1) c
else
/- Create a join point for `c` and jump to it from `cases` -/
let jpDecl ← mkAuxJpDecl' fvarId c
let cases ← (Code.cases cases).bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
go (i - 1) (.jp jpDecl cases)
else
return c
structure State where
/-- Local context containing the original Lean types (not LCNF ones). -/
lctx : LocalContext := {}
/-- Cache from Lean regular expression to LCNF expression. -/
cache : Std.PHashMap Expr Expr := {}
/-- `toLCNFType` cache -/
typeCache : Std.HashMap Expr Expr := {}
/-- isTypeFormerType cache -/
isTypeFormerTypeCache : Std.HashMap Expr Bool := {}
/-- LCNF sequence, we chain it to create a LCNF `Code` object. -/
seq : Array Element := #[]
abbrev M := StateRefT State CompilerM
@[inline] def liftMetaM (x : MetaM α) : M α := do
x.run' { lctx := (← get).lctx }
/-- Create `Code` that executes the current `seq` and then returns `e` -/
def toCode (e : Expr) : M Code := do
seqToCode (← get).seq e
/-- Add LCNF element to the current sequence -/
def pushElement (elem : Element) : M Unit := do
modify fun s => { s with seq := s.seq.push elem }
def mkUnreachable (type : Expr) : M Expr := do
pushElement .unreach
return .fvar (← mkAuxParam type).fvarId
def run (x : M α) : CompilerM α :=
x |>.run' {}
/--
Return true iff `type` is `Sort _` or `As → Sort _`.
-/
private partial def isTypeFormerType (type : Expr) : M Bool := do
match quick (← getEnv) type with
| .true => return true
| .false => return false
| .undef =>
if let some result := (← get).isTypeFormerTypeCache.find? type then
return result
let result ← liftMetaM <| Meta.isTypeFormerType type
modify fun s => { s with isTypeFormerTypeCache := s.isTypeFormerTypeCache.insert type result }
return result
where
quick (env : Environment) : Expr → LBool
| .forallE _ _ b _ => quick env b
| .mdata _ b => quick env b
| .letE .. => .undef
| .sort _ => .true
| .bvar .. => .false
| type =>
match type.getAppFn with
| .bvar .. => .false
| .const declName _ =>
if let some (.inductInfo ..) := env.find? declName then
.false
else
.undef
| _ => .undef
def withNewScope (x : M α) : M α := do
let saved ← get
-- typeCache and isTypeFormerTypeCache are not backtrackable
let saved := { saved with typeCache := {}, isTypeFormerTypeCache := {} }
modify fun s => { s with seq := #[] }
try
x
finally
let saved := { saved with
typeCache := (← get).typeCache
isTypeFormerTypeCache := (← get).isTypeFormerTypeCache
}
set saved
def toLCNFType (type : Expr) : M Expr := do
match (← get).typeCache.find? type with
| some type' => return type'
| none =>
let type' ← liftMetaM <| LCNF.toLCNFType type
modify fun s => { s with typeCache := s.typeCache.insert type type' }
return type'
/-- Create a new local declaration using a Lean regular type. -/
def mkParam (binderName : Name) (type : Expr) : M Param := do
let type' ← toLCNFType type
let param ← LCNF.mkParam binderName type'
modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default }
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (value' : Expr) : M LetDecl := do
let binderName ← if binderName.eraseMacroScopes.isInternal then mkFreshBinderName binderName.eraseMacroScopes else pure binderName
let letDecl ← LCNF.mkLetDecl binderName type' value'
modify fun s => { s with
lctx := s.lctx.mkLetDecl letDecl.fvarId binderName type value false
seq := s.seq.push <| .let letDecl
}
return letDecl
def visitLambda (e : Expr) : M (Array Param × Expr) :=
go e #[] #[]
where
go (e : Expr) (xs : Array Expr) (ps : Array Param) := do
if let .lam binderName type body _ := e then
let type := type.instantiateRev xs
let p ← mkParam binderName type
go body (xs.push p.toExpr) (ps.push p)
else
return (ps, e.instantiateRev xs)
def visitBoundedLambda (e : Expr) (n : Nat) : M (Array Param × Expr) :=
go e n #[] #[]
where
go (e : Expr) (n : Nat) (xs : Array Expr) (ps : Array Param) := do
if n == 0 then
return (ps, e.instantiateRev xs)
else if let .lam binderName type body _ := e then
let type := type.instantiateRev xs
let p ← mkParam binderName type
go body (n-1) (xs.push p.toExpr) (ps.push p)
else
return (ps, e.instantiateRev xs)
/--
Eta-expand with `n` lambdas.
-/
def etaExpandN (e : Expr) (n : Nat) : M Expr := do
if n == 0 then
return e
else liftMetaM do
Meta.forallBoundedTelescope (← Meta.inferType e) n fun xs _ =>
Meta.mkLambdaFVars xs (mkAppN e xs)
/--
Put the given expression in `LCNF`.
- Nested proofs are replaced with `lcProof`-applications.
- Eta-expand applications of declarations that satisfy `shouldEtaExpand`.
- Put computationally relevant expressions in A-normal form.
-/
partial def toLCNF (e : Expr) : CompilerM Code := do
run do
let e ← visit e
toCode e
where
visitCore (e : Expr) : M Expr := withIncRecDepth do
if let some e := (← get).cache.find? e then
return e
let r ← match e with
| .app .. => visitApp e
| .const .. => visitApp e
| .proj s i e => visitProj s i e
| .mdata d e => visitMData d e
| .lam .. => visitLambda e
| .letE .. => visitLet e #[]
| .lit .. => mkAuxLetDecl e
| .forallE .. => unreachable!
| .mvar .. => throwError "unexpected occurrence of metavariable in code generator{indentExpr e}"
| .bvar .. => unreachable!
| _ => pure e
modify fun s => { s with cache := s.cache.insert e r }
return r
visit (e : Expr) : M Expr := withIncRecDepth do
if isLCProof e then
return mkConst ``lcErased
let type ← liftMetaM <| Meta.inferType e
if (← liftMetaM <| Meta.isProp type) then
/- We erase proofs. -/
return mkConst ``lcErased
if (← isTypeFormerType type) then
/-
We erase type formers unless they occur as application arguments.
Recall that we usually do not generate code for functions that return type,
by this branch can be reachable if we cannot establish whether the function produces a type or not.
See `visitAppArg`
-/
return mkConst ``lcErased
visitCore e
visitAppArg (e : Expr) : M Expr := do
if isLCProof e then
return mkConst ``lcErased
let type ← liftMetaM <| Meta.inferType e
if (← liftMetaM <| Meta.isProp type) then
/- We erase proofs. -/
return mkConst ``lcErased
if (← isTypeFormerType type) then
/- Types and Type formers are not put into A-normal form -/
toLCNFType e
else
visitCore e
/-- Visit args, and return `f args` -/
visitAppDefault (f : Expr) (args : Array Expr) : M Expr := do
if f.isConstOf ``lcErased then
return f
else
let args ← args.mapM visitAppArg
mkAuxLetDecl <| mkAppN f args
/-- Eta expand if under applied, otherwise apply k -/
etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Expr) : M Expr := do
let numArgs := e.getAppNumArgs
if numArgs < arity then
visit (← etaExpandN e (arity - numArgs))
else
k
/--
If `args.size == arity`, then just return `app`.
Otherwise return
```
let k := app
k args[arity:]
```
-/
mkOverApplication (app : Expr) (args : Array Expr) (arity : Nat) : M Expr := do
if args.size == arity then
mkAuxLetDecl app
else
let k ← mkAuxLetDecl app
let mut args := args
for i in [arity : args.size] do
args ← args.modifyM i visitAppArg
mkAuxLetDecl (mkAppN k args[arity:])
/--
Visit a `matcher`/`casesOn` alternative.
-/
visitAlt (ctorName : Name) (numParams : Nat) (e : Expr) : M (Expr × Alt) := do
withNewScope do
let mut (ps, e) ← visitBoundedLambda e numParams
if ps.size < numParams then
e ← etaExpandN e (numParams - ps.size)
let (ps', e') ← ToLCNF.visitLambda e
ps := ps ++ ps'
e := e'
let c ← toCode (← visit e)
let eType ← inferType e
return (eType, AltCore.alt ctorName ps c)
visitCases (casesInfo : CasesInfo) (e : Expr) : M Expr :=
etaIfUnderApplied e casesInfo.arity do
let args := e.getAppArgs
let mut alts := #[]
let typeName := casesInfo.declName.getPrefix
let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[:casesInfo.arity]))
let discr ← visitAppArg args[casesInfo.discrPos]!
let .inductInfo indVal ← getConstInfo typeName | unreachable!
for i in casesInfo.altsRange, numParams in casesInfo.altNumParams, ctorName in indVal.ctors do
let (altType, alt) ← visitAlt ctorName numParams args[i]!
unless compatibleTypes altType resultType do
resultType := anyTypeExpr
alts := alts.push alt
let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts }
let auxDecl ← mkAuxParam resultType
pushElement (.cases auxDecl.fvarId cases)
let result := .fvar auxDecl.fvarId
if args.size == casesInfo.arity then
return result
else
mkOverApplication result args casesInfo.arity
visitCtor (arity : Nat) (e : Expr) : M Expr :=
etaIfUnderApplied e arity do
visitAppDefault e.getAppFn e.getAppArgs
visitQuotLift (e : Expr) : M Expr := do
let arity := 6
etaIfUnderApplied e arity do
let mut args := e.getAppArgs
let α := args[0]!
let r := args[1]!
let f ← visitAppArg args[3]!
let q ← visitAppArg args[5]!
let .const _ [u, _] := e.getAppFn | unreachable!
let invq ← mkAuxLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q)
let r := mkApp f invq
mkOverApplication r args arity
visitEqRec (e : Expr) : M Expr :=
let arity := 6
etaIfUnderApplied e arity do
let args := e.getAppArgs
let f := e.getAppFn
let recType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN f args[:arity]))
let minor := if e.isAppOf ``Eq.rec || e.isAppOf ``Eq.ndrec then args[3]! else args[5]!
let minor ← visit minor
let minorType ← inferType minor
let cast ← if compatibleTypes minorType recType then
-- Recall that many types become compatible after LCNF conversion
-- Example: `Fin 10` and `Fin n`
pure minor
else
mkLcCast (← mkAuxLetDecl minor) recType
mkOverApplication cast args arity
visitFalseRec (e : Expr) : M Expr :=
let arity := 2
etaIfUnderApplied e arity do
let type ← toLCNFType (← liftMetaM do Meta.inferType e)
mkUnreachable type
visitAndRec (e : Expr) : M Expr :=
let arity := 5
etaIfUnderApplied e arity do
let args := e.getAppArgs
let ha := mkLcProof args[0]! -- We should not use `lcErased` here since we use it to create a pre-LCNF Expr.
let hb := mkLcProof args[1]!
let minor := if e.isAppOf ``And.rec then args[3]! else args[4]!
let minor := minor.beta #[ha, hb]
visit (mkAppN minor args[arity:])
visitNoConfusion (e : Expr) : M Expr := do
let .const declName _ := e.getAppFn | unreachable!
let typeName := declName.getPrefix
let .inductInfo inductVal ← getConstInfo typeName | unreachable!
let arity := inductVal.numParams + inductVal.numIndices + 1 /- motive -/ + 2 /- lhs/rhs-/ + 1 /- equality -/
etaIfUnderApplied e arity do
let args := e.getAppArgs
let lhs ← liftMetaM do Meta.whnf args[inductVal.numParams + inductVal.numIndices + 1]!
let rhs ← liftMetaM do Meta.whnf args[inductVal.numParams + inductVal.numIndices + 2]!
let lhs := lhs.toCtorIfLit
let rhs := rhs.toCtorIfLit
match lhs.isConstructorApp? (← getEnv), rhs.isConstructorApp? (← getEnv) with
| some lhsCtorVal, some rhsCtorVal =>
if lhsCtorVal.name == rhsCtorVal.name then
etaIfUnderApplied e (arity+1) do
let major := args[arity]!
let major ← expandNoConfusionMajor major lhsCtorVal.numFields
let major := mkAppN major args[arity+1:]
visit major
else
mkUnreachable (← inferType e)
| _, _ =>
throwError "code generator failed, unsupported occurrence of `{declName}`"
expandNoConfusionMajor (major : Expr) (numFields : Nat) : M Expr := do
match numFields with
| 0 => return major
| n+1 =>
if let .lam _ d b _ := major then
let proof := mkLcProof d
expandNoConfusionMajor (b.instantiate1 proof) n
else
expandNoConfusionMajor (← etaExpandN major (n+1)) (n+1)
visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M Expr := do
let typeName := projInfo.ctorName.getPrefix
if isRuntimeBultinType typeName then
let numArgs := e.getAppNumArgs
let arity := projInfo.numParams + 1
if numArgs < arity then
visit (← etaExpandN e (arity - numArgs))
else
visitAppDefault e.getAppFn e.getAppArgs
else
let .const declName us := e.getAppFn | unreachable!
let info ← getConstInfo declName
let f ← Core.instantiateValueLevelParams info us
visit (f.beta e.getAppArgs)
visitApp (e : Expr) : M Expr := do
if let .const declName _ := e.getAppFn then
if declName == ``Quot.lift then
visitQuotLift e
else if declName == ``Quot.mk then
visitCtor 3 e
else if declName == ``Eq.casesOn || declName == ``Eq.rec || declName == ``Eq.ndrec then
visitEqRec e
else if declName == ``And.rec || declName == ``And.casesOn then
visitAndRec e
else if declName == ``False.rec || declName == ``Empty.rec || declName == ``False.casesOn || declName == ``Empty.casesOn then
visitFalseRec e
else if let some casesInfo ← getCasesInfo? declName then
visitCases casesInfo e
else if let some arity ← getCtorArity? declName then
visitCtor arity e
else if isNoConfusion (← getEnv) declName then
visitNoConfusion e
else if let some projInfo ← getProjectionFnInfo? declName then
visitProjFn projInfo e
else
e.withApp visitAppDefault
else
e.withApp fun f args => do visitAppDefault (← visit f) args
visitLambda (e : Expr) : M Expr := do
let funDecl ← withNewScope do
let (ps, e) ← ToLCNF.visitLambda e
let e ← visit e
let c ← toCode e
mkAuxFunDecl ps c
pushElement (.fun funDecl)
return .fvar funDecl.fvarId
visitMData (mdata : MData) (e : Expr) : M Expr := do
if isCompilerRelevantMData mdata then
mkAuxLetDecl <| .mdata mdata (← visit e)
else
visit e
visitProj (s : Name) (i : Nat) (e : Expr) : M Expr := do
mkAuxLetDecl <| .proj s i (← visit e)
visitLet (e : Expr) (xs : Array Expr) : M Expr := do
match e with
| .letE binderName type value body _ =>
let type := type.instantiateRev xs
let value := value.instantiateRev xs
if (← (liftMetaM <| Meta.isProp type) <||> isTypeFormerType type) then
visitLet body (xs.push value)
else
let type' ← toLCNFType type
let value' ← visit value
let letDecl ← mkLetDecl binderName type value type' value'
pushElement (.let letDecl)
visitLet body (xs.push (.fvar letDecl.fvarId))
| _ =>
let e := e.instantiateRev xs
visit e
end ToLCNF
end Lean.Compiler.LCNF

View file

@ -3,7 +3,7 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Match.MatcherInfo
import Lean.CoreM
import Lean.Util.Recognizers
namespace Lean.Compiler.LCNF
@ -25,14 +25,15 @@ def isLcCast? (e : Expr) : Option Expr :=
else
none
/--
Store information about `matcher` and `casesOn` declarations.
Store information about `casesOn` declarations.
We treat them uniformly in the code generator.
-/
structure CasesInfo where
declName : Name
arity : Nat
numParams : Nat
discrsRange : Std.Range
discrPos : Nat
altsRange : Std.Range
altNumParams : Array Nat
motivePos : Nat
@ -47,25 +48,13 @@ def getCasesInfo? (declName : Name) : CoreM (Option CasesInfo) := do
let numParams := val.numParams
let motivePos := numParams
let arity := numParams + 1 /- motive -/ + val.numIndices + 1 /- major -/ + val.numCtors
let majorPos := numParams + 1 /- motive -/ + val.numIndices
let discrPos := numParams + 1 /- motive -/ + val.numIndices
-- We view indices as discriminants
let discrsRange := { start := numParams + 1, stop := majorPos + 1 }
let altsRange := { start := majorPos + 1, stop := arity }
let altsRange := { start := discrPos + 1, stop := arity }
let altNumParams ← val.ctors.toArray.mapM fun ctor => do
let .ctorInfo ctorVal ← getConstInfo ctor | unreachable!
return ctorVal.numFields
return some { numParams, motivePos, arity, discrsRange, altsRange, altNumParams }
def CasesInfo.geNumDiscrs (casesInfo : CasesInfo) : Nat :=
casesInfo.discrsRange.stop - casesInfo.discrsRange.start
def CasesInfo.updateResultingType (casesInfo : CasesInfo) (casesArgs : Array Expr) (typeNew : Expr) : Array Expr :=
casesArgs.modify casesInfo.motivePos fun motive => go motive
where
go (e : Expr) : Expr :=
match e with
| .lam n b d bi => .lam n b (go d) bi
| _ => typeNew
return some { declName, numParams, motivePos, arity, discrPos, altsRange, altNumParams }
def isCasesApp? (e : Expr) : CoreM (Option CasesInfo) := do
let .const declName _ := e.getAppFn | return none