From 2fe933cdf59cf492e9ef489411f1f7973b7cb5f8 Mon Sep 17 00:00:00 2001 From: "E.W.Ayers" Date: Fri, 3 Jun 2022 15:12:23 -0400 Subject: [PATCH] refactor: make SubExpr.Pos a definition Instead of an abbreviation. It is easier to understand Pos operations in terms of 'push' and 'pop' rather than through arithmetic. --- .../PrettyPrinter/Delaborator/SubExpr.lean | 14 ++-- src/Lean/SubExpr.lean | 81 ++++++++++++++++++- src/Lean/Widget/InteractiveCode.lean | 4 +- tests/lean/PPRoundtrip.lean | 2 +- tests/lean/run/subexpr.lean | 13 +++ 5 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 tests/lean/run/subexpr.lean diff --git a/src/Lean/PrettyPrinter/Delaborator/SubExpr.lean b/src/Lean/PrettyPrinter/Delaborator/SubExpr.lean index b00e20da70..3e93c1f462 100644 --- a/src/Lean/PrettyPrinter/Delaborator/SubExpr.lean +++ b/src/Lean/PrettyPrinter/Delaborator/SubExpr.lean @@ -33,14 +33,14 @@ variable [MonadLiftT IO m] def getExpr : m Expr := return (← readThe SubExpr).expr def getPos : m Pos := return (← readThe SubExpr).pos -def descend (child : Expr) (childIdx : Pos) (x : m α) : m α := - withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos * maxChildren + childIdx }) x +def descend (child : Expr) (childIdx : Nat) (x : m α) : m α := + withTheReader SubExpr (fun cfg => { cfg with expr := child, pos := cfg.pos.push childIdx }) x def withAppFn (x : m α) : m α := do descend (← getExpr).appFn! 0 x def withAppArg (x : m α) : m α := do descend (← getExpr).appArg! 1 x def withType (x : m α) : m α := do - descend (← Meta.inferType (← getExpr)) (maxChildren - 1) x -- phantom positions for types + descend (← Meta.inferType (← getExpr)) Pos.typeCoord x -- phantom positions for types partial def withAppFnArgs (xf : m α) (xa : α → m α) : m α := do if (← getExpr).isApp then @@ -80,20 +80,20 @@ def withLetBody (x : m α) : m α := do def withNaryFn (x : m α) : m α := do let e ← getExpr let n := e.getAppNumArgs - let newPos := (← getPos) * (maxChildren ^ n) + let newPos := (← getPos).asNat * (Pos.maxChildren ^ n) withTheReader SubExpr (fun cfg => { cfg with expr := e.getAppFn, pos := newPos }) x def withNaryArg (argIdx : Nat) (x : m α) : m α := do let e ← getExpr let args := e.getAppArgs - let newPos := (← getPos) * (maxChildren ^ (args.size - argIdx)) + 1 + let newPos := (← getPos).asNat * (Pos.maxChildren ^ (args.size - argIdx)) + 1 withTheReader SubExpr (fun cfg => { cfg with expr := args[argIdx], pos := newPos }) x end Descend structure HoleIterator where curr : Nat := 2 - top : Nat := maxChildren + top : Nat := Pos.maxChildren deriving Inhabited section Hole @@ -107,7 +107,7 @@ def HoleIterator.toPos (iter : HoleIterator) : Pos := def HoleIterator.next (iter : HoleIterator) : HoleIterator := if (iter.curr+1) == iter.top then - ⟨2*iter.top, maxChildren*iter.top⟩ + ⟨2*iter.top, Pos.maxChildren*iter.top⟩ else ⟨iter.curr+1, iter.top⟩ /-- The positioning scheme guarantees that there will be an infinite number of extra positions diff --git a/src/Lean/SubExpr.lean b/src/Lean/SubExpr.lean index 0da91e00dc..07af20a339 100644 --- a/src/Lean/SubExpr.lean +++ b/src/Lean/SubExpr.lean @@ -1,9 +1,10 @@ /- Copyright (c) 2021 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Sebastian Ullrich, Daniel Selsam, Wojciech Nawrocki +Authors: Sebastian Ullrich, Daniel Selsam, Wojciech Nawrocki, E.W.Ayers -/ import Lean.Meta.Basic +import Lean.Data.Json import Std.Data.RBMap namespace Lean @@ -11,7 +12,80 @@ namespace Lean /-- A position of a subexpression in an expression. See docstring of `SubExpr` for more detail.-/ -abbrev SubExpr.Pos := Nat +def SubExpr.Pos := Nat + +namespace SubExpr.Pos + +def maxChildren := 4 + +/-- The coordinate `3 = maxChildren - 1` is +reserved to denote the type of the expression. -/ +def typeCoord : Nat := maxChildren - 1 + +def asNat : Pos → Nat := id + +instance : Inhabited Pos := show Inhabited Nat by infer_instance +instance : Ord Pos := show Ord Nat by infer_instance +instance : FromJson Pos := show FromJson Nat by infer_instance +instance : ToJson Pos := show ToJson Nat by infer_instance +instance : Repr Pos := show Repr Nat by infer_instance +instance : ToString Pos := show ToString Nat by infer_instance + +/-- The Pos representing the root subexpression. -/ +def root : Pos := (1 : Nat) + +def isRoot (p : Pos) : Bool := p.asNat == 1 + +/-- The coordinate deepest in the Pos. -/ +def head (p : Pos) : Nat := + if p.isRoot then panic! "already at top" + else p.asNat % maxChildren + +def tail (p : Pos) : Pos := + if p.isRoot then panic! "already at top" + else (p.asNat - p.head) / maxChildren + +def push (p : Pos) (c : Nat) : Pos := + if c >= maxChildren then panic! s!"invalid coordinate {c}" + else p.asNat * maxChildren + c + +/-- `pushNZeros p count` runs `.push 0` `count` times. -/ +def pushNZeros (p : Pos) (count : Nat) : Pos := + p.asNat * (maxChildren ^ count) + +variable {α : Type} [Inhabited α] + +/-- Fold over the position starting at the root and heading to the leaf-/ +def foldl (f : α → Nat → α) : α → Pos → α := + fix2 (fun r a p => if p.isRoot then a else f (r a p.tail) p.head) + +/-- Fold over the position starting at the root and heading to the leaf-/ +def foldr (f : Nat → α → α) : Pos → α → α := + fix2 (fun r p a => if p.isRoot then a else r p.tail (f p.head a)) + +def foldrM [Monad M] (f : Nat → α → M α) : Pos → α → M α := + fix2 (fun r p a => if p.isRoot then pure a else f p.head a >>= r p.tail) + +def depth (p : Pos) := + p.foldr (fun _ => Nat.succ) 0 + +/-- Returns true if `pred` is true for each coordinate in `p`.-/ +def all (pred : Nat → Bool) (p : Pos) : Bool := + OptionT.run (m := Id) (foldrM (fun n a => if pred n then pure a else failure) p ()) |>.isSome + +def append : Pos → Pos → Pos := foldl push + +/-- Creates a subexpression `Pos` from an array of 'coordinates'. +Each coordinate is a number {0,1,2} expressing which child subexpression should be explored. +The first coordinate in the array corresponds to the root of the expression tree. -/ +def ofArray (ps : Array Nat) : Pos := + ps.foldl push root + +/-- Decodes a subexpression `Pos` as a sequence of coordinates. See `Pos.fromArray` for details.-/ +def toArray (p : Pos) : Array Nat := + foldl Array.push #[] p + +end SubExpr.Pos /-- An expression and the position of a subexpression within this expression. @@ -30,8 +104,7 @@ structure SubExpr where namespace SubExpr -abbrev maxChildren : Pos := 4 -def mkRoot (e : Expr) : SubExpr := ⟨e, 1⟩ +def mkRoot (e : Expr) : SubExpr := ⟨e, Pos.root⟩ end SubExpr diff --git a/src/Lean/Widget/InteractiveCode.lean b/src/Lean/Widget/InteractiveCode.lean index 965b34c11c..7f25dd8b8c 100644 --- a/src/Lean/Widget/InteractiveCode.lean +++ b/src/Lean/Widget/InteractiveCode.lean @@ -46,8 +46,8 @@ where def ppExprTagged (e : Expr) (explicit : Bool := false) : MetaM CodeWithInfos := do let optsPerPos := if explicit then Std.RBMap.ofList [ - (1, KVMap.empty.setBool `pp.all true), - (1, KVMap.empty.setBool `pp.tagAppFns true) + (SubExpr.Pos.root, KVMap.empty.setBool `pp.all true), + (SubExpr.Pos.root, KVMap.empty.setBool `pp.tagAppFns true) ] else {} diff --git a/tests/lean/PPRoundtrip.lean b/tests/lean/PPRoundtrip.lean index 0ee518c053..221f412a59 100644 --- a/tests/lean/PPRoundtrip.lean +++ b/tests/lean/PPRoundtrip.lean @@ -58,7 +58,7 @@ section #eval checkM `(id Nat) #eval checkM `(Sum Nat Nat) end -#eval checkM `(id (id Nat)) (Std.RBMap.empty.insert 5 $ KVMap.empty.insert `pp.explicit true) +#eval checkM `(id (id Nat)) (Std.RBMap.empty.insert (SubExpr.Pos.encode #[1]) $ KVMap.empty.insert `pp.explicit true) -- specify the expected type of `a` in a way that is not erased by the delaborator def typeAs.{u} (α : Type u) (a : α) := () diff --git a/tests/lean/run/subexpr.lean b/tests/lean/run/subexpr.lean new file mode 100644 index 0000000000..4af899d204 --- /dev/null +++ b/tests/lean/run/subexpr.lean @@ -0,0 +1,13 @@ +import Lean +open Lean.SubExpr + +def ps := [#[], #[0], #[1], #[0,1], #[1,0] , #[0,0], #[1,2,3]] +theorem Pos.roundtrip : + true = ps.all fun x => x == (Pos.toArray <| Pos.ofArray <| x) + := by native_decide + +theorem Pos.append_roundtrip : + true = (List.all + (ps.bind fun p => ps.map fun q => (p,q)) + (fun (x,y) => (x ++ y) == (Pos.toArray <| (Pos.append (Pos.ofArray x) (Pos.ofArray y)))) + ) := by native_decide