refactor: make SubExpr.Pos a definition
Instead of an abbreviation. It is easier to understand Pos operations in terms of 'push' and 'pop' rather than through arithmetic.
This commit is contained in:
parent
bbc196eeb7
commit
2fe933cdf5
5 changed files with 100 additions and 14 deletions
|
|
@ -33,14 +33,14 @@ variable [MonadLiftT IO m]
|
|||
def getExpr : m Expr := return (← readThe SubExpr).expr
|
||||
def getPos : m Pos := return (← readThe SubExpr).pos
|
||||
|
||||
def descend (child : Expr) (childIdx : Pos) (x : m α) : m α :=
|
||||
withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos * maxChildren + childIdx }) x
|
||||
def descend (child : Expr) (childIdx : Nat) (x : m α) : m α :=
|
||||
withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos.push childIdx }) x
|
||||
|
||||
def withAppFn (x : m α) : m α := do descend (← getExpr).appFn! 0 x
|
||||
def withAppArg (x : m α) : m α := do descend (← getExpr).appArg! 1 x
|
||||
|
||||
def withType (x : m α) : m α := do
|
||||
descend (← Meta.inferType (← getExpr)) (maxChildren - 1) x -- phantom positions for types
|
||||
descend (← Meta.inferType (← getExpr)) Pos.typeCoord x -- phantom positions for types
|
||||
|
||||
partial def withAppFnArgs (xf : m α) (xa : α → m α) : m α := do
|
||||
if (← getExpr).isApp then
|
||||
|
|
@ -80,20 +80,20 @@ def withLetBody (x : m α) : m α := do
|
|||
def withNaryFn (x : m α) : m α := do
|
||||
let e ← getExpr
|
||||
let n := e.getAppNumArgs
|
||||
let newPos := (← getPos) * (maxChildren ^ n)
|
||||
let newPos := (← getPos).asNat * (Pos.maxChildren ^ n)
|
||||
withTheReader SubExpr (fun cfg => { cfg with expr := e.getAppFn, pos := newPos }) x
|
||||
|
||||
def withNaryArg (argIdx : Nat) (x : m α) : m α := do
|
||||
let e ← getExpr
|
||||
let args := e.getAppArgs
|
||||
let newPos := (← getPos) * (maxChildren ^ (args.size - argIdx)) + 1
|
||||
let newPos := (← getPos).asNat * (Pos.maxChildren ^ (args.size - argIdx)) + 1
|
||||
withTheReader SubExpr (fun cfg => { cfg with expr := args[argIdx], pos := newPos }) x
|
||||
|
||||
end Descend
|
||||
|
||||
structure HoleIterator where
|
||||
curr : Nat := 2
|
||||
top : Nat := maxChildren
|
||||
top : Nat := Pos.maxChildren
|
||||
deriving Inhabited
|
||||
|
||||
section Hole
|
||||
|
|
@ -107,7 +107,7 @@ def HoleIterator.toPos (iter : HoleIterator) : Pos :=
|
|||
|
||||
def HoleIterator.next (iter : HoleIterator) : HoleIterator :=
|
||||
if (iter.curr+1) == iter.top then
|
||||
⟨2*iter.top, maxChildren*iter.top⟩
|
||||
⟨2*iter.top, Pos.maxChildren*iter.top⟩
|
||||
else ⟨iter.curr+1, iter.top⟩
|
||||
|
||||
/-- The positioning scheme guarantees that there will be an infinite number of extra positions
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
/-
|
||||
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sebastian Ullrich, Daniel Selsam, Wojciech Nawrocki
|
||||
Authors: Sebastian Ullrich, Daniel Selsam, Wojciech Nawrocki, E.W.Ayers
|
||||
-/
|
||||
import Lean.Meta.Basic
|
||||
import Lean.Data.Json
|
||||
import Std.Data.RBMap
|
||||
|
||||
namespace Lean
|
||||
|
|
@ -11,7 +12,80 @@ namespace Lean
|
|||
/-- A position of a subexpression in an expression.
|
||||
|
||||
See docstring of `SubExpr` for more detail.-/
|
||||
abbrev SubExpr.Pos := Nat
|
||||
def SubExpr.Pos := Nat
|
||||
|
||||
namespace SubExpr.Pos
|
||||
|
||||
def maxChildren := 4
|
||||
|
||||
/-- The coordinate `3 = maxChildren - 1` is
|
||||
reserved to denote the type of the expression. -/
|
||||
def typeCoord : Nat := maxChildren - 1
|
||||
|
||||
def asNat : Pos → Nat := id
|
||||
|
||||
instance : Inhabited Pos := show Inhabited Nat by infer_instance
|
||||
instance : Ord Pos := show Ord Nat by infer_instance
|
||||
instance : FromJson Pos := show FromJson Nat by infer_instance
|
||||
instance : ToJson Pos := show ToJson Nat by infer_instance
|
||||
instance : Repr Pos := show Repr Nat by infer_instance
|
||||
instance : ToString Pos := show ToString Nat by infer_instance
|
||||
|
||||
/-- The Pos representing the root subexpression. -/
|
||||
def root : Pos := (1 : Nat)
|
||||
|
||||
def isRoot (p : Pos) : Bool := p.asNat == 1
|
||||
|
||||
/-- The coordinate deepest in the Pos. -/
|
||||
def head (p : Pos) : Nat :=
|
||||
if p.isRoot then panic! "already at top"
|
||||
else p.asNat % maxChildren
|
||||
|
||||
def tail (p : Pos) : Pos :=
|
||||
if p.isRoot then panic! "already at top"
|
||||
else (p.asNat - p.head) / maxChildren
|
||||
|
||||
def push (p : Pos) (c : Nat) : Pos :=
|
||||
if c >= maxChildren then panic! s!"invalid coordinate {c}"
|
||||
else p.asNat * maxChildren + c
|
||||
|
||||
/-- `pushNZeros p count` runs `.push 0` `count` times. -/
|
||||
def pushNZeros (p : Pos) (count : Nat) : Pos :=
|
||||
p.asNat * (maxChildren ^ count)
|
||||
|
||||
variable {α : Type} [Inhabited α]
|
||||
|
||||
/-- Fold over the position starting at the root and heading to the leaf-/
|
||||
def foldl (f : α → Nat → α) : α → Pos → α :=
|
||||
fix2 (fun r a p => if p.isRoot then a else f (r a p.tail) p.head)
|
||||
|
||||
/-- Fold over the position starting at the root and heading to the leaf-/
|
||||
def foldr (f : Nat → α → α) : Pos → α → α :=
|
||||
fix2 (fun r p a => if p.isRoot then a else r p.tail (f p.head a))
|
||||
|
||||
def foldrM [Monad M] (f : Nat → α → M α) : Pos → α → M α :=
|
||||
fix2 (fun r p a => if p.isRoot then pure a else f p.head a >>= r p.tail)
|
||||
|
||||
def depth (p : Pos) :=
|
||||
p.foldr (fun _ => Nat.succ) 0
|
||||
|
||||
/-- Returns true if `pred` is true for each coordinate in `p`.-/
|
||||
def all (pred : Nat → Bool) (p : Pos) : Bool :=
|
||||
OptionT.run (m := Id) (foldrM (fun n a => if pred n then pure a else failure) p ()) |>.isSome
|
||||
|
||||
def append : Pos → Pos → Pos := foldl push
|
||||
|
||||
/-- Creates a subexpression `Pos` from an array of 'coordinates'.
|
||||
Each coordinate is a number {0,1,2} expressing which child subexpression should be explored.
|
||||
The first coordinate in the array corresponds to the root of the expression tree. -/
|
||||
def ofArray (ps : Array Nat) : Pos :=
|
||||
ps.foldl push root
|
||||
|
||||
/-- Decodes a subexpression `Pos` as a sequence of coordinates. See `Pos.fromArray` for details.-/
|
||||
def toArray (p : Pos) : Array Nat :=
|
||||
foldl Array.push #[] p
|
||||
|
||||
end SubExpr.Pos
|
||||
|
||||
/-- An expression and the position of a subexpression within this expression.
|
||||
|
||||
|
|
@ -30,8 +104,7 @@ structure SubExpr where
|
|||
|
||||
namespace SubExpr
|
||||
|
||||
abbrev maxChildren : Pos := 4
|
||||
def mkRoot (e : Expr) : SubExpr := ⟨e, 1⟩
|
||||
def mkRoot (e : Expr) : SubExpr := ⟨e, Pos.root⟩
|
||||
|
||||
end SubExpr
|
||||
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ where
|
|||
def ppExprTagged (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
|
||||
let optsPerPos := if explicit then
|
||||
Std.RBMap.ofList [
|
||||
(1, KVMap.empty.setBool `pp.all true),
|
||||
(1, KVMap.empty.setBool `pp.tagAppFns true)
|
||||
(SubExpr.Pos.root, KVMap.empty.setBool `pp.all true),
|
||||
(SubExpr.Pos.root, KVMap.empty.setBool `pp.tagAppFns true)
|
||||
]
|
||||
else
|
||||
{}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ section
|
|||
#eval checkM `(id Nat)
|
||||
#eval checkM `(Sum Nat Nat)
|
||||
end
|
||||
#eval checkM `(id (id Nat)) (Std.RBMap.empty.insert 5 $ KVMap.empty.insert `pp.explicit true)
|
||||
#eval checkM `(id (id Nat)) (Std.RBMap.empty.insert (SubExpr.Pos.encode #[1]) $ KVMap.empty.insert `pp.explicit true)
|
||||
|
||||
-- specify the expected type of `a` in a way that is not erased by the delaborator
|
||||
def typeAs.{u} (α : Type u) (a : α) := ()
|
||||
|
|
|
|||
13
tests/lean/run/subexpr.lean
Normal file
13
tests/lean/run/subexpr.lean
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import Lean
|
||||
open Lean.SubExpr
|
||||
|
||||
def ps := [#[], #[0], #[1], #[0,1], #[1,0] , #[0,0], #[1,2,3]]
|
||||
theorem Pos.roundtrip :
|
||||
true = ps.all fun x => x == (Pos.toArray <| Pos.ofArray <| x)
|
||||
:= by native_decide
|
||||
|
||||
theorem Pos.append_roundtrip :
|
||||
true = (List.all
|
||||
(ps.bind fun p => ps.map fun q => (p,q))
|
||||
(fun (x,y) => (x ++ y) == (Pos.toArray <| (Pos.append (Pos.ofArray x) (Pos.ofArray y))))
|
||||
) := by native_decide
|
||||
Loading…
Add table
Reference in a new issue