feat: support for simplifying match discriminants

This commit is contained in:
Leonardo de Moura 2021-03-16 15:51:36 -07:00
parent 2970c6ca79
commit 8227d3afcd
4 changed files with 145 additions and 35 deletions

View file

@ -23,9 +23,14 @@ structure MatcherInfo where
altNumParams : Array Nat
uElimPos? : Option Nat
def MatcherInfo.numAlts (matcherInfo : MatcherInfo) : Nat :=
matcherInfo.altNumParams.size
def MatcherInfo.numAlts (info : MatcherInfo) : Nat :=
info.altNumParams.size
def MatcherInfo.arity (info : MatcherInfo) : Nat :=
info.numParams + 1 + info.numDiscrs + info.numAlts
def MatcherInfo.getMotivePos (info : MatcherInfo) : Nat :=
info.numParams
namespace Extension
structure Entry where
@ -42,8 +47,8 @@ def State.switch (s : State) : State := { s with map := s.map.switch }
builtin_initialize extension : SimplePersistentEnvExtension Entry State ←
registerSimplePersistentEnvExtension {
name := `matcher,
addEntryFn := State.addEntry,
name := `matcher
addEntryFn := State.addEntry
addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch
}
@ -62,13 +67,11 @@ end Match
export Match (MatcherInfo)
def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := do
let env ← getEnv
return Match.Extension.getMatcherInfo? env declName
def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) :=
return Match.Extension.getMatcherInfo? (← getEnv) declName
def isMatcher (declName : Name) : MetaM Bool := do
let info? ← getMatcherInfo? declName
return info?.isSome
def isMatcher (declName : Name) : MetaM Bool :=
return (← getMatcherInfo? declName).isSome
structure MatcherApp where
matcherName : Name
@ -86,19 +89,19 @@ def matchMatcherApp? (e : Expr) : MetaM (Option MatcherApp) :=
| Expr.const declName declLevels _ => do
let some info ← getMatcherInfo? declName | pure none
let args := e.getAppArgs
if args.size < info.numParams + 1 + info.numDiscrs + info.numAlts then
if args.size < info.arity then
return none
else
return some {
matcherName := declName,
matcherLevels := declLevels.toArray,
uElimPos? := info.uElimPos?,
params := args.extract 0 info.numParams,
motive := args.get! info.numParams,
discrs := args.extract (info.numParams + 1) (info.numParams + 1 + info.numDiscrs),
altNumParams := info.altNumParams,
alts := args.extract (info.numParams + 1 + info.numDiscrs) (info.numParams + 1 + info.numDiscrs + info.numAlts),
remaining := args.extract (info.numParams + 1 + info.numDiscrs + info.numAlts) args.size
matcherName := declName
matcherLevels := declLevels.toArray
uElimPos? := info.uElimPos?
params := args.extract 0 info.numParams
motive := args[info.getMotivePos]
discrs := args[info.numParams + 1 : info.numParams + 1 + info.numDiscrs]
altNumParams := info.altNumParams
alts := args[info.numParams + 1 + info.numDiscrs : info.numParams + 1 + info.numDiscrs + info.numAlts]
remaining := args[info.numParams + 1 + info.numDiscrs + info.numAlts : args.size]
}
| _ => return none

View file

