feat: structure instances with .. in patterns

This commit is contained in:
Leonardo de Moura 2020-08-17 16:23:49 -07:00
parent b47f530db5
commit 81a19c8554
3 changed files with 99 additions and 40 deletions

View file

@ -385,10 +385,11 @@ private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr
def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) :=
patternVarDecls.foldlM
(fun (decls : Array LocalDecl) pdecl =>
(fun (decls : Array LocalDecl) pdecl => do
match pdecl with
| PatternVarDecl.localVar fvarId => do
decl ← getLocalDecl fvarId;
decl ← liftMetaM $ Meta.instantiateLocalDeclMVars decl;
pure $ decls.push decl
| PatternVarDecl.anonymousVar mvarId fvarId => do
e ← instantiateMVars (mkMVar mvarId);
@ -400,14 +401,19 @@ patternVarDecls.foldlM
assignExprMVar newMVarId (mkFVar fvarId);
trace `Elab.match fun _ => "finalizePatternDecls: " ++ mkMVar newMVarId ++ " := " ++ mkFVar fvarId;
decl ← getLocalDecl fvarId;
decl ← liftMetaM $ Meta.instantiateLocalDeclMVars decl;
pure $ decls.push decl
| _ => pure decls)
#[]
open Meta.DepElim (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS mkElim ElimResult)
namespace ToDepElimPattern
structure State :=
(found : NameSet := {})
(found : NameSet := {})
(localDecls : Array LocalDecl)
(newLocals : NameSet := {})
abbrev M := StateT State TermElabM
@ -429,27 +435,52 @@ private def getFieldsBinderInfoAux (ctorVal : ConstructorVal) : Nat → Expr →
getFieldsBinderInfoAux (i+1) b (bis.push c.binderInfo)
| _, _, bis => bis
/- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/
private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
let mvarId := mvar.mvarId!;
s ← get;
val? ← liftM $ liftMetaM $ Meta.getExprMVarAssignment? mvarId;
match val? with
| some val => pure $ Pattern.inaccessible val
| none => do
fvarId ← liftM $ mkFreshId;
type ← liftM $ inferType mvar;
/- HACK: `fvarId` is not in the scope of `mvarId`
If this generates problems in the future, we should update the metavariable declarations. -/
liftM $ assignExprMVar mvarId (mkFVar fvarId);
let userName := (`_x).appendIndexAfter (s.localDecls.size+1);
let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default;
modify $ fun s =>
{ s with
newLocals := s.newLocals.insert fvarId,
localDecls :=
match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with
| none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar`
| some i => s.localDecls.insertAt i newDecl };
pure $ Pattern.var fvarId
private def getFieldsBinderInfo (ctorVal : ConstructorVal) : Array BinderInfo :=
getFieldsBinderInfoAux ctorVal 0 ctorVal.type #[]
partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Pattern
partial def main : Expr → M Pattern
| e =>
let isLocalDecl (fvarId : FVarId) : Bool :=
localDecls.any fun d => d.fvarId == fvarId;
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Meta.DepElim.Pattern := do {
condM (alreadyVisited fvarId)
(pure $ Meta.DepElim.Pattern.inaccessible e)
(do markAsVisited fvarId; pure $ Meta.DepElim.Pattern.var e.fvarId!)
let isLocalDecl (fvarId : FVarId) : M Bool := do {
s ← get;
pure $ s.localDecls.any fun d => d.fvarId == fvarId
};
let mkInaccessible (e : Expr) : M Meta.DepElim.Pattern := do {
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do {
condM (alreadyVisited fvarId)
(pure $ Pattern.inaccessible e)
(do markAsVisited fvarId; pure $ Pattern.var e.fvarId!)
};
let mkInaccessible (e : Expr) : M Pattern := do {
match e with
| Expr.fvar fvarId _ =>
if isLocalDecl fvarId then
mkPatternVar fvarId e
else
pure $ Meta.DepElim.Pattern.inaccessible e
condM (isLocalDecl fvarId)
(mkPatternVar fvarId e)
(pure $ Pattern.inaccessible e)
| _ =>
pure $ Meta.DepElim.Pattern.inaccessible e
pure $ Pattern.inaccessible e
};
match inaccessible? e with
| some t => mkInaccessible t
@ -457,19 +488,21 @@ partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Patter
match e.arrayLit? with
| some (α, lits) => do
ps ← lits.mapM main;
pure $ Meta.DepElim.Pattern.arrayLit α ps
pure $ Pattern.arrayLit α ps
| none =>
if e.isAppOfArity `namedPattern 3 then do
p ← main $ e.getArg! 2;
match e.getArg! 1 with
| Expr.fvar fvarId _ => pure $ Meta.DepElim.Pattern.as fvarId p
| Expr.fvar fvarId _ => pure $ Pattern.as fvarId p
| _ => liftM $ throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
else if e.isNatLit || e.isStringLit || e.isCharLit then
pure $ Meta.DepElim.Pattern.val e
pure $ Pattern.val e
else if e.isFVar then do
let fvarId := e.fvarId!;
unless (isLocalDecl fvarId) $ throwInvalidPattern e;
unlessM (isLocalDecl fvarId) $ throwInvalidPattern e;
mkPatternVar fvarId e
else if e.isMVar then do
mkLocalDeclFor e
else do
newE ← liftM $ whnf e;
if newE != e then
@ -491,45 +524,51 @@ partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Patter
else
mkInaccessible field
};
pure $ Meta.DepElim.Pattern.ctor declName us params.toList fields.toList
pure $ Pattern.ctor declName us params.toList fields.toList
| _ => throwInvalidPattern e
| _ => throwInvalidPattern e
end ToDepElimPattern
def toDepElimPattern (localDecls : Array LocalDecl) (e : Expr) : TermElabM Meta.DepElim.Pattern :=
(ToDepElimPattern.main localDecls e).run' {}
def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do
(patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls };
localDecls ← s.localDecls.mapM fun d => liftMetaM $ Meta.instantiateLocalDeclMVars d;
/- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/
lctx ← getLCtx;
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx;
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx;
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $ k localDecls patterns
private def elabPatterns (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := do
private def withElaboratedLHS {α} (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
(k : AltLHS → Expr → TermElabM α) : TermElabM α := do
(patterns, matchType) ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[];
localDecls ← finalizePatternDecls patternVarDecls;
patterns ← patterns.mapM instantiateMVars;
patterns.forM $ fun pattern => when pattern.hasExprMVar $ throwError ("pattern contains metavariables " ++ indentExpr pattern);
patterns ← patterns.mapM $ toDepElimPattern localDecls;
trace `Elab.match fun _ => "patterns: " ++ MessageData.ofArray (patterns.map fun (p : Meta.DepElim.Pattern) => p.toMessageData);
pure ({ fvarDecls := localDecls.toList, patterns := patterns.toList }, matchType)
withDepElimPatterns localDecls patterns fun localDecls patterns =>
k { fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) :=
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) :=
withRef alt.ref do
(patternVars, alt) ← collectPatternVars alt;
trace `Elab.match fun _ => "patternVars: " ++ toString patternVars;
withPatternVars patternVars fun patternVarDecls => do
(altLHS, matchType) ← elabPatterns patternVarDecls alt.patterns matchType;
rhs ← elabTermEnsuringType alt.rhs matchType;
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr;
rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs;
trace `Elab.match fun _ => "rhs: " ++ rhs;
pure (altLHS, rhs)
withElaboratedLHS patternVarDecls alt.patterns matchType fun altLHS matchType => do
rhs ← elabTermEnsuringType alt.rhs matchType;
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr;
rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs;
trace `Elab.match fun _ => "rhs: " ++ rhs;
-- TODO: check whether altLHS still has metavariables
pure (altLHS, rhs)
def mkMotiveType (matchType : Expr) (expectedType : Expr) : TermElabM Expr := do
liftMetaM $ Meta.forallTelescopeReducing matchType fun xs matchType => do
u ← Meta.getLevel matchType;
Meta.mkForall xs (mkSort u)
def mkElim (elimName : Name) (motiveType : Expr) (lhss : List Meta.DepElim.AltLHS) : TermElabM Meta.DepElim.ElimResult :=
liftMetaM $ Meta.DepElim.mkElim elimName motiveType lhss
def mkElim (elimName : Name) (motiveType : Expr) (lhss : List AltLHS) : TermElabM ElimResult :=
liftMetaM $ mkElim elimName motiveType lhss
def reportElimResultErrors (result : Meta.DepElim.ElimResult) : TermElabM Unit := do
def reportElimResultErrors (result : ElimResult) : TermElabM Unit := do
-- TODO: improve error messages
unless result.counterExamples.isEmpty $
throwError ("missing cases:" ++ Format.line ++ Meta.DepElim.counterExamplesToMessageData result.counterExamples);

View file

@ -1,3 +1,15 @@
def Vector (α : Type) (n : Nat) := { a : Array α // a.size = n }
def mkVec {α : Type} (n : Nat) (a : α) : Vector α n :=
⟨mkArray n a, rfl⟩
structure S :=
(n : Nat)
(y : Vector Nat n)
(z : Vector Nat n)
(h : y = z)
(m : { v : Nat // v = y.val.size })
new_frontend
def f1 (x : Nat × Nat) : Nat :=
@ -37,8 +49,14 @@ h x y
#eval f5 0 10
#eval f5 20 10
/-
def f2 (x : Nat × Nat) : Nat :=
def f6 (x : Nat × Nat) : Nat :=
match x with
| { fst := x, .. } => x * 10
-/
#eval f6 (5, 20)
def f7 (s : S) : Nat :=
match s with
| { n := n, m := m, .. } => n + m.val
#eval f7 { n := 10, y := mkVec 10 0, z := mkVec 10 0, h := rfl, m := ⟨10, rfl⟩ }

View file

@ -4,3 +4,5 @@
30
10
200
50
20