fix: heuristic for deciding which additional propositions must be included in equality theorems
This commit is contained in:
parent
bdea43a52a
commit
2961e9cbf0
2 changed files with 53 additions and 13 deletions
|
|
@ -95,11 +95,11 @@ private def lhsDependsOn (type : Expr) (fvarId : FVarId) : MetaM Bool :=
|
|||
def simpEqnType (eqnType : Expr) : MetaM Expr := do
|
||||
forallTelescopeReducing (← instantiateMVars eqnType) fun ys type => do
|
||||
let proofVars := collect type
|
||||
trace[Meta.debug] "simpEqnType: {type}"
|
||||
trace[Elab.definition] "simpEqnType type: {type}"
|
||||
let mut type ← Match.unfoldNamedPattern type
|
||||
let mut eliminated : FVarIdSet := {}
|
||||
for y in ys.reverse do
|
||||
trace[Meta.debug] ">> simpEqnType: {← inferType y}, {type}"
|
||||
trace[Elab.definition] ">> simpEqnType: {← inferType y}, {type}"
|
||||
if proofVars.contains y.fvarId! then
|
||||
let some (_, Expr.fvar fvarId _, rhs) ← matchEq? (← inferType y) | throwError "unexpected hypothesis in altenative{indentExpr eqnType}"
|
||||
eliminated := eliminated.insert fvarId
|
||||
|
|
@ -129,7 +129,7 @@ where
|
|||
ST.Prim.Ref.get ref
|
||||
runST (go e)
|
||||
|
||||
private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do
|
||||
private partial def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do
|
||||
let target ← getMVarType' mvarId
|
||||
let fvarState := collectFVars {} target
|
||||
let fvarState ← (← getLCtx).foldrM (init := fvarState) fun decl fvarState => do
|
||||
|
|
@ -137,20 +137,51 @@ private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := wit
|
|||
return collectFVars fvarState (← instantiateMVars decl.type)
|
||||
else
|
||||
return fvarState
|
||||
let mut fvarSet := fvarState.fvarSet
|
||||
let mut fvarIdSet := fvarState.fvarSet
|
||||
let mut fvarIds ← sortFVarIds <| fvarState.fvarSet.toArray
|
||||
-- Include propositions that are not in fvarState.fvarSet, and only contains variables in fvarSet
|
||||
for decl in (← getLCtx) do
|
||||
unless fvarSet.contains decl.fvarId do
|
||||
if (← isProp decl.type) then
|
||||
let type ← instantiateMVars decl.type
|
||||
let missing? := type.find? fun e => e.isFVar && !fvarSet.contains e.fvarId!
|
||||
if missing?.isNone then
|
||||
fvarIds := fvarIds.push decl.fvarId
|
||||
fvarSet := fvarSet.insert decl.fvarId
|
||||
-- Include (relevant) propositions that are not already in `fvarIdSet`
|
||||
let mut modified := false
|
||||
repeat
|
||||
modified := false
|
||||
for decl in (← getLCtx) do
|
||||
unless fvarIdSet.contains decl.fvarId do
|
||||
if (← isProp decl.type) then
|
||||
let type ← instantiateMVars decl.type
|
||||
unless (← isIrrelevant fvarIdSet type) do
|
||||
modified := true
|
||||
(fvarIdSet, fvarIds) ← pushDecl fvarIdSet fvarIds decl
|
||||
until !modified
|
||||
let type ← mkForallFVars (fvarIds.map mkFVar) target
|
||||
let type ← simpEqnType type
|
||||
modify (·.push type)
|
||||
where
|
||||
/--
|
||||
We say the type/proposition is "irrelevant" if
|
||||
1- It does not contain any variable in `fvarIdSet` OR
|
||||
2- It is of the form `x = t` or `t = x` where `x` is a free variable
|
||||
that is not in `fvarIdSet`. This can of equality can be eliminated by substitution. -/
|
||||
isIrrelevant (fvarIdSet : FVarIdSet) (type : Expr) : MetaM Bool := do
|
||||
if Option.isNone <| type.find? fun e => e.isFVar && fvarIdSet.contains e.fvarId! then
|
||||
return true
|
||||
else if let some (_, lhs, rhs) := type.eq? then
|
||||
return (lhs.isFVar && !fvarIdSet.contains lhs.fvarId!)
|
||||
|| (rhs.isFVar && !fvarIdSet.contains rhs.fvarId!)
|
||||
else
|
||||
return false
|
||||
|
||||
pushDecl (fvarIdSet : FVarIdSet) (fvarIds : Array FVarId) (localDecl : LocalDecl) : MetaM (FVarIdSet × Array FVarId) := do
|
||||
let (fvarIdSet, fvarIds) ← collectDeps fvarIdSet fvarIds (← instantiateMVars localDecl.type)
|
||||
return (fvarIdSet.insert localDecl.fvarId, fvarIds.push localDecl.fvarId)
|
||||
|
||||
collectDeps (fvarIdSet : FVarIdSet) (fvarIds : Array FVarId) (type : Expr) : MetaM (FVarIdSet × Array FVarId) := do
|
||||
let s := collectFVars {} type
|
||||
let usedFVarIds ← sortFVarIds <| s.fvarSet.toArray
|
||||
let mut fvarIdSet := fvarIdSet
|
||||
let mut fvarIds := fvarIds
|
||||
for fvarId in usedFVarIds do
|
||||
unless fvarIdSet.contains fvarId do
|
||||
(fvarIdSet, fvarIds) ← pushDecl fvarIdSet fvarIds (← getLocalDecl fvarId)
|
||||
return (fvarIdSet, fvarIds)
|
||||
|
||||
partial def mkEqnTypes (declNames : Array Name) (mvarId : MVarId) : MetaM (Array Expr) := do
|
||||
let (_, eqnTypes) ← go mvarId |>.run { declNames } |>.run #[]
|
||||
|
|
|
|||
9
tests/lean/run/streamEqIssue.lean
Normal file
9
tests/lean/run/streamEqIssue.lean
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
@[simp] def Stream.hasLength [Stream stream value] (n : Nat) (s : stream) : Bool :=
|
||||
match n, Stream.next? s with
|
||||
| 0, none => true
|
||||
| n + 1, some (_, s') => hasLength n s'
|
||||
| _, _ => false
|
||||
|
||||
#check @Stream.hasLength._eq_1
|
||||
#check @Stream.hasLength._eq_2
|
||||
#check @Stream.hasLength._eq_3
|
||||
Loading…
Add table
Reference in a new issue