@ -41,6 +41,14 @@ private def mkCongr (r₁ r₂ : Result) : MetaM Result :=
| none, some h => return { expr := e, proof? := (← Meta.mkCongrArg r₁.expr h) }
| some h₁, some h₂ => return { expr := e, proof? := (← Meta.mkCongr h₁ h₂) }
private def mkCongrDep (r₁ r₂ : Result) : MetaM Result := do
let e := mkApp r₁.expr r₂.expr
match r₁.proof?, r₂.proof? with
| none, none => return { expr := e }
| some h, none => return { expr := e, proof? := (← Meta.mkCongrFun h r₂.expr) }
| none, some h => return { expr := e, proof? := (← Meta.mkCongrDepArg r₁.expr h) }
| some h₁, some h₂ => return { expr := e, proof? := (← Meta.mkCongrDep h₁ h₂) }
private def mkImpCongr (r₁ r₂ : Result) : MetaM Result := do
let e ← mkArrow r₁.expr r₂.expr
match r₁.proof?, r₂.proof? with
@ -112,6 +120,28 @@ private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do
| some e => reduce e
| none => return e
/--
Compute a mask `m` of size `numDiscrs` s.t. discriminant `i` can be rewritten if `m[i] == true`.
`motive` is the motive of the matcher application.
For example, suppose the motive is `fun (n : Nat) (v : Vec α n) => Nat`, then
we can rewrite the second discriminant. We cannot rewrite the first because the type
of the second depends on it. -/
private def getMatcherDiscrCongrMask (motive : Expr) (numDiscrs : Nat) : Array Bool :=
let updateMask (e : Expr) (i : Nat) (mask : Array Bool) : Array Bool :=
if !e.hasLooseBVars then mask
else i.fold (init := mask) fun j mask => if e.hasLooseBVar j then mask.set! (i - j - 1) false else mask
let rec loop (e : Expr) (i : Nat) (mask : Array Bool): Array Bool :=
match e with
| Expr.lam _ d b _ => loop b (i+1) (updateMask e i mask)
| _ =>
if i != numDiscrs then
-- This is an ill-formed matcher application, it should not happen in practice.
mkArray numDiscrs false
else
updateMask e i mask
loop motive 0 (mkArray numDiscrs true)
private partial def dsimp (e : Expr) : M Expr := do
transform e (post := fun e => return TransformStep.done (← reduce e))
@ -221,16 +251,47 @@ where
else
return none
congrMatch (info : MatcherInfo) (e : Expr) : M Result :=
withParent e <| e.withApp fun f args => do
if args.size < info.arity then
congrDefault e -- partially applied matcher application
else
let mask := getMatcherDiscrCongrMask args[info.getMotivePos] info.numDiscrs
-- We don't rewrite the parameters nor the motive
let mut r : Result := { expr := mkAppN f args[:info.numParams+1] }
-- process discriminants
let firstDiscrPos := info.numParams+1
for i in [firstDiscrPos : firstDiscrPos + info.numDiscrs] do
if mask[i - firstDiscrPos] then
r ← mkCongrDep r (← simp args[i])
else
r ← mkCongrFun r (← dsimp args[i])
-- process alternatives
let firstAltPos := firstDiscrPos + info.numDiscrs
for i in [firstAltPos : firstAltPos + info.numAlts] do
r ← mkCongr r (← simp args[i])
-- process over-application arguments
for i in [info.arity : args.size] do
let fType ← whnfD (← inferType r.expr)
if fType.isArrow then
r ← mkCongr r (← simp args[i])
else
r ← mkCongrFun r (← dsimp args[i])
return r
congr (e : Expr) : M Result := do
let f := e.getAppFn
if f.isConst then
let congrLemmas ← getCongrLemmas
let cs := congrLemmas.get f.constName!
for c in cs do
match (← tryCongrLemma? c e) with
| none => pure ()
| some r => return r
congrDefault e
if let some matcherInfo ← getMatcherInfo? f.constName! then
congrMatch matcherInfo e
else
let congrLemmas ← getCongrLemmas
let cs := congrLemmas.get f.constName!
for c in cs do
match (← tryCongrLemma? c e) with
| none => pure ()
| some r => return r
congrDefault e
else
congrDefault e

View file

@ -190,22 +190,22 @@ g.foo :=
endPos := { line := 40, column := 47 },
endCharUtf16 := 47 } }
optParam :=
{ range := { pos := { line := 152, column := 13 },
{ range := { pos := { line := 158, column := 13 },
charUtf16 := 13,
endPos := { line := 152, column := 66 },
endPos := { line := 158, column := 66 },
endCharUtf16 := 66 },
selectionRange := { pos := { line := 152, column := 17 },
selectionRange := { pos := { line := 158, column := 17 },
charUtf16 := 17,
endPos := { line := 152, column := 25 },
endPos := { line := 158, column := 25 },
endCharUtf16 := 25 } }
namedPattern :=
{ range := { pos := { line := 161, column := 13 },
{ range := { pos := { line := 167, column := 13 },
charUtf16 := 13,
endPos := { line := 161, column := 61 },
endPos := { line := 167, column := 61 },
endCharUtf16 := 61 },
selectionRange := { pos := { line := 161, column := 17 },
selectionRange := { pos := { line := 167, column := 17 },
charUtf16 := 17,
endPos := { line := 161, column := 29 },
endPos := { line := 167, column := 29 },
endCharUtf16 := 29 } }
Lean.Meta.forallTelescopeReducing :=
{ range := { pos := { line := 699, column := 0 },

View file

@ -0,0 +1,46 @@
inductive Vec (α : Type u) : Nat → Type u
| nil : Vec α 0
| cons : α → {n : Nat} → Vec α n → Vec α (n+1)
def Vec.repeat (a : α) (n : Nat) : Vec α n :=
match n with
| 0 => nil
| n+1 => cons a (repeat a n)
instance [Inhabited α] : Inhabited (Vec α n) where
default := Vec.repeat arbitrary n
def Vec.map (v : Vec α n) (f : α → β) : Vec β n :=
match n, v with
| _, nil => nil
| _, cons a as => cons (f a) (map as f)
def Vec.reverse (v : Vec α n) : Vec α n :=
let rec loop : {n m : Nat} → Vec α n → Vec α m → Vec α (n + m)
| _, _, nil, w => Nat.zero_add .. ▸ w
| _, _, cons a as, w => Nat.add_assoc .. ▸ loop as (Nat.add_comm .. ▸ cons a w)
loop v nil
@[simp] theorem map_id (v : Vec α n) : v.map id = v := by
induction v with
| nil => rfl
| cons a as ih => simp [Vec.map, ih]
def foo [Add α] (v w : Vec α n) (f : αα) (a : α) : α :=
match n, v.map f, w.map f with
| _, Vec.nil, Vec.nil => a
| _, Vec.cons a .., Vec.cons b .. => a + b
theorem ex1 (a b : Nat) (as : Vec Nat n) : foo (Vec.cons a as) (Vec.cons b as) id 0 = a + b := by
simp [foo]
#print ex1
def bla (b : Bool) (f g : α → β) (a : α) : β :=
(match b with
| true => f | false => g) a
theorem ex2 (h : b = false) : bla b (fun x => x + 1) id 10 = 10 := by
simp [bla, h]
#print ex2