From 451abdf79deb2df07ac4d39aa38b492f09abad7a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 10 Jul 2022 08:56:58 -0700 Subject: [PATCH] fix: `Level.update*` functions see #1291 --- src/Lean/Level.lean | 48 ++++++++++++------- src/kernel/level.cpp | 8 ++-- tests/lean/updateLevelIssues.lean | 26 ++++++++++ .../lean/updateLevelIssues.lean.expected.out | 3 ++ 4 files changed, 64 insertions(+), 21 deletions(-) create mode 100644 tests/lean/updateLevelIssues.lean create mode 100644 tests/lean/updateLevelIssues.lean.expected.out diff --git a/src/Lean/Level.lean b/src/Lean/Level.lean index 07e1c7ed32..0b0d76603a 100644 --- a/src/Lean/Level.lean +++ b/src/Lean/Level.lean @@ -477,9 +477,7 @@ instance : Quote Level `level where end Level -/- Similar to `mkLevelMax`, but applies cheap simplifications -/ -@[export lean_level_mk_max_simp] -def mkLevelMax' (u v : Level) : Level := +@[inline] private def mkLevelMaxCore (u v : Level) (elseK : Unit → Level) : Level := let subsumes (u v : Level) : Bool := if v.isExplicit && u.getOffset ≥ v.getOffset then true else match u with @@ -493,16 +491,32 @@ def mkLevelMax' (u v : Level) : Level := else if u.getLevelOffset == v.getLevelOffset then if u.getOffset ≥ v.getOffset then u else v else - mkLevelMax u v + elseK () -/- Similar to `mkLevelIMax`, but applies cheap simplifications -/ -@[export lean_level_mk_imax_simp] -def mkLevelIMax' (u v : Level) : Level := +/- Similar to `mkLevelMax`, but applies cheap simplifications -/ +@[export lean_level_mk_max_simp] +def mkLevelMax' (u v : Level) : Level := + mkLevelMaxCore u v fun _ => mkLevelMax u v + +@[export lean_level_simp_max] +def simpLevelMax' (u v : Level) (d : Level) : Level := + mkLevelMaxCore u v fun _ => d + +@[inline] private def mkLevelIMaxCore (u v : Level) (elseK : Unit → Level) : Level := if v.isNeverZero then mkLevelMax' u v else if v.isZero then v else if u.isZero then v else if u == v then u - else mkLevelIMax u v + else elseK () + +/- Similar to `mkLevelIMax`, but applies cheap simplifications -/ +@[export lean_level_mk_imax_simp] +def mkLevelIMax' (u v : Level) : Level := + mkLevelIMaxCore u v fun _ => mkLevelIMax u v + +@[export lean_level_simp_imax] +def simpLevelIMax' (u v : Level) (d : Level) := + mkLevelIMaxCore u v fun _ => d namespace Level @@ -520,27 +534,27 @@ def updateSucc (lvl : Level) (newLvl : Level) (h : lvl.isSucc) : Level := mkLevelSucc newLvl @[inline] def updateSucc! (lvl : Level) (newLvl : Level) : Level := -match lvl with - | succ lvl d => updateSucc (succ lvl d) newLvl rfl - | _ => panic! "succ level expected" +match h : lvl with + | succ .. => updateSucc lvl newLvl (h ▸ rfl) + | _ => panic! "succ level expected" @[extern "lean_level_update_max"] def updateMax (lvl : Level) (newLhs : Level) (newRhs : Level) (h : lvl.isMax) : Level := mkLevelMax' newLhs newRhs @[inline] def updateMax! (lvl : Level) (newLhs : Level) (newRhs : Level) : Level := - match lvl with - | max lhs rhs d => updateMax (max lhs rhs d) newLhs newRhs rfl - | _ => panic! "max level expected" + match h : lvl with + | max .. => updateMax lvl newLhs newRhs (h ▸ rfl) + | _ => panic! "max level expected" @[extern "lean_level_update_imax"] def updateIMax (lvl : Level) (newLhs : Level) (newRhs : Level) (h : lvl.isIMax) : Level := mkLevelIMax' newLhs newRhs @[inline] def updateIMax! (lvl : Level) (newLhs : Level) (newRhs : Level) : Level := - match lvl with - | imax lhs rhs d => updateIMax (imax lhs rhs d) newLhs newRhs rfl - | _ => panic! "imax level expected" + match h : lvl with + | imax .. => updateIMax lvl newLhs newRhs (h ▸ rfl) + | _ => panic! "imax level expected" def mkNaryMax : List Level → Level | [] => levelZero diff --git a/src/kernel/level.cpp b/src/kernel/level.cpp index 1c75c4b099..9ed8d461d9 100644 --- a/src/kernel/level.cpp +++ b/src/kernel/level.cpp @@ -31,6 +31,8 @@ extern "C" object * lean_level_mk_max(obj_arg, obj_arg); extern "C" object * lean_level_mk_imax(obj_arg, obj_arg); extern "C" object * lean_level_mk_max_simp(obj_arg, obj_arg); extern "C" object * lean_level_mk_imax_simp(obj_arg, obj_arg); +extern "C" object * lean_level_simp_max(obj_arg, obj_arg, obj_arg); +extern "C" object * lean_level_simp_imax(obj_arg, obj_arg, obj_arg); level mk_succ(level const & l) { return level(lean_level_mk_succ(l.to_obj_arg())); } level mk_max_core(level const & l1, level const & l2) { return level(lean_level_mk_max(l1.to_obj_arg(), l2.to_obj_arg())); } @@ -303,8 +305,7 @@ extern "C" LEAN_EXPORT object * lean_level_update_succ(obj_arg l, obj_arg new_ar extern "C" LEAN_EXPORT object * lean_level_update_max(obj_arg l, obj_arg new_lhs, obj_arg new_rhs) { if (max_lhs(TO_REF(level, l)).raw() == new_lhs && max_rhs(TO_REF(level, l)).raw() == new_rhs) { - lean_dec(new_lhs); lean_dec(new_rhs); - return l; + return lean_level_simp_max(new_lhs, new_rhs, l); } else { lean_dec_ref(l); return lean_level_mk_max_simp(new_lhs, new_rhs); @@ -313,8 +314,7 @@ extern "C" LEAN_EXPORT object * lean_level_update_max(obj_arg l, obj_arg new_lhs extern "C" LEAN_EXPORT object * lean_level_update_imax(obj_arg l, obj_arg new_lhs, obj_arg new_rhs) { if (imax_lhs(TO_REF(level, l)).raw() == new_lhs && imax_rhs(TO_REF(level, l)).raw() == new_rhs) { - lean_dec(new_lhs); lean_dec(new_rhs); - return l; + return lean_level_simp_imax(new_lhs, new_rhs, l); } else { lean_dec_ref(l); return lean_level_mk_imax_simp(new_lhs, new_rhs); diff --git a/tests/lean/updateLevelIssues.lean b/tests/lean/updateLevelIssues.lean new file mode 100644 index 0000000000..0149b1699d --- /dev/null +++ b/tests/lean/updateLevelIssues.lean @@ -0,0 +1,26 @@ +import Lean +open Lean + +@[noinline] def noinline (a : α) := a + +#eval + let b := levelZero + let a1 := mkLevelParam `a + let a2 := mkLevelParam (noinline `a) + let l := mkLevelMax a1 b + (l.updateMax! a1 b).isMax == (l.updateMax! a2 b).isMax + +#eval + let b := levelZero + let a1 := mkLevelParam `a + let l := mkLevelMax a1 b + assert! (l.updateMax! a1 b) == a1 + toString (l.updateMax! a1 b) + +#eval + let b := mkLevelParam `b + let a1 := mkLevelParam `a + let l := mkLevelMax a1 b + assert! (l.updateMax! a1 b) == l + assert! ptrAddrUnsafe (l.updateMax! a1 b) == ptrAddrUnsafe l + toString (l.updateMax! a1 b) diff --git a/tests/lean/updateLevelIssues.lean.expected.out b/tests/lean/updateLevelIssues.lean.expected.out new file mode 100644 index 0000000000..468289e520 --- /dev/null +++ b/tests/lean/updateLevelIssues.lean.expected.out @@ -0,0 +1,3 @@ +true +"a" +"max a b"