From 2ef0146140be166e50f95bdd5cecda2ec820ecb9 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 8 Feb 2022 15:06:14 -0800 Subject: [PATCH] fix: avoid unnecessary `matcheApp.addArg`s at `BRecOn` and `Fix` It fixes the following two cases from #998 ``` attribute [simp] Lean.Export.exportName attribute [simp] Lean.Export.exportLevel ``` --- .../Elab/PreDefinition/Structural/BRecOn.lean | 34 ++++++-- src/Lean/Elab/PreDefinition/WF/Fix.lean | 18 +++-- tests/lean/run/998.lean | 4 +- tests/lean/run/998Export.lean | 79 +++++++++++++++++++ 4 files changed, 118 insertions(+), 17 deletions(-) create mode 100644 tests/lean/run/998Export.lean diff --git a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean index dbd98408d8..ec5688e5ee 100644 --- a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean +++ b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean @@ -80,6 +80,21 @@ private partial def toBelow (below : Expr) (numIndParams : Nat) (recArg : Expr) withBelowDict below numIndParams fun C belowDict => toBelowAux C belowDict recArg below +/-- + This method is used after `matcherApp.addArg arg` to check whether the new type of `arg` has been "refined/modified" + in at least one alternative. +-/ +def refinedArgType (matcherApp : MatcherApp) (arg : Expr) : MetaM Bool := do + let argType ← inferType arg + (Array.zip matcherApp.alts matcherApp.altNumParams).anyM fun (alt, numParams) => + lambdaTelescope alt fun xs altBody => do + if xs.size >= numParams then + let refinedArg := xs[numParams - 1] + trace[Meta.debug] "refinedArgType {argType} =?= {← inferType refinedArg}" + return !(← isDefEq (← inferType refinedArg) argType) + else + return false + private partial def replaceRecApps (recFnName : Name) (recArgInfo : RecArgInfo) (below : Expr) (e : Expr) : M Expr := let rec loop (below : Expr) (e : Expr) : M Expr := do match e with @@ -146,14 +161,17 @@ private partial def replaceRecApps (recFnName : Name) (recArgInfo : RecArgInfo) this may generate weird error messages, when it doesn't work. -/ trace[Elab.definition.structural] "below before matcherApp.addArg: {below} : {← inferType below}" let matcherApp ← mapError (matcherApp.addArg below) (fun msg => "failed to add `below` argument to 'matcher' application" ++ indentD msg) - let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) => - lambdaTelescope alt fun xs altBody => do - trace[Elab.definition.structural] "altNumParams: {numParams}, xs: {xs}" - unless xs.size >= numParams do - throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" - let belowForAlt := xs[numParams - 1] - mkLambdaFVars xs (← loop belowForAlt altBody) - pure { matcherApp with alts := altsNew }.toExpr + if !(← refinedArgType matcherApp below) then + processApp e + else + let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) => + lambdaTelescope alt fun xs altBody => do + trace[Elab.definition.structural] "altNumParams: {numParams}, xs: {xs}" + unless xs.size >= numParams do + throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" + let belowForAlt := xs[numParams - 1] + mkLambdaFVars xs (← loop belowForAlt altBody) + pure { matcherApp with alts := altsNew }.toExpr | none => processApp e | e => ensureNoRecFn recFnName e loop below e diff --git a/src/Lean/Elab/PreDefinition/WF/Fix.lean b/src/Lean/Elab/PreDefinition/WF/Fix.lean index c0cd20558e..7c8afca239 100644 --- a/src/Lean/Elab/PreDefinition/WF/Fix.lean +++ b/src/Lean/Elab/PreDefinition/WF/Fix.lean @@ -10,6 +10,7 @@ import Lean.Elab.Tactic.Basic import Lean.Elab.RecAppSyntax import Lean.Elab.PreDefinition.Basic import Lean.Elab.PreDefinition.Structural.Basic +import Lean.Elab.PreDefinition.Structural.BRecOn namespace Lean.Elab.WF open Meta @@ -62,13 +63,16 @@ private partial def replaceRecApps (recFnName : Name) (decrTactic? : Option Synt processApp e else let matcherApp ← mapError (matcherApp.addArg F) (fun msg => "failed to add functional argument to 'matcher' application" ++ indentD msg) - let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) => - lambdaTelescope alt fun xs altBody => do - unless xs.size >= numParams do - throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" - let FAlt := xs[numParams - 1] - mkLambdaFVars xs (← loop FAlt altBody) - return { matcherApp with alts := altsNew, discrs := (← matcherApp.discrs.mapM (loop F)) }.toExpr + if !(← Structural.refinedArgType matcherApp F) then + processApp e + else + let altsNew ← (Array.zip matcherApp.alts matcherApp.altNumParams).mapM fun (alt, numParams) => + lambdaTelescope alt fun xs altBody => do + unless xs.size >= numParams do + throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" + let FAlt := xs[numParams - 1] + mkLambdaFVars xs (← loop FAlt altBody) + return { matcherApp with alts := altsNew, discrs := (← matcherApp.discrs.mapM (loop F)) }.toExpr | none => processApp e | e => Structural.ensureNoRecFn recFnName e loop F e diff --git a/tests/lean/run/998.lean b/tests/lean/run/998.lean index a1fb0c3626..89ba3044a8 100644 --- a/tests/lean/run/998.lean +++ b/tests/lean/run/998.lean @@ -8,6 +8,6 @@ attribute [simp] Lean.Elab.Term.resolveLocalName.loop -- Mathlib -- attribute [simp] BinaryHeap.heapifyDown -- attribute [simp] ByteSlice.forIn.loop --- attribute [simp] Lean.Export.exportName --- attribute [simp] Lean.Export.exportLevel +-- attribute [simp] Lean.Export.exportName -- Fixed see 998Export.lean +-- attribute [simp] Lean.Export.exportLevel -- Fixed see 998Export.lean -- attribute [simp] Array.heapSort.loop diff --git a/tests/lean/run/998Export.lean b/tests/lean/run/998Export.lean new file mode 100644 index 0000000000..f3ef2af967 --- /dev/null +++ b/tests/lean/run/998Export.lean @@ -0,0 +1,79 @@ +import Lean +open Lean +open Std (HashMap HashSet) + +inductive Entry +| name (n : Name) +| level (n : Level) +| expr (n : Expr) +| defn (n : Name) +deriving Inhabited + +structure Alloc (α) [BEq α] [Hashable α] where + map : HashMap α Nat + next : Nat +deriving Inhabited + +namespace Export + +structure State where + names : Alloc Name := ⟨HashMap.empty.insert Name.anonymous 0, 1⟩ + levels : Alloc Level := ⟨HashMap.empty.insert levelZero 0, 1⟩ + exprs : Alloc Expr + defs : HashSet Name + stk : Array (Bool × Entry) +deriving Inhabited + +class OfState (α : Type) [BEq α] [Hashable α] where + get : State → Alloc α + modify : (Alloc α → Alloc α) → State → State + +instance : OfState Name where + get s := s.names + modify f s := { s with names := f s.names } + +instance : OfState Level where + get s := s.levels + modify f s := { s with levels := f s.levels } + +instance : OfState Expr where + get s := s.exprs + modify f s := { s with exprs := f s.exprs } + +end Export + +abbrev ExportM := StateT Export.State CoreM + +namespace Export + +def alloc {α} [BEq α] [Hashable α] [OfState α] (a : α) : ExportM Nat := do + let n := (OfState.get (α := α) (← get)).next + modify $ OfState.modify (α := α) fun s => {map := s.map.insert a n, next := n+1} + pure n + +def exportName (n : Name) : ExportM Nat := do + match (← get).names.map.find? n with + | some i => pure i + | none => match n with + | Name.anonymous => pure 0 + | Name.num p a _ => let i ← alloc n; IO.println s!"{i} #NI {← exportName p} {a}"; pure i + | Name.str p s _ => let i ← alloc n; IO.println s!"{i} #NS {← exportName p} {s}"; pure i + +attribute [simp] exportName + +def exportLevel (L : Level) : ExportM Nat := do + match (← get).levels.map.find? L with + | some i => pure i + | none => match L with + | Level.zero _ => pure 0 + | Level.succ l _ => + let i ← alloc L; IO.println s!"{i} #US {← exportLevel l}"; pure i + | Level.max l₁ l₂ _ => + let i ← alloc L; IO.println s!"{i} #UM {← exportLevel l₁} {← exportLevel l₂}"; pure i + | Level.imax l₁ l₂ _ => + let i ← alloc L; IO.println s!"{i} #UIM {← exportLevel l₁} {← exportLevel l₂}"; pure i + | Level.param n _ => + let i ← alloc L; IO.println s!"{i} #UP {← exportName n}"; pure i + | Level.mvar n _ => unreachable! + +attribute [simp] exportLevel