From 1295bf52bc7bdcd4390d40c0ba97be381739cc60 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 6 May 2019 18:35:06 -0700 Subject: [PATCH] feat(library/init/lean/compiler/ir): add `Decl.checker` for debugging purposes We have also added a new `Context` object, and modified our IR invariant. Now, we assume there is no variable or join point shadowing. --- library/init/lean/compiler/ir/basic.lean | 41 ++++++++++- library/init/lean/compiler/ir/checker.lean | 80 +++++++++++++++++++++ library/init/lean/compiler/ir/default.lean | 3 + library/init/lean/compiler/ir/livevars.lean | 42 ++++------- 4 files changed, 137 insertions(+), 29 deletions(-) create mode 100644 library/init/lean/compiler/ir/checker.lean diff --git a/library/init/lean/compiler/ir/basic.lean b/library/init/lean/compiler/ir/basic.lean index 38efa279e6..58711c7317 100644 --- a/library/init/lean/compiler/ir/basic.lean +++ b/library/init/lean/compiler/ir/basic.lean @@ -323,6 +323,10 @@ inductive Decl | fdecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) | extern (f : FunId) (xs : Array Param) (ty : IRType) +def Decl.id : Decl → FunId +| (Decl.fdecl f _ _ _) := f +| (Decl.extern f _ _) := f + @[export lean.ir.mk_decl_core] def mkDecl (f : FunId) (xs : Array Param) (ty : IRType) (b : FnBody) : Decl := Decl.fdecl f xs ty b /-- `Expr.isPure e` return `true` Iff `e` is in the `λPure` fragment. -/ @@ -358,7 +362,42 @@ abbrev IndexSet := RBTree Index Index.lt instance vsetInh : Inhabited IndexSet := ⟨{}⟩ /-- Mapping from variable (join point) indices to their declarations -/ -abbrev Context := RBMap Index FnBody Index.lt +structure Context := +(locals : RBMap Index FnBody Index.lt := {}) +(params : RBMap Index Param Index.lt := {}) + +def Context.addDecl (ctx : Context) (b : FnBody) : Context := +match b with +| FnBody.vdecl x _ _ _ := { locals := ctx.locals.insert x.idx b, .. ctx } +| FnBody.jdecl j _ _ _ _ := { locals := ctx.locals.insert j.idx b, .. ctx } +| other := ctx + +def Context.addParam (ctx : Context) (p : Param) : Context := +{ params := ctx.params.insert p.x.idx p, .. ctx } + +def Context.isJP (ctx : Context) (idx : Index) : Bool := +match ctx.locals.find idx with +| some (FnBody.jdecl _ _ _ _ _) := true +| other := false + +def Context.getJoinPointBody (ctx : Context) (j : JoinPointId) : Option FnBody := +match ctx.locals.find j.idx with +| some (FnBody.jdecl _ _ _ v _) := some v +| other := none + +def Context.isParam (ctx : Context) (idx : Index) : Bool := +ctx.params.contains idx + +def Context.isLocalVar (ctx : Context) (idx : Index) : Bool := +match ctx.locals.find idx with +| some (FnBody.vdecl _ _ _ _) := true +| other := false + +def Context.contains (ctx : Context) (idx : Index) : Bool := +ctx.locals.contains idx || ctx.params.contains idx + +def Context.eraseJoinPointDecl (ctx : Context) (j : JoinPointId) : Context := +{ locals := ctx.locals.erase j.idx, .. ctx } abbrev IndexRenaming := RBMap Index Index Index.lt diff --git a/library/init/lean/compiler/ir/checker.lean b/library/init/lean/compiler/ir/checker.lean new file mode 100644 index 0000000000..74546b3beb --- /dev/null +++ b/library/init/lean/compiler/ir/checker.lean @@ -0,0 +1,80 @@ +/- +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import init.lean.compiler.ir.basic +import init.control.reader + +namespace Lean +namespace IR + +namespace Checker + +abbrev M := ExceptT String (ReaderT Context Id) + +def checkVar (x : VarId) : M Unit := +do ctx ← read, + unless (ctx.isLocalVar x.idx || ctx.isParam x.idx) $ throw ("unknown variable '" ++ toString x ++ "'") + +def checkJP (j : JoinPointId) : M Unit := +do ctx ← read, + unless (ctx.isJP j.idx) $ throw ("unknown join point '" ++ toString j ++ "'") + +def checkArg (a : Arg) : M Unit := +match a with +| Arg.var x := checkVar x +| other := pure () + +def checkArgs (as : Array Arg) : M Unit := +as.mfor checkArg + +def checkExpr : Expr → M Unit +| e := pure () + +@[inline] def withParams (ps : Array Param) (k : M Unit) : M Unit := +do ctx ← read, + ctx ← ps.mfoldl ctx (λ ctx p, do + when (ctx.contains p.x.idx) $ throw ("invalid parameter declaration, shadowing is not allowed"), + pure $ ctx.addParam p), + adaptReader (λ _, ctx) k + +local attribute [instance] monadInhabited + +partial def checkFnBody : FnBody → M Unit +| d@(FnBody.vdecl x _ v b) := do + checkExpr v, + ctx ← read, + when (ctx.contains x.idx) $ throw ("invalid variable declaration, shadowing is not allowed"), + adaptReader (λ ctx : Context, ctx.addDecl d) (checkFnBody b) +| d@(FnBody.jdecl j ys _ v b) := do + withParams ys (checkFnBody v), + ctx ← read, + when (ctx.contains j.idx) $ throw ("invalid join point declaration, shadowing is not allowed"), + adaptReader (λ ctx : Context, ctx.addDecl d) (checkFnBody b) +| (FnBody.set x _ y b) := checkVar x *> checkVar y *> checkFnBody b +| (FnBody.uset x _ y b) := checkVar x *> checkVar y *> checkFnBody b +| (FnBody.sset x _ _ y _ b) := checkVar x *> checkVar y *> checkFnBody b +| (FnBody.release x _ b) := checkVar x *> checkFnBody b +| (FnBody.inc x _ _ b) := checkVar x *> checkFnBody b +| (FnBody.dec x _ _ b) := checkVar x *> checkFnBody b +| (FnBody.mdata _ b) := checkFnBody b +| (FnBody.jmp j ys) := checkJP j *> checkArgs ys +| (FnBody.ret x) := checkArg x +| (FnBody.case _ x alts) := checkVar x *> alts.mfor (λ alt, checkFnBody alt.body) +| (FnBody.unreachable) := pure () + +def checkDecl : Decl → M Unit +| (Decl.fdecl f xs t b) := withParams xs (checkFnBody b) +| (Decl.extern f xs t) := withParams xs (pure ()) + +end Checker + +def Decl.check (d : Decl) : IO Unit := +match Checker.checkDecl d {} with +| Except.error msg := throw (IO.userError ("IR check failed at '" ++ toString d.id ++ "'")) +| other := pure () + +end IR +end Lean diff --git a/library/init/lean/compiler/ir/default.lean b/library/init/lean/compiler/ir/default.lean index b505d47290..ac9263ec47 100644 --- a/library/init/lean/compiler/ir/default.lean +++ b/library/init/lean/compiler/ir/default.lean @@ -11,6 +11,7 @@ import init.lean.compiler.ir.elimdead import init.lean.compiler.ir.simpcase import init.lean.compiler.ir.resetreuse import init.lean.compiler.ir.normids +import init.lean.compiler.ir.checker namespace Lean namespace IR @@ -18,6 +19,7 @@ namespace IR @[export lean.ir.test_core] def test (d : Decl) : IO Unit := do + d.check, IO.println d, IO.println $ ("Max index " ++ toString d.maxIndex), let d := d.pushProj, @@ -35,6 +37,7 @@ do let d := d.normalizeIds, IO.println "=== After normalize Ids ===", IO.println d, + d.check, pure () end IR diff --git a/library/init/lean/compiler/ir/livevars.lean b/library/init/lean/compiler/ir/livevars.lean index 3d9e3ee9d7..f38075193f 100644 --- a/library/init/lean/compiler/ir/livevars.lean +++ b/library/init/lean/compiler/ir/livevars.lean @@ -31,12 +31,14 @@ namespace IR namespace IsLive /- - IndexSet is the set of local joint points We use `State Context` instead of `ReaderT Context Id` because we remove non local joint points from `Context` whenever we visit them instead of - maintaining a set of visit non local join points. + maintaining a set of visited non local join points. + + Remark: we don't need to track local join points because we assume there is + no variable or join point shadowing in our IR. -/ -abbrev M := ReaderT IndexSet (State Context) +abbrev M := State Context @[inline] def visitVar (w : Index) (x : VarId) : M Bool := pure (HasIndex.visitVar w x) @[inline] def visitJP (w : Index) (x : JoinPointId) : M Bool := pure (HasIndex.visitJP w x) @@ -44,26 +46,11 @@ abbrev M := ReaderT IndexSet (State Context) @[inline] def visitArgs (w : Index) (as : Array Arg) : M Bool := pure (HasIndex.visitArgs w as) @[inline] def visitExpr (w : Index) (e : Expr) : M Bool := pure (HasIndex.visitExpr w e) -/- Search for `w` using `k` in a context where variable `x` is declared. -/ -@[inline] def withVDecl (x : VarId) (w : Index) (k : M Bool) : M Bool := -if w == x.idx then pure false -else k - -/- Search for `w` using `k` in a context where joint point `x` is declared. -/ -@[inline] def withJDecl (j : JoinPointId) (w : Index) (k : M Bool) : M Bool := -if w == j.idx then pure false -else adaptReader (λ localJPs : IndexSet, localJPs.insert j.idx) k - -/- Search for `w` using `k` in a context with `ys` parameters -/ -@[inline] def withParams (ys : Array Param) (w : Index) (k : M Bool) : M Bool := -if HasIndex.visitParams w ys then pure false -else k - local attribute [instance] monadInhabited partial def visitFnBody (w : Index) : FnBody → M Bool -| (FnBody.vdecl x _ v b) := visitExpr w v <||> withVDecl x w (visitFnBody b) -| (FnBody.jdecl j ys _ v b) := withParams ys w (visitFnBody v) <||> withJDecl j w (visitFnBody b) +| (FnBody.vdecl x _ v b) := visitExpr w v <||> visitFnBody b +| (FnBody.jdecl j ys _ v b) := visitFnBody v <||> visitFnBody b | (FnBody.set x _ y b) := visitVar w x <||> visitVar w y <||> visitFnBody b | (FnBody.uset x _ y b) := visitVar w x <||> visitVar w y <||> visitFnBody b | (FnBody.sset x _ _ y _ b) := visitVar w x <||> visitVar w y <||> visitFnBody b @@ -72,16 +59,15 @@ partial def visitFnBody (w : Index) : FnBody → M Bool | (FnBody.dec x _ _ b) := visitVar w x <||> visitFnBody b | (FnBody.mdata _ b) := visitFnBody b | (FnBody.jmp j ys) := visitArgs w ys <||> do { - localJPs ← read, - if localJPs.contains j.idx then pure false -- `j` is a local joint point, so we have already searched for `w` in its declaration. - else do ctx ← get, - match ctx.find j.idx with + match ctx.getJoinPointBody j with | some b := - -- `j` is not a local join point. + -- `j` is not a local join point since we assume we cannot shadow join point declarations. -- Instead of marking the join points that we have already been visited, we permanently remove `j` from the context. - set (ctx.erase j.idx) *> visitFnBody b - | none := pure false + set (ctx.eraseJoinPointDecl j) *> visitFnBody b + | none := + -- `j` must be a local join point. So do nothing since we have already visite its body. + pure false } | (FnBody.ret x) := visitArg w x | (FnBody.case _ x alts) := visitVar w x <||> alts.anyM (λ alt, visitFnBody alt.body) @@ -96,7 +82,7 @@ end IsLive Recall that we say that a join point `j` is free in `b` if `b` contains `FnBody.jmp j ys` and `j` is not local. -/ def FnBody.isLive (b : FnBody) (ctx : Context) (x : VarId) : Bool := -(IsLive.visitFnBody x.idx b {}).run' ctx +(IsLive.visitFnBody x.idx b).run' ctx end IR end Lean