refactor: make SubExpr.Pos a definition

Instead of an abbreviation. It is easier to understand
Pos operations in terms of 'push' and 'pop' rather than
through arithmetic.
This commit is contained in:
E.W.Ayers 2022-06-03 15:12:23 -04:00 committed by Leonardo de Moura
parent bbc196eeb7
commit 2fe933cdf5
5 changed files with 100 additions and 14 deletions

View file

@ -33,14 +33,14 @@ variable [MonadLiftT IO m]
def getExpr : m Expr := return (← readThe SubExpr).expr
def getPos : m Pos := return (← readThe SubExpr).pos
def descend (child : Expr) (childIdx : Pos) (x : m α) : m α :=
withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos * maxChildren + childIdx }) x
def descend (child : Expr) (childIdx : Nat) (x : m α) : m α :=
withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos.push childIdx }) x
def withAppFn (x : m α) : m α := do descend (← getExpr).appFn! 0 x
def withAppArg (x : m α) : m α := do descend (← getExpr).appArg! 1 x
def withType (x : m α) : m α := do
descend (← Meta.inferType (← getExpr)) (maxChildren - 1) x -- phantom positions for types
descend (← Meta.inferType (← getExpr)) Pos.typeCoord x -- phantom positions for types
partial def withAppFnArgs (xf : m α) (xa : α → m α) : m α := do
if (← getExpr).isApp then
@ -80,20 +80,20 @@ def withLetBody (x : m α) : m α := do
def withNaryFn (x : m α) : m α := do
let e ← getExpr
let n := e.getAppNumArgs
let newPos := (← getPos) * (maxChildren ^ n)
let newPos := (← getPos).asNat * (Pos.maxChildren ^ n)
withTheReader SubExpr (fun cfg => { cfg with expr := e.getAppFn, pos := newPos }) x
def withNaryArg (argIdx : Nat) (x : m α) : m α := do
let e ← getExpr
let args := e.getAppArgs
let newPos := (← getPos) * (maxChildren ^ (args.size - argIdx)) + 1
let newPos := (← getPos).asNat * (Pos.maxChildren ^ (args.size - argIdx)) + 1
withTheReader SubExpr (fun cfg => { cfg with expr := args[argIdx], pos := newPos }) x
end Descend
structure HoleIterator where
curr : Nat := 2
top : Nat := maxChildren
top : Nat := Pos.maxChildren
deriving Inhabited
section Hole
@ -107,7 +107,7 @@ def HoleIterator.toPos (iter : HoleIterator) : Pos :=
def HoleIterator.next (iter : HoleIterator) : HoleIterator :=
if (iter.curr+1) == iter.top then
⟨2*iter.top, maxChildren*iter.top⟩
⟨2*iter.top, Pos.maxChildren*iter.top⟩
else ⟨iter.curr+1, iter.top⟩
/-- The positioning scheme guarantees that there will be an infinite number of extra positions

View file

