fix: heuristic for deciding which additional propositions must be included in equality theorems

This commit is contained in:
Leonardo de Moura 2022-02-24 17:17:07 -08:00
parent bdea43a52a
commit 2961e9cbf0
2 changed files with 53 additions and 13 deletions

View file

@ -95,11 +95,11 @@ private def lhsDependsOn (type : Expr) (fvarId : FVarId) : MetaM Bool :=
def simpEqnType (eqnType : Expr) : MetaM Expr := do
forallTelescopeReducing (← instantiateMVars eqnType) fun ys type => do
let proofVars := collect type
trace[Meta.debug] "simpEqnType: {type}"
trace[Elab.definition] "simpEqnType type: {type}"
let mut type ← Match.unfoldNamedPattern type
let mut eliminated : FVarIdSet := {}
for y in ys.reverse do
trace[Meta.debug] ">> simpEqnType: {← inferType y}, {type}"
trace[Elab.definition] ">> simpEqnType: {← inferType y}, {type}"
if proofVars.contains y.fvarId! then
let some (_, Expr.fvar fvarId _, rhs) ← matchEq? (← inferType y) | throwError "unexpected hypothesis in altenative{indentExpr eqnType}"
eliminated := eliminated.insert fvarId
@ -129,7 +129,7 @@ where
ST.Prim.Ref.get ref
runST (go e)
private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do
private partial def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do
let target ← getMVarType' mvarId
let fvarState := collectFVars {} target
let fvarState ← (← getLCtx).foldrM (init := fvarState) fun decl fvarState => do
@ -137,20 +137,51 @@ private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := wit
return collectFVars fvarState (← instantiateMVars decl.type)
else
return fvarState
let mut fvarSet := fvarState.fvarSet
let mut fvarIdSet := fvarState.fvarSet
let mut fvarIds ← sortFVarIds <| fvarState.fvarSet.toArray
-- Include propositions that are not in fvarState.fvarSet, and only contains variables in fvarSet
for decl in (← getLCtx) do
unless fvarSet.contains decl.fvarId do
if (← isProp decl.type) then
let type ← instantiateMVars decl.type
let missing? := type.find? fun e => e.isFVar && !fvarSet.contains e.fvarId!
if missing?.isNone then
fvarIds := fvarIds.push decl.fvarId
fvarSet := fvarSet.insert decl.fvarId
-- Include (relevant) propositions that are not already in `fvarIdSet`
let mut modified := false
repeat
modified := false
for decl in (← getLCtx) do
unless fvarIdSet.contains decl.fvarId do
if (← isProp decl.type) then
let type ← instantiateMVars decl.type
unless (← isIrrelevant fvarIdSet type) do
modified := true
(fvarIdSet, fvarIds) ← pushDecl fvarIdSet fvarIds decl
until !modified
let type ← mkForallFVars (fvarIds.map mkFVar) target
let type ← simpEqnType type
modify (·.push type)
where
/--
We say the type/proposition is "irrelevant" if
1- It does not contain any variable in `fvarIdSet` OR
2- It is of the form `x = t` or `t = x` where `x` is a free variable
that is not in `fvarIdSet`. This can of equality can be eliminated by substitution. -/
isIrrelevant (fvarIdSet : FVarIdSet) (type : Expr) : MetaM Bool := do
if Option.isNone <| type.find? fun e => e.isFVar && fvarIdSet.contains e.fvarId! then
return true
else if let some (_, lhs, rhs) := type.eq? then
return (lhs.isFVar && !fvarIdSet.contains lhs.fvarId!)
|| (rhs.isFVar && !fvarIdSet.contains rhs.fvarId!)
else
return false
pushDecl (fvarIdSet : FVarIdSet) (fvarIds : Array FVarId) (localDecl : LocalDecl) : MetaM (FVarIdSet × Array FVarId) := do
let (fvarIdSet, fvarIds) ← collectDeps fvarIdSet fvarIds (← instantiateMVars localDecl.type)
return (fvarIdSet.insert localDecl.fvarId, fvarIds.push localDecl.fvarId)
collectDeps (fvarIdSet : FVarIdSet) (fvarIds : Array FVarId) (type : Expr) : MetaM (FVarIdSet × Array FVarId) := do
let s := collectFVars {} type
let usedFVarIds ← sortFVarIds <| s.fvarSet.toArray
let mut fvarIdSet := fvarIdSet
let mut fvarIds := fvarIds
for fvarId in usedFVarIds do
unless fvarIdSet.contains fvarId do
(fvarIdSet, fvarIds) ← pushDecl fvarIdSet fvarIds (← getLocalDecl fvarId)
return (fvarIdSet, fvarIds)
partial def mkEqnTypes (declNames : Array Name) (mvarId : MVarId) : MetaM (Array Expr) := do
let (_, eqnTypes) ← go mvarId |>.run { declNames } |>.run #[]

View file

@ -0,0 +1,9 @@
@[simp] def Stream.hasLength [Stream stream value] (n : Nat) (s : stream) : Bool :=
match n, Stream.next? s with
| 0, none => true
| n + 1, some (_, s') => hasLength n s'
| _, _ => false
#check @Stream.hasLength._eq_1
#check @Stream.hasLength._eq_2
#check @Stream.hasLength._eq_3