From 3c5b3cd91f1e059e73dfbc82e042a6eec1bc8fdc Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 6 Feb 2020 14:03:54 -0800 Subject: [PATCH] feat: add `Expr.replace` helper function --- src/Init/Lean.lean | 1 + src/Init/Lean/ReplaceExpr.lean | 75 ++++++++++++++++++++++++++++++++++ tests/lean/run/replace.lean | 16 ++++++++ 3 files changed, 92 insertions(+) create mode 100644 src/Init/Lean/ReplaceExpr.lean create mode 100644 tests/lean/run/replace.lean diff --git a/src/Init/Lean.lean b/src/Init/Lean.lean index c8fc6010e6..cc803ea7a6 100644 --- a/src/Init/Lean.lean +++ b/src/Init/Lean.lean @@ -22,3 +22,4 @@ import Init.Lean.Linter import Init.Lean.Meta import Init.Lean.Eval import Init.Lean.Structure +import Init.Lean.ReplaceExpr diff --git a/src/Init/Lean/ReplaceExpr.lean b/src/Init/Lean/ReplaceExpr.lean new file mode 100644 index 0000000000..e3a072cce0 --- /dev/null +++ b/src/Init/Lean/ReplaceExpr.lean @@ -0,0 +1,75 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Lean.Expr + +namespace Lean +namespace Expr + +namespace ReplaceImpl + +abbrev cacheSize : USize := 8192 + +structure State := +(keys : Array Expr) +(used : Array Bool) +(results : Array Expr) + +abbrev ReplaceM := StateM State + +@[inline] unsafe def cache (i : USize) (key : Expr) (result : Expr) : ReplaceM Expr := do +modify $ fun s => { keys := s.keys.uset i key lcProof, used := s.used.uset i true lcProof, results := s.results.uset i result lcProof }; +pure result + +@[specialize] unsafe def replaceUnsafeM (f? : Expr → Option Expr) (size : USize) : Expr → ReplaceM Expr +| e => do + c ← get; + let h := ptrAddrUnsafe e; + let i := h % size; + if c.used.uget i lcProof && ptrAddrUnsafe (c.keys.uget i lcProof) == h then + pure $ c.results.uget i lcProof + else match f? e with + | some eNew => cache i e eNew + | none => match e with + | Expr.forallE _ d b _ => do d ← replaceUnsafeM d; b ← replaceUnsafeM b; cache i e $ e.updateForallE! d b + | Expr.lam _ d b _ => do d ← replaceUnsafeM d; b ← replaceUnsafeM b; cache i e $ e.updateLambdaE! d b + | Expr.mdata _ b _ => do b ← replaceUnsafeM b; cache i e $ e.updateMData! b + | Expr.letE _ t v b _ => do t ← replaceUnsafeM t; v ← replaceUnsafeM v; b ← replaceUnsafeM b; cache i e $ e.updateLet! t v b + | Expr.app f a _ => do f ← replaceUnsafeM f; a ← replaceUnsafeM a; cache i e $ e.updateApp! f a + | Expr.proj _ _ b _ => do b ← replaceUnsafeM b; cache i e $ e.updateProj! b + | Expr.localE _ _ _ _ => unreachable! + | e => pure e + +def initCache : State := +{ keys := mkArray cacheSize.toNat (arbitrary _), + results := mkArray cacheSize.toNat (arbitrary _), + used := mkArray cacheSize.toNat false } + +@[inline] unsafe def replaceUnsafe (f? : Expr → Option Expr) (e : Expr) : Expr := +(replaceUnsafeM f? cacheSize e).run' initCache + +end ReplaceImpl + +/- TODO: use withPtrAddr, withPtrEq to avoid unsafe tricks above. + We also need an invariant at `State` and proofs for the `uget` operations. -/ + +@[implementedBy ReplaceImpl.replaceUnsafe] +partial def replace (f? : Expr → Option Expr) : Expr → Expr +| e => + /- This is a reference implementation for the unsafe one above -/ + match f? e with + | some eNew => eNew + | none => match e with + | Expr.forallE _ d b _ => let d := replace d; let b := replace b; e.updateForallE! d b + | Expr.lam _ d b _ => let d := replace d; let b := replace b; e.updateLambdaE! d b + | Expr.mdata _ b _ => let b := replace b; e.updateMData! b + | Expr.letE _ t v b _ => let t := replace t; let v := replace v; let b := replace b; e.updateLet! t v b + | Expr.app f a _ => let f := replace f; let a := replace a; e.updateApp! f a + | Expr.proj _ _ b _ => let b := replace b; e.updateProj! b + | e => e +end Expr + +end Lean diff --git a/tests/lean/run/replace.lean b/tests/lean/run/replace.lean new file mode 100644 index 0000000000..2fa50cd9be --- /dev/null +++ b/tests/lean/run/replace.lean @@ -0,0 +1,16 @@ +import Init.Lean + +open Lean + +def mkBig : Nat → Expr +| 0 => mkConst `a +| (n+1) => mkApp2 (mkConst `f []) (mkBig n) (mkBig n) + +def replaceTest (e : Expr) : Expr := +e.replace $ fun e => match e with + | Expr.const c _ _ => if c == `f then mkConst `g else none + | _ => none + +#eval replaceTest $ mkBig 4 + +#eval (replaceTest $ mkBig 128).getAppFn