feat: support for simplifying match discriminants
This commit is contained in:
parent
2970c6ca79
commit
8227d3afcd
4 changed files with 145 additions and 35 deletions
|
|
@ -23,9 +23,14 @@ structure MatcherInfo where
|
|||
altNumParams : Array Nat
|
||||
uElimPos? : Option Nat
|
||||
|
||||
def MatcherInfo.numAlts (matcherInfo : MatcherInfo) : Nat :=
|
||||
matcherInfo.altNumParams.size
|
||||
def MatcherInfo.numAlts (info : MatcherInfo) : Nat :=
|
||||
info.altNumParams.size
|
||||
|
||||
def MatcherInfo.arity (info : MatcherInfo) : Nat :=
|
||||
info.numParams + 1 + info.numDiscrs + info.numAlts
|
||||
|
||||
def MatcherInfo.getMotivePos (info : MatcherInfo) : Nat :=
|
||||
info.numParams
|
||||
namespace Extension
|
||||
|
||||
structure Entry where
|
||||
|
|
@ -42,8 +47,8 @@ def State.switch (s : State) : State := { s with map := s.map.switch }
|
|||
|
||||
builtin_initialize extension : SimplePersistentEnvExtension Entry State ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
name := `matcher,
|
||||
addEntryFn := State.addEntry,
|
||||
name := `matcher
|
||||
addEntryFn := State.addEntry
|
||||
addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch
|
||||
}
|
||||
|
||||
|
|
@ -62,13 +67,11 @@ end Match
|
|||
|
||||
export Match (MatcherInfo)
|
||||
|
||||
def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) := do
|
||||
let env ← getEnv
|
||||
return Match.Extension.getMatcherInfo? env declName
|
||||
def getMatcherInfo? (declName : Name) : MetaM (Option MatcherInfo) :=
|
||||
return Match.Extension.getMatcherInfo? (← getEnv) declName
|
||||
|
||||
def isMatcher (declName : Name) : MetaM Bool := do
|
||||
let info? ← getMatcherInfo? declName
|
||||
return info?.isSome
|
||||
def isMatcher (declName : Name) : MetaM Bool :=
|
||||
return (← getMatcherInfo? declName).isSome
|
||||
|
||||
structure MatcherApp where
|
||||
matcherName : Name
|
||||
|
|
@ -86,19 +89,19 @@ def matchMatcherApp? (e : Expr) : MetaM (Option MatcherApp) :=
|
|||
| Expr.const declName declLevels _ => do
|
||||
let some info ← getMatcherInfo? declName | pure none
|
||||
let args := e.getAppArgs
|
||||
if args.size < info.numParams + 1 + info.numDiscrs + info.numAlts then
|
||||
if args.size < info.arity then
|
||||
return none
|
||||
else
|
||||
return some {
|
||||
matcherName := declName,
|
||||
matcherLevels := declLevels.toArray,
|
||||
uElimPos? := info.uElimPos?,
|
||||
params := args.extract 0 info.numParams,
|
||||
motive := args.get! info.numParams,
|
||||
discrs := args.extract (info.numParams + 1) (info.numParams + 1 + info.numDiscrs),
|
||||
altNumParams := info.altNumParams,
|
||||
alts := args.extract (info.numParams + 1 + info.numDiscrs) (info.numParams + 1 + info.numDiscrs + info.numAlts),
|
||||
remaining := args.extract (info.numParams + 1 + info.numDiscrs + info.numAlts) args.size
|
||||
matcherName := declName
|
||||
matcherLevels := declLevels.toArray
|
||||
uElimPos? := info.uElimPos?
|
||||
params := args.extract 0 info.numParams
|
||||
motive := args[info.getMotivePos]
|
||||
discrs := args[info.numParams + 1 : info.numParams + 1 + info.numDiscrs]
|
||||
altNumParams := info.altNumParams
|
||||
alts := args[info.numParams + 1 + info.numDiscrs : info.numParams + 1 + info.numDiscrs + info.numAlts]
|
||||
remaining := args[info.numParams + 1 + info.numDiscrs + info.numAlts : args.size]
|
||||
}
|
||||
| _ => return none
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,14 @@ private def mkCongr (r₁ r₂ : Result) : MetaM Result :=
|
|||
| none, some h => return { expr := e, proof? := (← Meta.mkCongrArg r₁.expr h) }
|
||||
| some h₁, some h₂ => return { expr := e, proof? := (← Meta.mkCongr h₁ h₂) }
|
||||
|
||||
private def mkCongrDep (r₁ r₂ : Result) : MetaM Result := do
|
||||
let e := mkApp r₁.expr r₂.expr
|
||||
match r₁.proof?, r₂.proof? with
|
||||
| none, none => return { expr := e }
|
||||
| some h, none => return { expr := e, proof? := (← Meta.mkCongrFun h r₂.expr) }
|
||||
| none, some h => return { expr := e, proof? := (← Meta.mkCongrDepArg r₁.expr h) }
|
||||
| some h₁, some h₂ => return { expr := e, proof? := (← Meta.mkCongrDep h₁ h₂) }
|
||||
|
||||
private def mkImpCongr (r₁ r₂ : Result) : MetaM Result := do
|
||||
let e ← mkArrow r₁.expr r₂.expr
|
||||
match r₁.proof?, r₂.proof? with
|
||||
|
|
@ -112,6 +120,28 @@ private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do
|
|||
| some e => reduce e
|
||||
| none => return e
|
||||
|
||||
/--
|
||||
Compute a mask `m` of size `numDiscrs` s.t. discriminant `i` can be rewritten if `m[i] == true`.
|
||||
`motive` is the motive of the matcher application.
|
||||
|
||||
For example, suppose the motive is `fun (n : Nat) (v : Vec α n) => Nat`, then
|
||||
we can rewrite the second discriminant. We cannot rewrite the first because the type
|
||||
of the second depends on it. -/
|
||||
private def getMatcherDiscrCongrMask (motive : Expr) (numDiscrs : Nat) : Array Bool :=
|
||||
let updateMask (e : Expr) (i : Nat) (mask : Array Bool) : Array Bool :=
|
||||
if !e.hasLooseBVars then mask
|
||||
else i.fold (init := mask) fun j mask => if e.hasLooseBVar j then mask.set! (i - j - 1) false else mask
|
||||
let rec loop (e : Expr) (i : Nat) (mask : Array Bool): Array Bool :=
|
||||
match e with
|
||||
| Expr.lam _ d b _ => loop b (i+1) (updateMask e i mask)
|
||||
| _ =>
|
||||
if i != numDiscrs then
|
||||
-- This is an ill-formed matcher application, it should not happen in practice.
|
||||
mkArray numDiscrs false
|
||||
else
|
||||
updateMask e i mask
|
||||
loop motive 0 (mkArray numDiscrs true)
|
||||
|
||||
private partial def dsimp (e : Expr) : M Expr := do
|
||||
transform e (post := fun e => return TransformStep.done (← reduce e))
|
||||
|
||||
|
|
@ -221,16 +251,47 @@ where
|
|||
else
|
||||
return none
|
||||
|
||||
congrMatch (info : MatcherInfo) (e : Expr) : M Result :=
|
||||
withParent e <| e.withApp fun f args => do
|
||||
if args.size < info.arity then
|
||||
congrDefault e -- partially applied matcher application
|
||||
else
|
||||
let mask := getMatcherDiscrCongrMask args[info.getMotivePos] info.numDiscrs
|
||||
-- We don't rewrite the parameters nor the motive
|
||||
let mut r : Result := { expr := mkAppN f args[:info.numParams+1] }
|
||||
-- process discriminants
|
||||
let firstDiscrPos := info.numParams+1
|
||||
for i in [firstDiscrPos : firstDiscrPos + info.numDiscrs] do
|
||||
if mask[i - firstDiscrPos] then
|
||||
r ← mkCongrDep r (← simp args[i])
|
||||
else
|
||||
r ← mkCongrFun r (← dsimp args[i])
|
||||
-- process alternatives
|
||||
let firstAltPos := firstDiscrPos + info.numDiscrs
|
||||
for i in [firstAltPos : firstAltPos + info.numAlts] do
|
||||
r ← mkCongr r (← simp args[i])
|
||||
-- process over-application arguments
|
||||
for i in [info.arity : args.size] do
|
||||
let fType ← whnfD (← inferType r.expr)
|
||||
if fType.isArrow then
|
||||
r ← mkCongr r (← simp args[i])
|
||||
else
|
||||
r ← mkCongrFun r (← dsimp args[i])
|
||||
return r
|
||||
|
||||
congr (e : Expr) : M Result := do
|
||||
let f := e.getAppFn
|
||||
if f.isConst then
|
||||
let congrLemmas ← getCongrLemmas
|
||||
let cs := congrLemmas.get f.constName!
|
||||
for c in cs do
|
||||
match (← tryCongrLemma? c e) with
|
||||
| none => pure ()
|
||||
| some r => return r
|
||||
congrDefault e
|
||||
if let some matcherInfo ← getMatcherInfo? f.constName! then
|
||||
congrMatch matcherInfo e
|
||||
else
|
||||
let congrLemmas ← getCongrLemmas
|
||||
let cs := congrLemmas.get f.constName!
|
||||
for c in cs do
|
||||
match (← tryCongrLemma? c e) with
|
||||
| none => pure ()
|
||||
| some r => return r
|
||||
congrDefault e
|
||||
else
|
||||
congrDefault e
|
||||
|
||||
|
|
|
|||
|
|
@ -190,22 +190,22 @@ g.foo :=
|
|||
endPos := { line := 40, column := 47 },
|
||||
endCharUtf16 := 47 } }
|
||||
optParam :=
|
||||
{ range := { pos := { line := 152, column := 13 },
|
||||
{ range := { pos := { line := 158, column := 13 },
|
||||
charUtf16 := 13,
|
||||
endPos := { line := 152, column := 66 },
|
||||
endPos := { line := 158, column := 66 },
|
||||
endCharUtf16 := 66 },
|
||||
selectionRange := { pos := { line := 152, column := 17 },
|
||||
selectionRange := { pos := { line := 158, column := 17 },
|
||||
charUtf16 := 17,
|
||||
endPos := { line := 152, column := 25 },
|
||||
endPos := { line := 158, column := 25 },
|
||||
endCharUtf16 := 25 } }
|
||||
namedPattern :=
|
||||
{ range := { pos := { line := 161, column := 13 },
|
||||
{ range := { pos := { line := 167, column := 13 },
|
||||
charUtf16 := 13,
|
||||
endPos := { line := 161, column := 61 },
|
||||
endPos := { line := 167, column := 61 },
|
||||
endCharUtf16 := 61 },
|
||||
selectionRange := { pos := { line := 161, column := 17 },
|
||||
selectionRange := { pos := { line := 167, column := 17 },
|
||||
charUtf16 := 17,
|
||||
endPos := { line := 161, column := 29 },
|
||||
endPos := { line := 167, column := 29 },
|
||||
endCharUtf16 := 29 } }
|
||||
Lean.Meta.forallTelescopeReducing :=
|
||||
{ range := { pos := { line := 699, column := 0 },
|
||||
|
|
|
|||
46
tests/lean/run/simpMatchDiscr.lean
Normal file
46
tests/lean/run/simpMatchDiscr.lean
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
inductive Vec (α : Type u) : Nat → Type u
|
||||
| nil : Vec α 0
|
||||
| cons : α → {n : Nat} → Vec α n → Vec α (n+1)
|
||||
|
||||
def Vec.repeat (a : α) (n : Nat) : Vec α n :=
|
||||
match n with
|
||||
| 0 => nil
|
||||
| n+1 => cons a (repeat a n)
|
||||
|
||||
instance [Inhabited α] : Inhabited (Vec α n) where
|
||||
default := Vec.repeat arbitrary n
|
||||
|
||||
def Vec.map (v : Vec α n) (f : α → β) : Vec β n :=
|
||||
match n, v with
|
||||
| _, nil => nil
|
||||
| _, cons a as => cons (f a) (map as f)
|
||||
|
||||
def Vec.reverse (v : Vec α n) : Vec α n :=
|
||||
let rec loop : {n m : Nat} → Vec α n → Vec α m → Vec α (n + m)
|
||||
| _, _, nil, w => Nat.zero_add .. ▸ w
|
||||
| _, _, cons a as, w => Nat.add_assoc .. ▸ loop as (Nat.add_comm .. ▸ cons a w)
|
||||
loop v nil
|
||||
|
||||
@[simp] theorem map_id (v : Vec α n) : v.map id = v := by
|
||||
induction v with
|
||||
| nil => rfl
|
||||
| cons a as ih => simp [Vec.map, ih]
|
||||
|
||||
def foo [Add α] (v w : Vec α n) (f : α → α) (a : α) : α :=
|
||||
match n, v.map f, w.map f with
|
||||
| _, Vec.nil, Vec.nil => a
|
||||
| _, Vec.cons a .., Vec.cons b .. => a + b
|
||||
|
||||
theorem ex1 (a b : Nat) (as : Vec Nat n) : foo (Vec.cons a as) (Vec.cons b as) id 0 = a + b := by
|
||||
simp [foo]
|
||||
|
||||
#print ex1
|
||||
|
||||
def bla (b : Bool) (f g : α → β) (a : α) : β :=
|
||||
(match b with
|
||||
| true => f | false => g) a
|
||||
|
||||
theorem ex2 (h : b = false) : bla b (fun x => x + 1) id 10 = 10 := by
|
||||
simp [bla, h]
|
||||
|
||||
#print ex2
|
||||
Loading…
Add table
Reference in a new issue