feat: add checkpoint using withSynthesize
This commit is contained in:
parent
89bd5d6da2
commit
8543a20b8f
2 changed files with 30 additions and 23 deletions
|
|
@ -11,19 +11,19 @@ namespace Lean
|
|||
namespace Elab
|
||||
namespace Term
|
||||
|
||||
structure LetRecDecl :=
|
||||
structure LetRecDeclView :=
|
||||
(attrs : Syntax)
|
||||
(decl : Syntax)
|
||||
|
||||
structure LetRecView :=
|
||||
(ref : Syntax)
|
||||
(isPartial : Bool)
|
||||
(decls : Array LetRecDecl)
|
||||
(decls : Array LetRecDeclView)
|
||||
(body : Syntax)
|
||||
|
||||
private def mkLetRecView (letRec : Syntax) : LetRecView :=
|
||||
let decls := (letRec.getArg 2).getArgs.getSepElems.map fun attrDeclSyntax =>
|
||||
{ attrs := attrDeclSyntax.getArg 0, decl := (attrDeclSyntax.getArg 1).getArg 0 : LetRecDecl };
|
||||
{ attrs := attrDeclSyntax.getArg 0, decl := (attrDeclSyntax.getArg 1).getArg 0 : LetRecDeclView };
|
||||
{ decls := decls,
|
||||
ref := letRec,
|
||||
isPartial := !(letRec.getArg 1).isNone,
|
||||
|
|
@ -45,19 +45,21 @@ let result := result.setArg 2 $ mkSepStx
|
|||
(mkAtomFrom result ", ");
|
||||
result
|
||||
|
||||
private def isLetEqnsDecl (d : LetRecDecl) : Bool :=
|
||||
private def isLetEqnsDecl (d : LetRecDeclView) : Bool :=
|
||||
d.decl.isOfKind `Lean.Parser.Term.letEqnsDecl
|
||||
|
||||
open Meta
|
||||
|
||||
structure LetRecDeclHeader :=
|
||||
(fnFVarId : FVarId)
|
||||
(declName : Name)
|
||||
(name : Name)
|
||||
(type : Expr)
|
||||
(numBinders : Nat)
|
||||
|
||||
instance LetRecDeclHeader.inhabited : Inhabited LetRecDeclHeader := ⟨⟨arbitrary _, arbitrary _⟩⟩
|
||||
instance LetRecDeclHeader.inhabited : Inhabited LetRecDeclHeader := ⟨⟨arbitrary _, arbitrary _, arbitrary _, arbitrary _⟩⟩
|
||||
|
||||
private partial def withLetRecDeclHeadersAux {α} (view : LetRecView) (k : Array LetRecDeclHeader → TermElabM α) : Nat → Array LetRecDeclHeader → TermElabM α
|
||||
| i, acc =>
|
||||
| i, headers =>
|
||||
if h : i < view.decls.size then
|
||||
let decl := (view.decls.get ⟨i, h⟩).decl;
|
||||
-- `decl` is a `letIdDecl` of the form `ident >> many bracketedBinder >> optType >> " := " >> termParser
|
||||
|
|
@ -68,10 +70,13 @@ private partial def withLetRecDeclHeadersAux {α} (view : LetRecView) (k : Array
|
|||
type ← mkForallFVars xs type;
|
||||
pure (type, xs.size)
|
||||
};
|
||||
currDeclName? ← getDeclName?;
|
||||
let declName := currDeclName?.getD Name.anonymous ++ declView.id;
|
||||
checkNotAlreadyDeclared declName;
|
||||
withLocalDeclD declView.id type fun fvar =>
|
||||
withLetRecDeclHeadersAux (i+1) (acc.push ⟨fvar.fvarId!, numBinders⟩)
|
||||
withLetRecDeclHeadersAux (i+1) (headers.push { declName := declName, name := declView.id, type := type, numBinders := numBinders })
|
||||
else
|
||||
k acc
|
||||
k headers
|
||||
|
||||
private def withLetRecDeclHeaders {α} (view : LetRecView) (k : Array LetRecDeclHeader → TermElabM α) : TermElabM α :=
|
||||
withLetRecDeclHeadersAux view k 0 #[]
|
||||
|
|
@ -81,26 +86,32 @@ view.decls.mapIdxM fun i d => do
|
|||
let decl := d.decl;
|
||||
let view := mkLetIdDeclView decl;
|
||||
let header := headers.get! i;
|
||||
headerLocalDecl ← getLocalDecl header.fnFVarId;
|
||||
forallBoundedTelescope headerLocalDecl.type header.numBinders fun xs type =>
|
||||
withDeclNameSuffix view.id do
|
||||
currDeclName? ← getDeclName?;
|
||||
let currDeclName := currDeclName?.get!;
|
||||
checkNotAlreadyDeclared currDeclName;
|
||||
forallBoundedTelescope header.type header.numBinders fun xs type =>
|
||||
withDeclName header.declName do
|
||||
value ← elabTermEnsuringType view.value type;
|
||||
mkLambdaFVars xs value
|
||||
|
||||
private def elabLetRecView (view : LetRecView) (expectedType? : Option Expr) : TermElabM Expr :=
|
||||
withLetRecDeclHeaders view fun headers => do
|
||||
structure LetRecDecl :=
|
||||
(name : Name)
|
||||
(type : Expr)
|
||||
(value : Expr)
|
||||
|
||||
private def elabLetRecView (view : LetRecView) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
decls ← withSynthesize $ withLetRecDeclHeaders view fun headers => do {
|
||||
values ← elabLetRecDeclValues view headers;
|
||||
synthesizeSyntheticMVars false;
|
||||
-- TODO
|
||||
values.forM IO.println;
|
||||
throwError ("WIP")
|
||||
values.mapIdxM fun i value => do
|
||||
let header := headers.get! i;
|
||||
pure { name := header.name, type := header.type, value := value : LetRecDecl }
|
||||
};
|
||||
throwError ("WIP")
|
||||
|
||||
@[builtinTermElab «letrec»] def elabLetRec : TermElab :=
|
||||
fun stx expectedType? => do
|
||||
let view := mkLetRecView stx;
|
||||
view.decls.forM fun (d : LetRecDecl) =>
|
||||
view.decls.forM fun (d : LetRecDeclView) =>
|
||||
when ((d.decl.getArg 0).isOfKind `Lean.Parser.Term.letPatDecl) $
|
||||
throwErrorAt d.decl "patterns are not allowed in letrec expressions";
|
||||
if view.decls.any isLetEqnsDecl then do
|
||||
|
|
|
|||
|
|
@ -227,10 +227,6 @@ def assignLevelMVar (mvarId : MVarId) (val : Level) : TermElabM Unit := modifyTh
|
|||
def withDeclName {α} (name : Name) (x : TermElabM α) : TermElabM α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with declName? := name }) x
|
||||
|
||||
def withDeclNameSuffix {α} (suffix : Name) (x : TermElabM α) : TermElabM α := do
|
||||
name? ← getDeclName?;
|
||||
withDeclName ((name?.getD Name.anonymous) ++ suffix) x
|
||||
|
||||
def logTrace (cls : Name) (msg : MessageData) : TermElabM Unit := do
|
||||
env ← getEnv;
|
||||
mctx ← getMCtx;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue