162 lines
No EOL
5.7 KiB
Text
162 lines
No EOL
5.7 KiB
Text
/-
|
|
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
Authors: Leonardo de Moura
|
|
-/
|
|
import Lean.Compiler.LCNF.MonoTypes
|
|
import Lean.Compiler.LCNF.InferType
|
|
|
|
namespace Lean.Compiler.LCNF
|
|
|
|
structure ToMonoM.State where
|
|
typeParams : FVarIdSet := {}
|
|
|
|
abbrev ToMonoM := StateRefT ToMonoM.State CompilerM
|
|
|
|
def Param.toMono (param : Param) : ToMonoM Param := do
|
|
if isTypeFormerType param.type then
|
|
modify fun { typeParams, .. } => { typeParams := typeParams.insert param.fvarId }
|
|
param.update (← toMonoType param.type)
|
|
|
|
def isTrivialConstructorApp? (e : Expr) : ToMonoM (Option Expr) := do
|
|
let some ctorInfo := e.isConstructorApp? (← getEnv) | return none
|
|
let some info ← hasTrivialStructure? ctorInfo.induct | return none
|
|
assert! ctorInfo.numParams + info.fieldIdx < e.getAppNumArgs
|
|
return e.getArg! (ctorInfo.numParams + info.fieldIdx)
|
|
|
|
def fvarToMono (e : Expr) : ToMonoM Expr := do
|
|
if (← get).typeParams.contains e.fvarId! then
|
|
return erasedExpr
|
|
else
|
|
return e
|
|
|
|
def argToMono (arg : Expr) : ToMonoM Expr := do
|
|
if arg.isFVar then
|
|
fvarToMono arg
|
|
else
|
|
return erasedExpr
|
|
|
|
def ctorAppToMono (ctorInfo : ConstructorVal) (args : Array Expr) : ToMonoM Expr := do
|
|
let argsNew ← args[:ctorInfo.numParams].toArray.mapM fun param => do
|
|
-- We only preserve constructor parameters that are types
|
|
if isTypeFormerType (← inferType param) then
|
|
toMonoType param
|
|
else
|
|
return erasedExpr
|
|
let argsNew := argsNew ++ (← args[ctorInfo.numParams:].toArray.mapM argToMono)
|
|
return mkAppN (mkConst ctorInfo.name) argsNew
|
|
|
|
partial def _root_.Lean.Expr.toMono (e : Expr) : ToMonoM Expr := do
|
|
match e with
|
|
| .fvar .. => fvarToMono e
|
|
| .lit .. => return e
|
|
| .const declName _ => return mkConst declName
|
|
| .sort .. => return erasedExpr
|
|
| .mvar .. | .bvar .. | .letE .. => unreachable!
|
|
| .mdata _ b => return e.updateMData! (← b.toMono)
|
|
| .proj structName fieldIdx b =>
|
|
if let some info ← hasTrivialStructure? structName then
|
|
if info.fieldIdx == fieldIdx then
|
|
b.toMono
|
|
else
|
|
return erasedExpr
|
|
else
|
|
return e.updateProj! (← b.toMono)
|
|
| .forallE .. | .lam .. => return erasedExpr
|
|
| .app .. =>
|
|
if e.isAppOf ``Decidable.isTrue then
|
|
return mkConst ``Bool.true
|
|
else if e.isAppOf ``Decidable.isFalse then
|
|
return mkConst ``Bool.false
|
|
else if let some arg ← isTrivialConstructorApp? e then
|
|
arg.toMono
|
|
else if let some ctorInfo := e.isConstructorApp? (← getEnv) then
|
|
ctorAppToMono ctorInfo e.getAppArgs
|
|
else
|
|
let f := e.getAppFn
|
|
let args := e.getAppArgs
|
|
let args ← args.mapM argToMono
|
|
return mkAppN (← f.toMono) args
|
|
|
|
def LetDecl.toMono (decl : LetDecl) : ToMonoM LetDecl := do
|
|
let type ← toMonoType decl.type
|
|
let value ← decl.value.toMono
|
|
decl.update type value
|
|
|
|
mutual
|
|
|
|
partial def FunDeclCore.toMono (decl : FunDecl) : ToMonoM FunDecl := do
|
|
let type ← toMonoType decl.type
|
|
let params ← decl.params.mapM (·.toMono)
|
|
let value ← decl.value.toMono
|
|
decl.update type params value
|
|
|
|
/-- Convert `cases` `Decidable` => `Bool` -/
|
|
partial def decToMono (c : Cases) (_ : c.typeName == ``Decidable) : ToMonoM Code := do
|
|
let resultType ← toMonoType c.resultType
|
|
let alts ← c.alts.mapM fun alt => do
|
|
match alt with
|
|
| .default k => return alt.updateCode (← k.toMono)
|
|
| .alt ctorName ps k =>
|
|
eraseParams ps
|
|
let ctorName := if ctorName == ``Decidable.isTrue then ``Bool.true else ``Bool.false
|
|
return .alt ctorName #[] (← k.toMono)
|
|
return .cases { c with resultType, alts, typeName := ``Bool }
|
|
|
|
/-- Eliminate `cases` for trivial structure. See `hasTrivialStructure?` -/
|
|
partial def trivialStructToMono (info : TrivialStructureInfo) (c : Cases) : ToMonoM Code := do
|
|
assert! c.alts.size == 1
|
|
let .alt ctorName ps k := c.alts[0]! | unreachable!
|
|
assert! ctorName == info.ctorName
|
|
assert! info.fieldIdx < ps.size
|
|
let p := ps[info.fieldIdx]!
|
|
eraseParams ps
|
|
/- We reuse `p`s `fvarId` to avoid substitution -/
|
|
let decl := { fvarId := p.fvarId, binderName := p.binderName, type := (← toMonoType p.type), value := .fvar c.discr }
|
|
modifyLCtx fun lctx => lctx.addLetDecl decl
|
|
let k ← k.toMono
|
|
return .let decl k
|
|
|
|
partial def Code.toMono (code : Code) : ToMonoM Code := do
|
|
match code with
|
|
| .let decl k => return code.updateLet! (← decl.toMono) (← k.toMono)
|
|
| .fun decl k | .jp decl k => return code.updateFun! (← decl.toMono) (← k.toMono)
|
|
| .unreach type => return .unreach (← toMonoType type)
|
|
| .return .. | .jmp .. => return code
|
|
| .cases c =>
|
|
if h : c.typeName == ``Decidable then
|
|
decToMono c h
|
|
else if let some info ← hasTrivialStructure? c.typeName then
|
|
trivialStructToMono info c
|
|
else
|
|
-- TODO: `casesOn` `[implementedBy]` support
|
|
let type ← toMonoType c.resultType
|
|
let alts ← c.alts.mapM fun alt =>
|
|
match alt with
|
|
| .default k => return alt.updateCode (← k.toMono)
|
|
| .alt _ ps k => return alt.updateAlt! (← ps.mapM (·.toMono)) (← k.toMono)
|
|
return code.updateCases! type c.discr alts
|
|
|
|
end
|
|
|
|
def Decl.toMono (decl : Decl) : CompilerM Decl := do
|
|
go |>.run' {}
|
|
where
|
|
go : ToMonoM Decl := do
|
|
let type ← toMonoType decl.type
|
|
let params ← decl.params.mapM (·.toMono)
|
|
let value ← decl.value.toMono
|
|
let decl := { decl with type, params, value, levelParams := [] }
|
|
decl.saveMono
|
|
return decl
|
|
|
|
def toMono : Pass where
|
|
name := `toMono
|
|
run := fun decls => decls.mapM (·.toMono)
|
|
phase := .base
|
|
phaseOut := .mono
|
|
|
|
builtin_initialize
|
|
registerTraceClass `Compiler.toMono (inherited := true)
|
|
|
|
end Lean.Compiler.LCNF |