From 81a19c8554179745358a68b62914dec260419b53 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 17 Aug 2020 16:23:49 -0700 Subject: [PATCH] feat: structure instances with `..` in patterns --- src/Lean/Elab/Match.lean | 113 +++++++++++++++++++--------- tests/lean/match4.lean | 24 +++++- tests/lean/match4.lean.expected.out | 2 + 3 files changed, 99 insertions(+), 40 deletions(-) diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 4df18f844e..c6b3e3661e 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -385,10 +385,11 @@ private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := patternVarDecls.foldlM - (fun (decls : Array LocalDecl) pdecl => + (fun (decls : Array LocalDecl) pdecl => do match pdecl with | PatternVarDecl.localVar fvarId => do decl ← getLocalDecl fvarId; + decl ← liftMetaM $ Meta.instantiateLocalDeclMVars decl; pure $ decls.push decl | PatternVarDecl.anonymousVar mvarId fvarId => do e ← instantiateMVars (mkMVar mvarId); @@ -400,14 +401,19 @@ patternVarDecls.foldlM assignExprMVar newMVarId (mkFVar fvarId); trace `Elab.match fun _ => "finalizePatternDecls: " ++ mkMVar newMVarId ++ " := " ++ mkFVar fvarId; decl ← getLocalDecl fvarId; + decl ← liftMetaM $ Meta.instantiateLocalDeclMVars decl; pure $ decls.push decl | _ => pure decls) #[] +open Meta.DepElim (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS mkElim ElimResult) + namespace ToDepElimPattern structure State := -(found : NameSet := {}) +(found : NameSet := {}) +(localDecls : Array LocalDecl) +(newLocals : NameSet := {}) abbrev M := StateT State TermElabM @@ -429,27 +435,52 @@ private def getFieldsBinderInfoAux (ctorVal : ConstructorVal) : Nat → Expr → getFieldsBinderInfoAux (i+1) b (bis.push c.binderInfo) | _, _, bis => bis +/- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/ +private def mkLocalDeclFor (mvar : Expr) : M Pattern := do +let mvarId := mvar.mvarId!; +s ← get; +val? ← liftM $ liftMetaM $ Meta.getExprMVarAssignment? mvarId; +match val? with +| some val => pure $ Pattern.inaccessible val +| none => do + fvarId ← liftM $ mkFreshId; + type ← liftM $ inferType mvar; + /- HACK: `fvarId` is not in the scope of `mvarId` + If this generates problems in the future, we should update the metavariable declarations. -/ + liftM $ assignExprMVar mvarId (mkFVar fvarId); + let userName := (`_x).appendIndexAfter (s.localDecls.size+1); + let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default; + modify $ fun s => + { s with + newLocals := s.newLocals.insert fvarId, + localDecls := + match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with + | none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar` + | some i => s.localDecls.insertAt i newDecl }; + pure $ Pattern.var fvarId + private def getFieldsBinderInfo (ctorVal : ConstructorVal) : Array BinderInfo := getFieldsBinderInfoAux ctorVal 0 ctorVal.type #[] -partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Pattern +partial def main : Expr → M Pattern | e => - let isLocalDecl (fvarId : FVarId) : Bool := - localDecls.any fun d => d.fvarId == fvarId; - let mkPatternVar (fvarId : FVarId) (e : Expr) : M Meta.DepElim.Pattern := do { - condM (alreadyVisited fvarId) - (pure $ Meta.DepElim.Pattern.inaccessible e) - (do markAsVisited fvarId; pure $ Meta.DepElim.Pattern.var e.fvarId!) + let isLocalDecl (fvarId : FVarId) : M Bool := do { + s ← get; + pure $ s.localDecls.any fun d => d.fvarId == fvarId }; - let mkInaccessible (e : Expr) : M Meta.DepElim.Pattern := do { + let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do { + condM (alreadyVisited fvarId) + (pure $ Pattern.inaccessible e) + (do markAsVisited fvarId; pure $ Pattern.var e.fvarId!) + }; + let mkInaccessible (e : Expr) : M Pattern := do { match e with | Expr.fvar fvarId _ => - if isLocalDecl fvarId then - mkPatternVar fvarId e - else - pure $ Meta.DepElim.Pattern.inaccessible e + condM (isLocalDecl fvarId) + (mkPatternVar fvarId e) + (pure $ Pattern.inaccessible e) | _ => - pure $ Meta.DepElim.Pattern.inaccessible e + pure $ Pattern.inaccessible e }; match inaccessible? e with | some t => mkInaccessible t @@ -457,19 +488,21 @@ partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Patter match e.arrayLit? with | some (α, lits) => do ps ← lits.mapM main; - pure $ Meta.DepElim.Pattern.arrayLit α ps + pure $ Pattern.arrayLit α ps | none => if e.isAppOfArity `namedPattern 3 then do p ← main $ e.getArg! 2; match e.getArg! 1 with - | Expr.fvar fvarId _ => pure $ Meta.DepElim.Pattern.as fvarId p + | Expr.fvar fvarId _ => pure $ Pattern.as fvarId p | _ => liftM $ throwError "unexpected occurrence of auxiliary declaration 'namedPattern'" else if e.isNatLit || e.isStringLit || e.isCharLit then - pure $ Meta.DepElim.Pattern.val e + pure $ Pattern.val e else if e.isFVar then do let fvarId := e.fvarId!; - unless (isLocalDecl fvarId) $ throwInvalidPattern e; + unlessM (isLocalDecl fvarId) $ throwInvalidPattern e; mkPatternVar fvarId e + else if e.isMVar then do + mkLocalDeclFor e else do newE ← liftM $ whnf e; if newE != e then @@ -491,45 +524,51 @@ partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Patter else mkInaccessible field }; - pure $ Meta.DepElim.Pattern.ctor declName us params.toList fields.toList + pure $ Pattern.ctor declName us params.toList fields.toList | _ => throwInvalidPattern e | _ => throwInvalidPattern e end ToDepElimPattern -def toDepElimPattern (localDecls : Array LocalDecl) (e : Expr) : TermElabM Meta.DepElim.Pattern := -(ToDepElimPattern.main localDecls e).run' {} +def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do +(patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls }; +localDecls ← s.localDecls.mapM fun d => liftMetaM $ Meta.instantiateLocalDeclMVars d; +/- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/ +lctx ← getLCtx; +let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx; +let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx; +adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $ k localDecls patterns -private def elabPatterns (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := do +private def withElaboratedLHS {α} (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) + (k : AltLHS → Expr → TermElabM α) : TermElabM α := do (patterns, matchType) ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[]; localDecls ← finalizePatternDecls patternVarDecls; patterns ← patterns.mapM instantiateMVars; -patterns.forM $ fun pattern => when pattern.hasExprMVar $ throwError ("pattern contains metavariables " ++ indentExpr pattern); -patterns ← patterns.mapM $ toDepElimPattern localDecls; -trace `Elab.match fun _ => "patterns: " ++ MessageData.ofArray (patterns.map fun (p : Meta.DepElim.Pattern) => p.toMessageData); -pure ({ fvarDecls := localDecls.toList, patterns := patterns.toList }, matchType) +withDepElimPatterns localDecls patterns fun localDecls patterns => + k { fvarDecls := localDecls.toList, patterns := patterns.toList } matchType -def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := +def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do (patternVars, alt) ← collectPatternVars alt; trace `Elab.match fun _ => "patternVars: " ++ toString patternVars; withPatternVars patternVars fun patternVarDecls => do - (altLHS, matchType) ← elabPatterns patternVarDecls alt.patterns matchType; - rhs ← elabTermEnsuringType alt.rhs matchType; - let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr; - rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs; - trace `Elab.match fun _ => "rhs: " ++ rhs; - pure (altLHS, rhs) + withElaboratedLHS patternVarDecls alt.patterns matchType fun altLHS matchType => do + rhs ← elabTermEnsuringType alt.rhs matchType; + let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr; + rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs; + trace `Elab.match fun _ => "rhs: " ++ rhs; + -- TODO: check whether altLHS still has metavariables + pure (altLHS, rhs) def mkMotiveType (matchType : Expr) (expectedType : Expr) : TermElabM Expr := do liftMetaM $ Meta.forallTelescopeReducing matchType fun xs matchType => do u ← Meta.getLevel matchType; Meta.mkForall xs (mkSort u) -def mkElim (elimName : Name) (motiveType : Expr) (lhss : List Meta.DepElim.AltLHS) : TermElabM Meta.DepElim.ElimResult := -liftMetaM $ Meta.DepElim.mkElim elimName motiveType lhss +def mkElim (elimName : Name) (motiveType : Expr) (lhss : List AltLHS) : TermElabM ElimResult := +liftMetaM $ mkElim elimName motiveType lhss -def reportElimResultErrors (result : Meta.DepElim.ElimResult) : TermElabM Unit := do +def reportElimResultErrors (result : ElimResult) : TermElabM Unit := do -- TODO: improve error messages unless result.counterExamples.isEmpty $ throwError ("missing cases:" ++ Format.line ++ Meta.DepElim.counterExamplesToMessageData result.counterExamples); diff --git a/tests/lean/match4.lean b/tests/lean/match4.lean index fabcc78f79..529f235534 100644 --- a/tests/lean/match4.lean +++ b/tests/lean/match4.lean @@ -1,3 +1,15 @@ +def Vector (α : Type) (n : Nat) := { a : Array α // a.size = n } + +def mkVec {α : Type} (n : Nat) (a : α) : Vector α n := +⟨mkArray n a, rfl⟩ + +structure S := +(n : Nat) +(y : Vector Nat n) +(z : Vector Nat n) +(h : y = z) +(m : { v : Nat // v = y.val.size }) + new_frontend def f1 (x : Nat × Nat) : Nat := @@ -37,8 +49,14 @@ h x y #eval f5 0 10 #eval f5 20 10 -/- -def f2 (x : Nat × Nat) : Nat := +def f6 (x : Nat × Nat) : Nat := match x with | { fst := x, .. } => x * 10 --/ + +#eval f6 (5, 20) + +def f7 (s : S) : Nat := +match s with +| { n := n, m := m, .. } => n + m.val + +#eval f7 { n := 10, y := mkVec 10 0, z := mkVec 10 0, h := rfl, m := ⟨10, rfl⟩ } diff --git a/tests/lean/match4.lean.expected.out b/tests/lean/match4.lean.expected.out index 9a2ed83dd7..675184e682 100644 --- a/tests/lean/match4.lean.expected.out +++ b/tests/lean/match4.lean.expected.out @@ -4,3 +4,5 @@ 30 10 200 +50 +20