diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index cbdeb02461..44c0a81966 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -621,20 +621,20 @@ let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx; adaptTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ k localDecls patterns -private def withElaboratedLHS {α} (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) +private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) (k : AltLHS → Expr → TermElabM α) : TermElabM α := do (patterns, matchType) ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[]; localDecls ← finalizePatternDecls patternVarDecls; patterns ← patterns.mapM instantiateMVars; withDepElimPatterns localDecls patterns fun localDecls patterns => - k { fvarDecls := localDecls.toList, patterns := patterns.toList } matchType + k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do (patternVars, alt) ← collectPatternVars alt; trace `Elab.match fun _ => "patternVars: " ++ toString patternVars; withPatternVars patternVars fun patternVarDecls => do - withElaboratedLHS patternVarDecls alt.patterns matchType fun altLHS matchType => do + withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do rhs ← elabTermEnsuringType alt.rhs matchType; let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr; rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambdaFVars xs rhs; diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 559b252eed..153eabb7e2 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -69,10 +69,12 @@ p.applyFVarSubst (s.insert fvarId v) end Pattern structure AltLHS := +(ref : Syntax) (fvarDecls : List LocalDecl) -- Free variables used in the patterns. (patterns : List Pattern) -- We use `List Pattern` since we have nary match-expressions. structure Alt := +(ref : Syntax) (idx : Nat) -- for generating error messages (rhs : Expr) (fvarDecls : List LocalDecl) @@ -80,7 +82,7 @@ structure Alt := namespace Alt -instance : Inhabited Alt := ⟨⟨0, arbitrary _, [], []⟩⟩ +instance : Inhabited Alt := ⟨⟨arbitrary _, 0, arbitrary _, [], []⟩⟩ partial def toMessageData (alt : Alt) : MetaM MessageData := do withExistingLocalDecls alt.fvarDecls do @@ -102,6 +104,62 @@ def replaceFVarId (fvarId : FVarId) (v : Expr) (alt : Alt) : Alt := decls.map $ replaceFVarIdAtLocalDecl fvarId v, rhs := alt.rhs.replaceFVarId fvarId v } +def isDefEqGuarded (a b : Expr) : MetaM Bool := +catch (isDefEq a b) (fun _ => pure false) + +/- + Similar to `checkAndReplaceFVarId`, but ensures type of `v` is definitionally equal to type of `fvarId`. + This extra check is necessary when performing dependent elimination and inaccessible terms have been used. + For example, consider the following code fragment: + +``` +inductive Vec (α : Type u) : Nat → Type u +| nil : Vec α 0 +| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1) + +inductive VecPred {α : Type u} (P : α → Prop) : {n : Nat} → Vec α n → Prop +| nil : VecPred P Vec.nil +| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail) + +theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P +| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ +``` +Recall that `_` in a pattern can be elaborated into pattern variable or an inaccessible term. +The elaborator uses an inaccessible term when typing constraints restrict its value. +Thus, in the example above, the `_` at `Vec.cons head _` becomes the inaccessible pattern `.(Vec.nil)` +because the type ascription `(w : VecPred P Vec.nil)` propagates typing constraints that restrict its value to be `Vec.nil`. +After elaboration the alternative becomes: +``` +| .(0), @Vec.cons .(α) .(0) head .(Vec.nil), @VecPred.cons .(α) .(P) .(0) .(head) .(Vec.nil) h w => ⟨head, h⟩ +``` +where +``` +(head : α), (h: P head), (w : VecPred P Vec.nil) +``` +Then, when we process this alternative in this module, the following check will detect that +`w` has type `VecPred P Vec.nil`, when it is supposed to have type `VecPred P tail`. +Note that if we had written +``` +theorem ex {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P +| _, Vec.cons head Vec.nil, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ +``` +we would get the easier to digest error message +``` +missing cases: +_, (Vec.cons _ _ (Vec.cons _ _ _)), _ +``` +-/ +def checkAndReplaceFVarId (fvarId : FVarId) (v : Expr) (alt : Alt) : MetaM Alt := do +match alt.fvarDecls.find? fun (fvarDecl : LocalDecl) => fvarDecl.fvarId == fvarId with +| none => throwErrorAt alt.ref "unknown free pattern variable" +| some fvarDecl => do + vType ← inferType v; + unlessM (isDefEqGuarded fvarDecl.type vType) $ + throwErrorAt alt.ref $ + "type mismatch during dependent match-elimination at pattern variable '" ++ fvarDecl.userName.simpMacroScopes ++ "' with type" ++ indentExpr fvarDecl.type ++ + Format.line ++ "expected type" ++ indentExpr vType; + pure $ replaceFVarId fvarId v alt + end Alt inductive Example @@ -207,7 +265,7 @@ private partial def withAltsAux {α} (motive : Expr) : List AltLHS → List Alt let rhs := if xs.isEmpty then mkApp minor (mkConst `Unit.unit) else mkAppN minor xs; let minors := minors.push minor; fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars; - let alts := { idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns : Alt } :: alts; + let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns : Alt } :: alts; withAltsAux lhss alts minors k /- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/ @@ -315,24 +373,24 @@ match p.alts with liftM $ assignGoalOf p alt.rhs; modify fun s => { s with used := s.used.insert alt.idx } -private def processAsPattern (p : Problem) : Problem := +private def processAsPattern (p : Problem) : MetaM Problem := match p.vars with | [] => unreachable! -| x :: xs => do - let alts := p.alts.map fun alt => match alt.patterns with - | Pattern.as fvarId p :: ps => { alt with patterns := p :: ps }.replaceFVarId fvarId x - | _ => alt; - { p with alts := alts } +| x :: xs => withGoalOf p do + alts ← p.alts.mapM fun alt => match alt.patterns with + | Pattern.as fvarId p :: ps => { alt with patterns := p :: ps }.checkAndReplaceFVarId fvarId x + | _ => pure alt; + pure { p with alts := alts } -private def processVariable (p : Problem) : Problem := +private def processVariable (p : Problem) : MetaM Problem := match p.vars with | [] => unreachable! -| x :: xs => do - let alts := p.alts.map fun alt => match alt.patterns with - | Pattern.inaccessible _ :: ps => { alt with patterns := ps } - | Pattern.var fvarId :: ps => { alt with patterns := ps }.replaceFVarId fvarId x +| x :: xs => withGoalOf p do + alts ← p.alts.mapM fun alt => match alt.patterns with + | Pattern.inaccessible _ :: ps => pure { alt with patterns := ps } + | Pattern.var fvarId :: ps => { alt with patterns := ps }.checkAndReplaceFVarId fvarId x | _ => unreachable!; - { p with alts := alts, vars := xs } + pure { p with alts := alts, vars := xs } private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do t ← inferType e; @@ -416,7 +474,7 @@ withExistingLocalDecls alt.fvarDecls do ctorType ← inferType ctor; forallTelescopeReducing ctorType fun ctorFields resultType => do let ctor := mkAppN ctor ctorFields; - let alt := alt.replaceFVarId fvarId ctor; + let alt := alt.replaceFVarId fvarId ctor; ctorFieldDecls ← ctorFields.mapM fun ctorField => getLocalDecl ctorField.fvarId!; let newAltDecls := ctorFieldDecls.toList ++ alt.fvarDecls; subst? ← unify? newAltDecls resultType expectedType; @@ -541,7 +599,7 @@ match p.vars with let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst; let newAlts := newAlts.map fun alt => match alt.patterns with | Pattern.val _ :: ps => { alt with patterns := ps } - | Pattern.var fvarId :: ps => + | Pattern.var fvarId :: ps => do let alt := { alt with patterns := ps }; alt.replaceFVarId fvarId value | _ => unreachable!; @@ -640,7 +698,8 @@ private partial def process : Problem → StateRefT State MetaM Unit processLeaf p else if hasAsPattern p then do traceStep ("as-pattern"); - process (processAsPattern p) + p ← liftM $ processAsPattern p; + process p else if !isNextVar p then do traceStep ("non variable"); process (processNonVariable p) @@ -649,7 +708,8 @@ private partial def process : Problem → StateRefT State MetaM Unit ps.forM process else if isVariableTransition p then do traceStep ("variable"); - process (processVariable p) + p ← liftM $ processVariable p; + process p else if isValueTransition p then do ps ← liftM $ processValue p; ps.forM process diff --git a/tests/lean/match1.lean b/tests/lean/match1.lean index ab33a9ba05..d11a21a258 100644 --- a/tests/lean/match1.lean +++ b/tests/lean/match1.lean @@ -64,3 +64,19 @@ fun { x := x, ..} => { y := x } theorem ex2 : f1 { x := 10 } = { y := 10 } := rfl + +universes u + +inductive Vec (α : Type u) : Nat → Type u +| nil : Vec α 0 +| cons {n} (head : α) (tail : Vec α n) : Vec α (n+1) + +inductive VecPred {α : Type u} (P : α → Prop) : {n : Nat} → Vec α n → Prop +| nil : VecPred P Vec.nil +| cons {n : Nat} {head : α} {tail : Vec α n} : P head → VecPred P tail → VecPred P (Vec.cons head tail) + +theorem ex3 {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P +| _, Vec.cons head _, VecPred.cons h _ => ⟨head, h⟩ + +theorem ex4 {α : Type u} (P : α → Prop) : {n : Nat} → (v : Vec α (n+1)) → VecPred P v → Exists P +| _, Vec.cons head _, VecPred.cons h (w : VecPred P Vec.nil) => ⟨head, h⟩ -- ERROR diff --git a/tests/lean/match1.lean.expected.out b/tests/lean/match1.lean.expected.out index 5ccc50407c..962bc9505e 100644 --- a/tests/lean/match1.lean.expected.out +++ b/tests/lean/match1.lean.expected.out @@ -10,3 +10,7 @@ 4 ---- inv 10 +match1.lean:82:2: error: type mismatch during dependent match-elimination at pattern variable 'w' with type + VecPred P Vec.nil +expected type + VecPred P tail diff --git a/tests/lean/run/depElim1.lean b/tests/lean/run/depElim1.lean index c87dbfdbd4..273d186476 100644 --- a/tests/lean/run/depElim1.lean +++ b/tests/lean/run/depElim1.lean @@ -99,7 +99,7 @@ partial def decodeAltLHS (e : Expr) : MetaM AltLHS := forallTelescopeReducing e fun args body => do decls ← args.toList.mapM (fun arg => getLocalDecl arg.fvarId!); pats ← decodePats body; - pure { fvarDecls := decls, patterns := pats } + pure { ref := Syntax.missing, fvarDecls := decls, patterns := pats } partial def decodeAltLHSs : Expr → MetaM (List AltLHS) | e =>