diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 224d28da6d..653d4288b0 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -650,14 +650,10 @@ private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array Patte If `type` or another local variables depends on a free variable in `toClear`, then it is not cleared. -/ private def withToClear (toClear : Array FVarId) (type : Expr) (k : TermElabM α) : TermElabM α := do - let mut toClear := toClear - for localDecl in (← getLCtx) do - if isAuxDiscrName localDecl.userName || isAuxFunDiscrName localDecl.userName then - toClear := toClear.push localDecl.fvarId if toClear.isEmpty then k else - toClear ← sortFVarIds toClear + let toClear ← sortFVarIds toClear trace[Elab.match] ">> toClear {toClear.map mkFVar}" let mut lctx ← getLCtx let mut localInsts ← getLocalInstances @@ -668,31 +664,40 @@ private def withToClear (toClear : Array FVarId) (type : Expr) (k : TermElabM α localInsts := localInsts.filter fun localInst => localInst.fvar.fvarId! != fvarId withLCtx lctx localInsts k + +private def withoutAuxDiscrs (matchType : Expr) (k : TermElabM α) : TermElabM α := do + let mut toClear := #[] + for localDecl in (← getLCtx) do + if isAuxDiscrName localDecl.userName || isAuxFunDiscrName localDecl.userName then + toClear := toClear.push localDecl.fvarId + withToClear toClear matchType k + /-- Elaborate the `match` alternative `alt` using the given `matchType`. The array `toClear` contains variables that must be cleared before elaborating the `rhs` because they have been generalized/refined. -/ private def elabMatchAltView (alt : MatchAltView) (matchType : Expr) (toClear : Array FVarId) : ExceptT PatternElabException TermElabM (AltLHS × Expr) := withRef alt.ref do - let (patternVars, alt) ← collectPatternVars alt - trace[Elab.match] "patternVars: {patternVars}" - withPatternVars patternVars fun patternVarDecls => do - withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do - withLocalInstances altLHS.fvarDecls do - trace[Elab.match] "elabMatchAltView: {matchType}" - let matchType ← instantiateMVars matchType - -- If `matchType` is of the form `@m ...`, we create a new metavariable with the current scope. - -- This improves the effectiveness of the `isDefEq` default approximations - let matchType' ← if matchType.getAppFn.isMVar then mkFreshTypeMVar else pure matchType - withToClear toClear matchType' do - let rhs ← elabTermEnsuringType alt.rhs matchType' - -- We use all approximations to ensure the auxiliary type is defeq to the original one. - unless (← fullApproxDefEq <| isDefEq matchType' matchType) do - throwError "type mistmatch, alternative {← mkHasTypeButIsExpectedMsg matchType' matchType}" - let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr - let rhs ← if xs.isEmpty then pure <| mkSimpleThunk rhs else mkLambdaFVars xs rhs - trace[Elab.match] "rhs: {rhs}" - return (altLHS, rhs) + withoutAuxDiscrs matchType do + let (patternVars, alt) ← collectPatternVars alt + trace[Elab.match] "patternVars: {patternVars}" + withPatternVars patternVars fun patternVarDecls => do + withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do + withLocalInstances altLHS.fvarDecls do + trace[Elab.match] "elabMatchAltView: {matchType}" + let matchType ← instantiateMVars matchType + -- If `matchType` is of the form `@m ...`, we create a new metavariable with the current scope. + -- This improves the effectiveness of the `isDefEq` default approximations + let matchType' ← if matchType.getAppFn.isMVar then mkFreshTypeMVar else pure matchType + withToClear toClear matchType' do + let rhs ← elabTermEnsuringType alt.rhs matchType' + -- We use all approximations to ensure the auxiliary type is defeq to the original one. + unless (← fullApproxDefEq <| isDefEq matchType' matchType) do + throwError "type mistmatch, alternative {← mkHasTypeButIsExpectedMsg matchType' matchType}" + let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr + let rhs ← if xs.isEmpty then pure <| mkSimpleThunk rhs else mkLambdaFVars xs rhs + trace[Elab.match] "rhs: {rhs}" + return (altLHS, rhs) /-- Collect problematic index for the "discriminant refinement feature". This method is invoked diff --git a/tests/lean/interactive/discrsIssue.lean b/tests/lean/interactive/discrsIssue.lean new file mode 100644 index 0000000000..df05f1ce14 --- /dev/null +++ b/tests/lean/interactive/discrsIssue.lean @@ -0,0 +1,39 @@ +inductive Expr where + | nat : Nat → Expr + | plus : Expr → Expr → Expr + | bool : Bool → Expr + | and : Expr → Expr → Expr + +inductive Ty where + | nat + | bool + deriving DecidableEq + +inductive HasType : Expr → Ty → Prop + | nat : HasType (.nat v) .nat + | plus : HasType a .nat → HasType b .nat → HasType (.plus a b) .nat + | bool : HasType (.bool v) .bool + | and : HasType a .bool → HasType b .bool → HasType (.and a b) .bool + +theorem HasType.det (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by + cases h₁ <;> cases h₂ <;> rfl + +inductive Maybe (p : α → Prop) where + | found : (a : α) → p a → Maybe p + | unknown + +notation "{{ " x " | " p " }}" => Maybe (fun x => p) + +def Expr.typeCheck (e : Expr) : {{ ty | HasType e ty }} := + match e with + | nat .. => .found .nat .nat + | bool .. => .found .bool .bool + | plus a b => + match a.typeCheck, b.typeCheck with + | .found .nat h₁, .found .nat h₂ => .found .nat (.plus h₁ h₂) + --^ $/lean/plainTermGoal + | _, _ => .unknown + | and a b => + match a.typeCheck, b.typeCheck with + | .found .bool h₁, .found .bool h₂ => .found .bool (.and h₁ h₂) + | _, _ => .unknown diff --git a/tests/lean/interactive/discrsIssue.lean.expected.out b/tests/lean/interactive/discrsIssue.lean.expected.out new file mode 100644 index 0000000000..49fc6a0e4f --- /dev/null +++ b/tests/lean/interactive/discrsIssue.lean.expected.out @@ -0,0 +1,6 @@ +{"textDocument": {"uri": "file://discrsIssue.lean"}, + "position": {"line": 32, "character": 18}} +{"range": + {"start": {"line": 32, "character": 18}, "end": {"line": 32, "character": 20}}, + "goal": + "e a b : Expr\nh₁ : HasType a Ty.nat\nh₂ : HasType b Ty.nat\n⊢ HasType a Ty.nat"}