feat: structure instances with .. in patterns
This commit is contained in:
parent
b47f530db5
commit
81a19c8554
3 changed files with 99 additions and 40 deletions
|
|
@ -385,10 +385,11 @@ private partial def elabPatternsAux (patternStxs : Array Syntax) : Nat → Expr
|
|||
|
||||
def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) :=
|
||||
patternVarDecls.foldlM
|
||||
(fun (decls : Array LocalDecl) pdecl =>
|
||||
(fun (decls : Array LocalDecl) pdecl => do
|
||||
match pdecl with
|
||||
| PatternVarDecl.localVar fvarId => do
|
||||
decl ← getLocalDecl fvarId;
|
||||
decl ← liftMetaM $ Meta.instantiateLocalDeclMVars decl;
|
||||
pure $ decls.push decl
|
||||
| PatternVarDecl.anonymousVar mvarId fvarId => do
|
||||
e ← instantiateMVars (mkMVar mvarId);
|
||||
|
|
@ -400,14 +401,19 @@ patternVarDecls.foldlM
|
|||
assignExprMVar newMVarId (mkFVar fvarId);
|
||||
trace `Elab.match fun _ => "finalizePatternDecls: " ++ mkMVar newMVarId ++ " := " ++ mkFVar fvarId;
|
||||
decl ← getLocalDecl fvarId;
|
||||
decl ← liftMetaM $ Meta.instantiateLocalDeclMVars decl;
|
||||
pure $ decls.push decl
|
||||
| _ => pure decls)
|
||||
#[]
|
||||
|
||||
open Meta.DepElim (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS mkElim ElimResult)
|
||||
|
||||
namespace ToDepElimPattern
|
||||
|
||||
structure State :=
|
||||
(found : NameSet := {})
|
||||
(found : NameSet := {})
|
||||
(localDecls : Array LocalDecl)
|
||||
(newLocals : NameSet := {})
|
||||
|
||||
abbrev M := StateT State TermElabM
|
||||
|
||||
|
|
@ -429,27 +435,52 @@ private def getFieldsBinderInfoAux (ctorVal : ConstructorVal) : Nat → Expr →
|
|||
getFieldsBinderInfoAux (i+1) b (bis.push c.binderInfo)
|
||||
| _, _, bis => bis
|
||||
|
||||
/- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/
|
||||
private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
|
||||
let mvarId := mvar.mvarId!;
|
||||
s ← get;
|
||||
val? ← liftM $ liftMetaM $ Meta.getExprMVarAssignment? mvarId;
|
||||
match val? with
|
||||
| some val => pure $ Pattern.inaccessible val
|
||||
| none => do
|
||||
fvarId ← liftM $ mkFreshId;
|
||||
type ← liftM $ inferType mvar;
|
||||
/- HACK: `fvarId` is not in the scope of `mvarId`
|
||||
If this generates problems in the future, we should update the metavariable declarations. -/
|
||||
liftM $ assignExprMVar mvarId (mkFVar fvarId);
|
||||
let userName := (`_x).appendIndexAfter (s.localDecls.size+1);
|
||||
let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default;
|
||||
modify $ fun s =>
|
||||
{ s with
|
||||
newLocals := s.newLocals.insert fvarId,
|
||||
localDecls :=
|
||||
match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with
|
||||
| none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar`
|
||||
| some i => s.localDecls.insertAt i newDecl };
|
||||
pure $ Pattern.var fvarId
|
||||
|
||||
private def getFieldsBinderInfo (ctorVal : ConstructorVal) : Array BinderInfo :=
|
||||
getFieldsBinderInfoAux ctorVal 0 ctorVal.type #[]
|
||||
|
||||
partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Pattern
|
||||
partial def main : Expr → M Pattern
|
||||
| e =>
|
||||
let isLocalDecl (fvarId : FVarId) : Bool :=
|
||||
localDecls.any fun d => d.fvarId == fvarId;
|
||||
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Meta.DepElim.Pattern := do {
|
||||
condM (alreadyVisited fvarId)
|
||||
(pure $ Meta.DepElim.Pattern.inaccessible e)
|
||||
(do markAsVisited fvarId; pure $ Meta.DepElim.Pattern.var e.fvarId!)
|
||||
let isLocalDecl (fvarId : FVarId) : M Bool := do {
|
||||
s ← get;
|
||||
pure $ s.localDecls.any fun d => d.fvarId == fvarId
|
||||
};
|
||||
let mkInaccessible (e : Expr) : M Meta.DepElim.Pattern := do {
|
||||
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do {
|
||||
condM (alreadyVisited fvarId)
|
||||
(pure $ Pattern.inaccessible e)
|
||||
(do markAsVisited fvarId; pure $ Pattern.var e.fvarId!)
|
||||
};
|
||||
let mkInaccessible (e : Expr) : M Pattern := do {
|
||||
match e with
|
||||
| Expr.fvar fvarId _ =>
|
||||
if isLocalDecl fvarId then
|
||||
mkPatternVar fvarId e
|
||||
else
|
||||
pure $ Meta.DepElim.Pattern.inaccessible e
|
||||
condM (isLocalDecl fvarId)
|
||||
(mkPatternVar fvarId e)
|
||||
(pure $ Pattern.inaccessible e)
|
||||
| _ =>
|
||||
pure $ Meta.DepElim.Pattern.inaccessible e
|
||||
pure $ Pattern.inaccessible e
|
||||
};
|
||||
match inaccessible? e with
|
||||
| some t => mkInaccessible t
|
||||
|
|
@ -457,19 +488,21 @@ partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Patter
|
|||
match e.arrayLit? with
|
||||
| some (α, lits) => do
|
||||
ps ← lits.mapM main;
|
||||
pure $ Meta.DepElim.Pattern.arrayLit α ps
|
||||
pure $ Pattern.arrayLit α ps
|
||||
| none =>
|
||||
if e.isAppOfArity `namedPattern 3 then do
|
||||
p ← main $ e.getArg! 2;
|
||||
match e.getArg! 1 with
|
||||
| Expr.fvar fvarId _ => pure $ Meta.DepElim.Pattern.as fvarId p
|
||||
| Expr.fvar fvarId _ => pure $ Pattern.as fvarId p
|
||||
| _ => liftM $ throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
|
||||
else if e.isNatLit || e.isStringLit || e.isCharLit then
|
||||
pure $ Meta.DepElim.Pattern.val e
|
||||
pure $ Pattern.val e
|
||||
else if e.isFVar then do
|
||||
let fvarId := e.fvarId!;
|
||||
unless (isLocalDecl fvarId) $ throwInvalidPattern e;
|
||||
unlessM (isLocalDecl fvarId) $ throwInvalidPattern e;
|
||||
mkPatternVar fvarId e
|
||||
else if e.isMVar then do
|
||||
mkLocalDeclFor e
|
||||
else do
|
||||
newE ← liftM $ whnf e;
|
||||
if newE != e then
|
||||
|
|
@ -491,45 +524,51 @@ partial def main (localDecls : Array LocalDecl) : Expr → M Meta.DepElim.Patter
|
|||
else
|
||||
mkInaccessible field
|
||||
};
|
||||
pure $ Meta.DepElim.Pattern.ctor declName us params.toList fields.toList
|
||||
pure $ Pattern.ctor declName us params.toList fields.toList
|
||||
| _ => throwInvalidPattern e
|
||||
| _ => throwInvalidPattern e
|
||||
|
||||
end ToDepElimPattern
|
||||
|
||||
def toDepElimPattern (localDecls : Array LocalDecl) (e : Expr) : TermElabM Meta.DepElim.Pattern :=
|
||||
(ToDepElimPattern.main localDecls e).run' {}
|
||||
def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do
|
||||
(patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls };
|
||||
localDecls ← s.localDecls.mapM fun d => liftMetaM $ Meta.instantiateLocalDeclMVars d;
|
||||
/- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/
|
||||
lctx ← getLCtx;
|
||||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx;
|
||||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx;
|
||||
adaptReader (fun (ctx : Context) => { ctx with lctx := lctx }) $ k localDecls patterns
|
||||
|
||||
private def elabPatterns (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) := do
|
||||
private def withElaboratedLHS {α} (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
|
||||
(k : AltLHS → Expr → TermElabM α) : TermElabM α := do
|
||||
(patterns, matchType) ← withSynthesize $ elabPatternsAux patternStxs 0 matchType #[];
|
||||
localDecls ← finalizePatternDecls patternVarDecls;
|
||||
patterns ← patterns.mapM instantiateMVars;
|
||||
patterns.forM $ fun pattern => when pattern.hasExprMVar $ throwError ("pattern contains metavariables " ++ indentExpr pattern);
|
||||
patterns ← patterns.mapM $ toDepElimPattern localDecls;
|
||||
trace `Elab.match fun _ => "patterns: " ++ MessageData.ofArray (patterns.map fun (p : Meta.DepElim.Pattern) => p.toMessageData);
|
||||
pure ({ fvarDecls := localDecls.toList, patterns := patterns.toList }, matchType)
|
||||
withDepElimPatterns localDecls patterns fun localDecls patterns =>
|
||||
k { fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
|
||||
|
||||
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (Meta.DepElim.AltLHS × Expr) :=
|
||||
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) :=
|
||||
withRef alt.ref do
|
||||
(patternVars, alt) ← collectPatternVars alt;
|
||||
trace `Elab.match fun _ => "patternVars: " ++ toString patternVars;
|
||||
withPatternVars patternVars fun patternVarDecls => do
|
||||
(altLHS, matchType) ← elabPatterns patternVarDecls alt.patterns matchType;
|
||||
rhs ← elabTermEnsuringType alt.rhs matchType;
|
||||
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr;
|
||||
rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs;
|
||||
trace `Elab.match fun _ => "rhs: " ++ rhs;
|
||||
pure (altLHS, rhs)
|
||||
withElaboratedLHS patternVarDecls alt.patterns matchType fun altLHS matchType => do
|
||||
rhs ← elabTermEnsuringType alt.rhs matchType;
|
||||
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr;
|
||||
rhs ← if xs.isEmpty then pure $ mkThunk rhs else mkLambda xs rhs;
|
||||
trace `Elab.match fun _ => "rhs: " ++ rhs;
|
||||
-- TODO: check whether altLHS still has metavariables
|
||||
pure (altLHS, rhs)
|
||||
|
||||
def mkMotiveType (matchType : Expr) (expectedType : Expr) : TermElabM Expr := do
|
||||
liftMetaM $ Meta.forallTelescopeReducing matchType fun xs matchType => do
|
||||
u ← Meta.getLevel matchType;
|
||||
Meta.mkForall xs (mkSort u)
|
||||
|
||||
def mkElim (elimName : Name) (motiveType : Expr) (lhss : List Meta.DepElim.AltLHS) : TermElabM Meta.DepElim.ElimResult :=
|
||||
liftMetaM $ Meta.DepElim.mkElim elimName motiveType lhss
|
||||
def mkElim (elimName : Name) (motiveType : Expr) (lhss : List AltLHS) : TermElabM ElimResult :=
|
||||
liftMetaM $ mkElim elimName motiveType lhss
|
||||
|
||||
def reportElimResultErrors (result : Meta.DepElim.ElimResult) : TermElabM Unit := do
|
||||
def reportElimResultErrors (result : ElimResult) : TermElabM Unit := do
|
||||
-- TODO: improve error messages
|
||||
unless result.counterExamples.isEmpty $
|
||||
throwError ("missing cases:" ++ Format.line ++ Meta.DepElim.counterExamplesToMessageData result.counterExamples);
|
||||
|
|
|
|||
|
|
@ -1,3 +1,15 @@
|
|||
def Vector (α : Type) (n : Nat) := { a : Array α // a.size = n }
|
||||
|
||||
def mkVec {α : Type} (n : Nat) (a : α) : Vector α n :=
|
||||
⟨mkArray n a, rfl⟩
|
||||
|
||||
structure S :=
|
||||
(n : Nat)
|
||||
(y : Vector Nat n)
|
||||
(z : Vector Nat n)
|
||||
(h : y = z)
|
||||
(m : { v : Nat // v = y.val.size })
|
||||
|
||||
new_frontend
|
||||
|
||||
def f1 (x : Nat × Nat) : Nat :=
|
||||
|
|
@ -37,8 +49,14 @@ h x y
|
|||
#eval f5 0 10
|
||||
#eval f5 20 10
|
||||
|
||||
/-
|
||||
def f2 (x : Nat × Nat) : Nat :=
|
||||
def f6 (x : Nat × Nat) : Nat :=
|
||||
match x with
|
||||
| { fst := x, .. } => x * 10
|
||||
-/
|
||||
|
||||
#eval f6 (5, 20)
|
||||
|
||||
def f7 (s : S) : Nat :=
|
||||
match s with
|
||||
| { n := n, m := m, .. } => n + m.val
|
||||
|
||||
#eval f7 { n := 10, y := mkVec 10 0, z := mkVec 10 0, h := rfl, m := ⟨10, rfl⟩ }
|
||||
|
|
|
|||
|
|
@ -4,3 +4,5 @@
|
|||
30
|
||||
10
|
||||
200
|
||||
50
|
||||
20
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue