diff --git a/src/Lean/Elab/PreDefinition/Eqns.lean b/src/Lean/Elab/PreDefinition/Eqns.lean index a293772fda..c169d6d063 100644 --- a/src/Lean/Elab/PreDefinition/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Eqns.lean @@ -95,11 +95,11 @@ private def lhsDependsOn (type : Expr) (fvarId : FVarId) : MetaM Bool := def simpEqnType (eqnType : Expr) : MetaM Expr := do forallTelescopeReducing (← instantiateMVars eqnType) fun ys type => do let proofVars := collect type - trace[Meta.debug] "simpEqnType: {type}" + trace[Elab.definition] "simpEqnType type: {type}" let mut type ← Match.unfoldNamedPattern type let mut eliminated : FVarIdSet := {} for y in ys.reverse do - trace[Meta.debug] ">> simpEqnType: {← inferType y}, {type}" + trace[Elab.definition] ">> simpEqnType: {← inferType y}, {type}" if proofVars.contains y.fvarId! then let some (_, Expr.fvar fvarId _, rhs) ← matchEq? (← inferType y) | throwError "unexpected hypothesis in altenative{indentExpr eqnType}" eliminated := eliminated.insert fvarId @@ -129,7 +129,7 @@ where ST.Prim.Ref.get ref runST (go e) -private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do +private partial def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do let target ← getMVarType' mvarId let fvarState := collectFVars {} target let fvarState ← (← getLCtx).foldrM (init := fvarState) fun decl fvarState => do @@ -137,20 +137,51 @@ private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := wit return collectFVars fvarState (← instantiateMVars decl.type) else return fvarState - let mut fvarSet := fvarState.fvarSet + let mut fvarIdSet := fvarState.fvarSet let mut fvarIds ← sortFVarIds <| fvarState.fvarSet.toArray - -- Include propositions that are not in fvarState.fvarSet, and only contains variables in fvarSet - for decl in (← getLCtx) do - unless fvarSet.contains decl.fvarId do - if (← isProp decl.type) then - let type ← instantiateMVars decl.type - let missing? := type.find? fun e => e.isFVar && !fvarSet.contains e.fvarId! - if missing?.isNone then - fvarIds := fvarIds.push decl.fvarId - fvarSet := fvarSet.insert decl.fvarId + -- Include (relevant) propositions that are not already in `fvarIdSet` + let mut modified := false + repeat + modified := false + for decl in (← getLCtx) do + unless fvarIdSet.contains decl.fvarId do + if (← isProp decl.type) then + let type ← instantiateMVars decl.type + unless (← isIrrelevant fvarIdSet type) do + modified := true + (fvarIdSet, fvarIds) ← pushDecl fvarIdSet fvarIds decl + until !modified let type ← mkForallFVars (fvarIds.map mkFVar) target let type ← simpEqnType type modify (·.push type) +where + /-- + We say the type/proposition is "irrelevant" if + 1- It does not contain any variable in `fvarIdSet` OR + 2- It is of the form `x = t` or `t = x` where `x` is a free variable + that is not in `fvarIdSet`. This can of equality can be eliminated by substitution. -/ + isIrrelevant (fvarIdSet : FVarIdSet) (type : Expr) : MetaM Bool := do + if Option.isNone <| type.find? fun e => e.isFVar && fvarIdSet.contains e.fvarId! then + return true + else if let some (_, lhs, rhs) := type.eq? then + return (lhs.isFVar && !fvarIdSet.contains lhs.fvarId!) + || (rhs.isFVar && !fvarIdSet.contains rhs.fvarId!) + else + return false + + pushDecl (fvarIdSet : FVarIdSet) (fvarIds : Array FVarId) (localDecl : LocalDecl) : MetaM (FVarIdSet × Array FVarId) := do + let (fvarIdSet, fvarIds) ← collectDeps fvarIdSet fvarIds (← instantiateMVars localDecl.type) + return (fvarIdSet.insert localDecl.fvarId, fvarIds.push localDecl.fvarId) + + collectDeps (fvarIdSet : FVarIdSet) (fvarIds : Array FVarId) (type : Expr) : MetaM (FVarIdSet × Array FVarId) := do + let s := collectFVars {} type + let usedFVarIds ← sortFVarIds <| s.fvarSet.toArray + let mut fvarIdSet := fvarIdSet + let mut fvarIds := fvarIds + for fvarId in usedFVarIds do + unless fvarIdSet.contains fvarId do + (fvarIdSet, fvarIds) ← pushDecl fvarIdSet fvarIds (← getLocalDecl fvarId) + return (fvarIdSet, fvarIds) partial def mkEqnTypes (declNames : Array Name) (mvarId : MVarId) : MetaM (Array Expr) := do let (_, eqnTypes) ← go mvarId |>.run { declNames } |>.run #[] diff --git a/tests/lean/run/streamEqIssue.lean b/tests/lean/run/streamEqIssue.lean new file mode 100644 index 0000000000..6401c0016e --- /dev/null +++ b/tests/lean/run/streamEqIssue.lean @@ -0,0 +1,9 @@ +@[simp] def Stream.hasLength [Stream stream value] (n : Nat) (s : stream) : Bool := + match n, Stream.next? s with + | 0, none => true + | n + 1, some (_, s') => hasLength n s' + | _, _ => false + +#check @Stream.hasLength._eq_1 +#check @Stream.hasLength._eq_2 +#check @Stream.hasLength._eq_3