lean4-htt/src/Lean/Compiler/IR/EmitUtil.lean
Leonardo de Moura 6858cb5fb6 chore: cleanup
2020-10-29 10:24:16 -07:00

85 lines
3.1 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.InitAttr
import Lean.Compiler.IR.CompilerM
/- Helper functions for backend code generators -/
namespace Lean.IR
/- Return true iff `b` is of the form `let x := g ys; ret x` -/
def isTailCallTo (g : Name) (b : FnBody) : Bool :=
match b with
| FnBody.vdecl x _ (Expr.fap f _) (FnBody.ret (Arg.var y)) => x == y && f == g
| _ => false
def usesModuleFrom (env : Environment) (modulePrefix : Name) : Bool :=
env.allImportedModuleNames.toList.any $ fun modName => modulePrefix.isPrefixOf modName
namespace CollectUsedDecls
abbrev M := ReaderT Environment (StateM NameSet)
@[inline] def collect (f : FunId) : M Unit :=
modify fun s => s.insert f
partial def collectFnBody : FnBody → M Unit
| FnBody.vdecl _ _ v b =>
match v with
| Expr.fap f _ => collect f *> collectFnBody b
| Expr.pap f _ => collect f *> collectFnBody b
| other => collectFnBody b
| FnBody.jdecl _ _ v b => collectFnBody v *> collectFnBody b
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
def collectInitDecl (fn : Name) : M Unit := do
let env ← read
match getInitFnNameFor? env fn with
| some initFn => collect initFn
| _ => pure ()
def collectDecl : Decl → M NameSet
| Decl.fdecl fn _ _ b => collectInitDecl fn *> CollectUsedDecls.collectFnBody b *> get
| Decl.extern fn _ _ _ => collectInitDecl fn *> get
end CollectUsedDecls
def collectUsedDecls (env : Environment) (decl : Decl) (used : NameSet := {}) : NameSet :=
(CollectUsedDecls.collectDecl decl env).run' used
abbrev VarTypeMap := Std.HashMap VarId IRType
abbrev JPParamsMap := Std.HashMap JoinPointId (Array Param)
namespace CollectMaps
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
@[inline] def collectVar (x : VarId) (t : IRType) : Collector
| (vs, js) => (vs.insert x t, js)
def collectParams (ps : Array Param) : Collector :=
fun s => ps.foldl (fun s p => collectVar p.x p.ty s) s
@[inline] def collectJP (j : JoinPointId) (xs : Array Param) : Collector
| (vs, js) => (vs, js.insert j xs)
/- `collectFnBody` assumes the variables in -/
partial def collectFnBody : FnBody → Collector
| FnBody.vdecl x t _ b => collectVar x t ∘ collectFnBody b
| FnBody.jdecl j xs v b => collectJP j xs ∘ collectParams xs ∘ collectFnBody v ∘ collectFnBody b
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
def collectDecl : Decl → Collector
| Decl.fdecl _ xs _ b => collectParams xs ∘ collectFnBody b
| _ => id
end CollectMaps
/- Return a pair `(v, j)`, where `v` is a mapping from variable/parameter to type,
and `j` is a mapping from join point to parameters.
This function assumes `d` has normalized indexes (see `normids.lean`). -/
def mkVarJPMaps (d : Decl) : VarTypeMap × JPParamsMap :=
CollectMaps.collectDecl d ({}, {})
end IR
end Lean