@ -1,9 +1,10 @@
/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Ullrich, Daniel Selsam, Wojciech Nawrocki
Authors: Sebastian Ullrich, Daniel Selsam, Wojciech Nawrocki, E.W.Ayers
-/
import Lean.Meta.Basic
import Lean.Data.Json
import Std.Data.RBMap
namespace Lean
@ -11,7 +12,80 @@ namespace Lean
/-- A position of a subexpression in an expression.
See docstring of `SubExpr` for more detail.-/
abbrev SubExpr.Pos := Nat
def SubExpr.Pos := Nat
namespace SubExpr.Pos
def maxChildren := 4
/-- The coordinate `3 = maxChildren - 1` is
reserved to denote the type of the expression. -/
def typeCoord : Nat := maxChildren - 1
def asNat : Pos → Nat := id
instance : Inhabited Pos := show Inhabited Nat by infer_instance
instance : Ord Pos := show Ord Nat by infer_instance
instance : FromJson Pos := show FromJson Nat by infer_instance
instance : ToJson Pos := show ToJson Nat by infer_instance
instance : Repr Pos := show Repr Nat by infer_instance
instance : ToString Pos := show ToString Nat by infer_instance
/-- The Pos representing the root subexpression. -/
def root : Pos := (1 : Nat)
def isRoot (p : Pos) : Bool := p.asNat == 1
/-- The coordinate deepest in the Pos. -/
def head (p : Pos) : Nat :=
if p.isRoot then panic! "already at top"
else p.asNat % maxChildren
def tail (p : Pos) : Pos :=
if p.isRoot then panic! "already at top"
else (p.asNat - p.head) / maxChildren
def push (p : Pos) (c : Nat) : Pos :=
if c >= maxChildren then panic! s!"invalid coordinate {c}"
else p.asNat * maxChildren + c
/-- `pushNZeros p count` runs `.push 0` `count` times. -/
def pushNZeros (p : Pos) (count : Nat) : Pos :=
p.asNat * (maxChildren ^ count)
variable {α : Type} [Inhabited α]
/-- Fold over the position starting at the root and heading to the leaf-/
def foldl (f : α → Nat → α) : α → Pos → α :=
fix2 (fun r a p => if p.isRoot then a else f (r a p.tail) p.head)
/-- Fold over the position starting at the root and heading to the leaf-/
def foldr (f : Nat → αα) : Pos → αα :=
fix2 (fun r p a => if p.isRoot then a else r p.tail (f p.head a))
def foldrM [Monad M] (f : Nat → α → M α) : Pos → α → M α :=
fix2 (fun r p a => if p.isRoot then pure a else f p.head a >>= r p.tail)
def depth (p : Pos) :=
p.foldr (fun _ => Nat.succ) 0
/-- Returns true if `pred` is true for each coordinate in `p`.-/
def all (pred : Nat → Bool) (p : Pos) : Bool :=
OptionT.run (m := Id) (foldrM (fun n a => if pred n then pure a else failure) p ()) |>.isSome
def append : Pos → Pos → Pos := foldl push
/-- Creates a subexpression `Pos` from an array of 'coordinates'.
Each coordinate is a number {0,1,2} expressing which child subexpression should be explored.
The first coordinate in the array corresponds to the root of the expression tree. -/
def ofArray (ps : Array Nat) : Pos :=
ps.foldl push root
/-- Decodes a subexpression `Pos` as a sequence of coordinates. See `Pos.fromArray` for details.-/
def toArray (p : Pos) : Array Nat :=
foldl Array.push #[] p
end SubExpr.Pos
/-- An expression and the position of a subexpression within this expression.
@ -30,8 +104,7 @@ structure SubExpr where
namespace SubExpr
abbrev maxChildren : Pos := 4
def mkRoot (e : Expr) : SubExpr := ⟨e, 1⟩
def mkRoot (e : Expr) : SubExpr := ⟨e, Pos.root⟩
end SubExpr

View file

@ -46,8 +46,8 @@ where
def ppExprTagged (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do
let optsPerPos := if explicit then
Std.RBMap.ofList [
(1, KVMap.empty.setBool `pp.all true),
(1, KVMap.empty.setBool `pp.tagAppFns true)
(SubExpr.Pos.root, KVMap.empty.setBool `pp.all true),
(SubExpr.Pos.root, KVMap.empty.setBool `pp.tagAppFns true)
]
else
{}

View file

@ -58,7 +58,7 @@ section
#eval checkM `(id Nat)
#eval checkM `(Sum Nat Nat)
end
#eval checkM `(id (id Nat)) (Std.RBMap.empty.insert 5 $ KVMap.empty.insert `pp.explicit true)
#eval checkM `(id (id Nat)) (Std.RBMap.empty.insert (SubExpr.Pos.encode #[1]) $ KVMap.empty.insert `pp.explicit true)
-- specify the expected type of `a` in a way that is not erased by the delaborator
def typeAs.{u} (α : Type u) (a : α) := ()

View file

@ -0,0 +1,13 @@
import Lean
open Lean.SubExpr
def ps := [#[], #[0], #[1], #[0,1], #[1,0] , #[0,0], #[1,2,3]]
theorem Pos.roundtrip :
true = ps.all fun x => x == (Pos.toArray <| Pos.ofArray <| x)
:= by native_decide
theorem Pos.append_roundtrip :
true = (List.all
(ps.bind fun p => ps.map fun q => (p,q))
(fun (x,y) => (x ++ y) == (Pos.toArray <| (Pos.append (Pos.ofArray x) (Pos.ofArray y))))
) := by native_decide