From 8227d3afcd178c8ae418bcb6e15f995f7c763e3f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 16 Mar 2021 15:51:36 -0700 Subject: [PATCH] feat: support for simplifying `match` discriminants --- src/Lean/Meta/Match/MatcherInfo.lean | 43 ++++++++-------- src/Lean/Meta/Tactic/Simp/Main.lean | 75 +++++++++++++++++++++++++--- tests/lean/docStr.lean.expected.out | 16 +++--- tests/lean/run/simpMatchDiscr.lean | 46 +++++++++++++++++ 4 files changed, 145 insertions(+), 35 deletions(-) create mode 100644 tests/lean/run/simpMatchDiscr.lean diff --git a/src/Lean/Meta/Match/MatcherInfo.lean b/src/Lean/Meta/Match/MatcherInfo.lean index bd99527581..9a7f161a5b 100644 --- a/src/Lean/Meta/Match/MatcherInfo.lean +++ b/src/Lean/Meta/Match/MatcherInfo.lean @@ -23,9 +23,14 @@ structure MatcherInfo where altNumParams : Array Nat uElimPos? : Option Nat -def MatcherInfo.numAlts (matcherInfo : MatcherInfo) : Nat := - matcherInfo.altNumParams.size +def MatcherInfo.numAlts (info : MatcherInfo) : Nat := + info.altNumParams.size +def MatcherInfo.arity (info : MatcherInfo) : Nat := + info.numParams + 1 + info.numDiscrs + info.numAlts + +def MatcherInfo.getMotivePos (info : MatcherInfo) : Nat := + info.numParams namespace Extension structure Entry where @@ -42,8 +47,8 @@ def State.switch (s : State) : State := { s with map := s.map.switch } builtin_initialize extension : SimplePersistentEnvExtension Entry State ← registerSimplePersistentEnvExtension { - name := `matcher, - addEntryFn := State.addEntry, + name := `matcher + addEntryFn := State.addEntry addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch } @@ -62,13 +67,11 @@ end Match export Match (MatcherInfo) -def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := do - let env ← getEnv - return Match.Extension.getMatcherInfo? env declName +def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := + return Match.Extension.getMatcherInfo? (← getEnv) declName -def isMatcher (declName : Name) : MetaM Bool := do - let info? ← getMatcherInfo? declName - return info?.isSome +def isMatcher (declName : Name) : MetaM Bool := + return (← getMatcherInfo? declName).isSome structure MatcherApp where matcherName : Name @@ -86,19 +89,19 @@ def matchMatcherApp? (e : Expr) : MetaM (Option MatcherApp) := | Expr.const declName declLevels _ => do let some info ← getMatcherInfo? declName | pure none let args := e.getAppArgs - if args.size < info.numParams + 1 + info.numDiscrs + info.numAlts then + if args.size < info.arity then return none else return some { - matcherName := declName, - matcherLevels := declLevels.toArray, - uElimPos? := info.uElimPos?, - params := args.extract 0 info.numParams, - motive := args.get! info.numParams, - discrs := args.extract (info.numParams + 1) (info.numParams + 1 + info.numDiscrs), - altNumParams := info.altNumParams, - alts := args.extract (info.numParams + 1 + info.numDiscrs) (info.numParams + 1 + info.numDiscrs + info.numAlts), - remaining := args.extract (info.numParams + 1 + info.numDiscrs + info.numAlts) args.size + matcherName := declName + matcherLevels := declLevels.toArray + uElimPos? := info.uElimPos? + params := args.extract 0 info.numParams + motive := args[info.getMotivePos] + discrs := args[info.numParams + 1 : info.numParams + 1 + info.numDiscrs] + altNumParams := info.altNumParams + alts := args[info.numParams + 1 + info.numDiscrs : info.numParams + 1 + info.numDiscrs + info.numAlts] + remaining := args[info.numParams + 1 + info.numDiscrs + info.numAlts : args.size] } | _ => return none diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 531d40861f..502dd38ca0 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -41,6 +41,14 @@ private def mkCongr (r₁ r₂ : Result) : MetaM Result := | none, some h => return { expr := e, proof? := (← Meta.mkCongrArg r₁.expr h) } | some h₁, some h₂ => return { expr := e, proof? := (← Meta.mkCongr h₁ h₂) } +private def mkCongrDep (r₁ r₂ : Result) : MetaM Result := do + let e := mkApp r₁.expr r₂.expr + match r₁.proof?, r₂.proof? with + | none, none => return { expr := e } + | some h, none => return { expr := e, proof? := (← Meta.mkCongrFun h r₂.expr) } + | none, some h => return { expr := e, proof? := (← Meta.mkCongrDepArg r₁.expr h) } + | some h₁, some h₂ => return { expr := e, proof? := (← Meta.mkCongrDep h₁ h₂) } + private def mkImpCongr (r₁ r₂ : Result) : MetaM Result := do let e ← mkArrow r₁.expr r₂.expr match r₁.proof?, r₂.proof? with @@ -112,6 +120,28 @@ private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do | some e => reduce e | none => return e +/-- + Compute a mask `m` of size `numDiscrs` s.t. discriminant `i` can be rewritten if `m[i] == true`. + `motive` is the motive of the matcher application. + + For example, suppose the motive is `fun (n : Nat) (v : Vec α n) => Nat`, then + we can rewrite the second discriminant. We cannot rewrite the first because the type + of the second depends on it. -/ +private def getMatcherDiscrCongrMask (motive : Expr) (numDiscrs : Nat) : Array Bool := + let updateMask (e : Expr) (i : Nat) (mask : Array Bool) : Array Bool := + if !e.hasLooseBVars then mask + else i.fold (init := mask) fun j mask => if e.hasLooseBVar j then mask.set! (i - j - 1) false else mask + let rec loop (e : Expr) (i : Nat) (mask : Array Bool): Array Bool := + match e with + | Expr.lam _ d b _ => loop b (i+1) (updateMask e i mask) + | _ => + if i != numDiscrs then + -- This is an ill-formed matcher application, it should not happen in practice. + mkArray numDiscrs false + else + updateMask e i mask + loop motive 0 (mkArray numDiscrs true) + private partial def dsimp (e : Expr) : M Expr := do transform e (post := fun e => return TransformStep.done (← reduce e)) @@ -221,16 +251,47 @@ where else return none + congrMatch (info : MatcherInfo) (e : Expr) : M Result := + withParent e <| e.withApp fun f args => do + if args.size < info.arity then + congrDefault e -- partially applied matcher application + else + let mask := getMatcherDiscrCongrMask args[info.getMotivePos] info.numDiscrs + -- We don't rewrite the parameters nor the motive + let mut r : Result := { expr := mkAppN f args[:info.numParams+1] } + -- process discriminants + let firstDiscrPos := info.numParams+1 + for i in [firstDiscrPos : firstDiscrPos + info.numDiscrs] do + if mask[i - firstDiscrPos] then + r ← mkCongrDep r (← simp args[i]) + else + r ← mkCongrFun r (← dsimp args[i]) + -- process alternatives + let firstAltPos := firstDiscrPos + info.numDiscrs + for i in [firstAltPos : firstAltPos + info.numAlts] do + r ← mkCongr r (← simp args[i]) + -- process over-application arguments + for i in [info.arity : args.size] do + let fType ← whnfD (← inferType r.expr) + if fType.isArrow then + r ← mkCongr r (← simp args[i]) + else + r ← mkCongrFun r (← dsimp args[i]) + return r + congr (e : Expr) : M Result := do let f := e.getAppFn if f.isConst then - let congrLemmas ← getCongrLemmas - let cs := congrLemmas.get f.constName! - for c in cs do - match (← tryCongrLemma? c e) with - | none => pure () - | some r => return r - congrDefault e + if let some matcherInfo ← getMatcherInfo? f.constName! then + congrMatch matcherInfo e + else + let congrLemmas ← getCongrLemmas + let cs := congrLemmas.get f.constName! + for c in cs do + match (← tryCongrLemma? c e) with + | none => pure () + | some r => return r + congrDefault e else congrDefault e diff --git a/tests/lean/docStr.lean.expected.out b/tests/lean/docStr.lean.expected.out index e6f19ee4b4..194784d538 100644 --- a/tests/lean/docStr.lean.expected.out +++ b/tests/lean/docStr.lean.expected.out @@ -190,22 +190,22 @@ g.foo := endPos := { line := 40, column := 47 }, endCharUtf16 := 47 } } optParam := - { range := { pos := { line := 152, column := 13 }, + { range := { pos := { line := 158, column := 13 }, charUtf16 := 13, - endPos := { line := 152, column := 66 }, + endPos := { line := 158, column := 66 }, endCharUtf16 := 66 }, - selectionRange := { pos := { line := 152, column := 17 }, + selectionRange := { pos := { line := 158, column := 17 }, charUtf16 := 17, - endPos := { line := 152, column := 25 }, + endPos := { line := 158, column := 25 }, endCharUtf16 := 25 } } namedPattern := - { range := { pos := { line := 161, column := 13 }, + { range := { pos := { line := 167, column := 13 }, charUtf16 := 13, - endPos := { line := 161, column := 61 }, + endPos := { line := 167, column := 61 }, endCharUtf16 := 61 }, - selectionRange := { pos := { line := 161, column := 17 }, + selectionRange := { pos := { line := 167, column := 17 }, charUtf16 := 17, - endPos := { line := 161, column := 29 }, + endPos := { line := 167, column := 29 }, endCharUtf16 := 29 } } Lean.Meta.forallTelescopeReducing := { range := { pos := { line := 699, column := 0 }, diff --git a/tests/lean/run/simpMatchDiscr.lean b/tests/lean/run/simpMatchDiscr.lean new file mode 100644 index 0000000000..f67e8e8567 --- /dev/null +++ b/tests/lean/run/simpMatchDiscr.lean @@ -0,0 +1,46 @@ +inductive Vec (α : Type u) : Nat → Type u + | nil : Vec α 0 + | cons : α → {n : Nat} → Vec α n → Vec α (n+1) + +def Vec.repeat (a : α) (n : Nat) : Vec α n := + match n with + | 0 => nil + | n+1 => cons a (repeat a n) + +instance [Inhabited α] : Inhabited (Vec α n) where + default := Vec.repeat arbitrary n + +def Vec.map (v : Vec α n) (f : α → β) : Vec β n := + match n, v with + | _, nil => nil + | _, cons a as => cons (f a) (map as f) + +def Vec.reverse (v : Vec α n) : Vec α n := + let rec loop : {n m : Nat} → Vec α n → Vec α m → Vec α (n + m) + | _, _, nil, w => Nat.zero_add .. ▸ w + | _, _, cons a as, w => Nat.add_assoc .. ▸ loop as (Nat.add_comm .. ▸ cons a w) + loop v nil + +@[simp] theorem map_id (v : Vec α n) : v.map id = v := by + induction v with + | nil => rfl + | cons a as ih => simp [Vec.map, ih] + +def foo [Add α] (v w : Vec α n) (f : α → α) (a : α) : α := + match n, v.map f, w.map f with + | _, Vec.nil, Vec.nil => a + | _, Vec.cons a .., Vec.cons b .. => a + b + +theorem ex1 (a b : Nat) (as : Vec Nat n) : foo (Vec.cons a as) (Vec.cons b as) id 0 = a + b := by + simp [foo] + +#print ex1 + +def bla (b : Bool) (f g : α → β) (a : α) : β := + (match b with + | true => f | false => g) a + +theorem ex2 (h : b = false) : bla b (fun x => x + 1) id 10 = 10 := by + simp [bla, h] + +#print ex2