feat: implement let elaborators without using match_syntax
@Kha I had to do this because of the `ident` vs `Term.id` recurrent issue. `match_syntax` fails if a `Term.id` is used at `Term.letIdDecl` where an `ident` is expected. We should try to remove `Term.id` in the future.
This commit is contained in:
parent
1ba3925740
commit
3342ba08d2
3 changed files with 113 additions and 94 deletions
|
|
@ -212,6 +212,11 @@ match stx with
|
|||
| Syntax.atom _ v => mkNameSimple v
|
||||
| Syntax.ident _ _ _ _ => identKind
|
||||
|
||||
def updateKind (stx : Syntax) (k : SyntaxNodeKind) : Syntax :=
|
||||
match stx with
|
||||
| Syntax.node _ args => Syntax.node k args
|
||||
| _ => stx
|
||||
|
||||
def isOfKind : Syntax → SyntaxNodeKind → Bool
|
||||
| stx, k => stx.getKind == k
|
||||
|
||||
|
|
@ -770,6 +775,12 @@ match stx with
|
|||
Name.anonymous
|
||||
| _ => Name.anonymous
|
||||
|
||||
/- Given an `ident` or `Term.id`, return its name -/
|
||||
def getRelaxedId (stx : Syntax) : Name :=
|
||||
match stx with
|
||||
| Syntax.ident _ _ _ _ => stx.getId
|
||||
| _ => stx.getIdOfTermId
|
||||
|
||||
/-- Similar to `isTermId?`, but succeeds only if the optional part is a `none`. -/
|
||||
def isSimpleTermId? (stx : Syntax) (relaxed : Bool := false) : Option Syntax :=
|
||||
match stx.isTermId? relaxed with
|
||||
|
|
|
|||
|
|
@ -416,41 +416,39 @@ else do
|
|||
};
|
||||
pure $ mkApp f val
|
||||
|
||||
-- letIdLhs := ident >> checkWsBefore "expected space before binders" >> many bracketedBinder >> optType
|
||||
def expandLetIdLhs (letIdLhs : Syntax) : Name × Array Syntax × Syntax := do
|
||||
let id := (letIdLhs.getArg 0).getRelaxedId; -- allow `Term.id` to be used as an id for convenience of macro writers
|
||||
let binders := (letIdLhs.getArg 1).getArgs;
|
||||
let optType := letIdLhs.getArg 2;
|
||||
let type := expandOptType letIdLhs optType;
|
||||
(id, binders, type)
|
||||
|
||||
def elabLetDeclCore (stx : Syntax) (expectedType? : Option Expr) (useLetExpr : Bool) : TermElabM Expr := do
|
||||
let ref := stx;
|
||||
let letDecl := (stx.getArg 1).getArg 0;
|
||||
let body := stx.getArg 3;
|
||||
if letDecl.getKind == `Lean.Parser.Term.letIdDecl then
|
||||
let (id, binders, type) := expandLetIdLhs letDecl;
|
||||
let val := letDecl.getArg 4;
|
||||
elabLetDeclAux id binders type val body expectedType? useLetExpr
|
||||
else if letDecl.getKind == `Lean.Parser.Term.letPatDecl then do
|
||||
-- node `Lean.Parser.Term.letPatDecl $ try (termParser >> pushNone >> optType >> " := ") >> termParser
|
||||
let pat := letDecl.getArg 0;
|
||||
let optType := letDecl.getArg 2;
|
||||
let type := expandOptType stx optType;
|
||||
let val := letDecl.getArg 4;
|
||||
stxNew ← `(let x : $type := $val; match x with $pat => $body);
|
||||
let stxNew := if useLetExpr then stxNew else stxNew.updateKind `Lean.Parser.Term.«let!»;
|
||||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
else
|
||||
throwError "WIP"
|
||||
|
||||
@[builtinTermElab «let»] def elabLetDecl : TermElab :=
|
||||
fun stx expectedType? => match_syntax stx with
|
||||
| `(let $id:ident $args* := $val; $body) =>
|
||||
elabLetDeclAux id.getId args (mkHole stx) val body expectedType? true
|
||||
| `(let $id:ident $args* : $type := $val; $body) =>
|
||||
elabLetDeclAux id.getId args type val body expectedType? true
|
||||
| `(let $id:ident $args* | $alts:matchAlt*; $body) =>
|
||||
throwError "invalid let-expression with pattern matching, type must be provided"
|
||||
| `(let $id:ident $args* : $type | $alts:matchAlt*; $body) =>
|
||||
throwError "WIP" -- TODO
|
||||
| `(let $pat:term := $val; $body) => do
|
||||
stxNew ← `(let x := $val; match x with $pat => $body);
|
||||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
| `(let $pat:term : $type := $val; $body) => do
|
||||
stxNew ← `(let x : $type := $val; match x with $pat => $body);
|
||||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
| _ => throwUnsupportedSyntax
|
||||
fun stx expectedType? => elabLetDeclCore stx expectedType? true
|
||||
|
||||
@[builtinTermElab «let!»] def elabLetBangDecl : TermElab :=
|
||||
fun stx expectedType? => match_syntax stx with
|
||||
| `(let! $id:ident $args* := $val; $body) =>
|
||||
elabLetDeclAux id.getId args (mkHole stx) val body expectedType? false
|
||||
| `(let! $id:ident $args* : $type := $val; $body) =>
|
||||
elabLetDeclAux id.getId args type val body expectedType? false
|
||||
| `(let! $id:ident $args* | $alts:matchAlt*; $body) =>
|
||||
throwError "invalid let-expression with pattern matching, type must be provided"
|
||||
| `(let! $id:ident $args* : $type | $alts:matchAlt*; $body) =>
|
||||
throwError "WIP" -- TODO
|
||||
| `(let! $pat:term := $val; $body) => do
|
||||
stxNew ← `(let! x := $val; match x with $pat => $body);
|
||||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
| `(let! $pat:term : $type := $val; $body) => do
|
||||
stxNew ← `(let! x : $type := $val; match x with $pat => $body);
|
||||
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
| _ => throwUnsupportedSyntax
|
||||
fun stx expectedType? => elabLetDeclCore stx expectedType? false
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Elab.let;
|
||||
|
|
|
|||
|
|
@ -278,47 +278,50 @@ f Type 1
|
|||
let x := 0; x + 1
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl `x [] [] ":=" (Term.num (numLit "0")))
|
||||
(Term.letDecl (Term.letIdDecl `x [] [] ":=" (Term.num (numLit "0"))))
|
||||
";"
|
||||
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
|
||||
let x : Nat := 0; x + 1
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl `x [] [(Term.typeSpec ":" (Term.id `Nat []))] ":=" (Term.num (numLit "0")))
|
||||
(Term.letDecl (Term.letIdDecl `x [] [(Term.typeSpec ":" (Term.id `Nat []))] ":=" (Term.num (numLit "0"))))
|
||||
";"
|
||||
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
|
||||
let f (x : Nat) := x + 1; f 0
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
`f
|
||||
[(Term.explicitBinder "(" [`x] [":" (Term.id `Nat [])] [] ")")]
|
||||
[]
|
||||
":="
|
||||
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
|
||||
(Term.letIdDecl
|
||||
`f
|
||||
[(Term.explicitBinder "(" [`x] [":" (Term.id `Nat [])] [] ")")]
|
||||
[]
|
||||
":="
|
||||
(Term.add (Term.id `x []) "+" (Term.num (numLit "1")))))
|
||||
";"
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "0"))]))
|
||||
let f {α : Type} (a : α) : α := a; f 10
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
`f
|
||||
[(Term.implicitBinder "{" [`α] [":" (Term.type "Type" [])] "}")
|
||||
(Term.explicitBinder "(" [`a] [":" (Term.id `α [])] [] ")")]
|
||||
[(Term.typeSpec ":" (Term.id `α []))]
|
||||
":="
|
||||
(Term.id `a []))
|
||||
(Term.letIdDecl
|
||||
`f
|
||||
[(Term.implicitBinder "{" [`α] [":" (Term.type "Type" [])] "}")
|
||||
(Term.explicitBinder "(" [`a] [":" (Term.id `α [])] [] ")")]
|
||||
[(Term.typeSpec ":" (Term.id `α []))]
|
||||
":="
|
||||
(Term.id `a [])))
|
||||
";"
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
|
||||
let f (x) := x + 1; f 10 + f 20
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
`f
|
||||
[(Term.explicitBinder "(" [`x] [] [] ")")]
|
||||
[]
|
||||
":="
|
||||
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
|
||||
(Term.letIdDecl
|
||||
`f
|
||||
[(Term.explicitBinder "(" [`x] [] [] ")")]
|
||||
[]
|
||||
":="
|
||||
(Term.add (Term.id `x []) "+" (Term.num (numLit "1")))))
|
||||
";"
|
||||
(Term.add
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])
|
||||
|
|
@ -328,61 +331,66 @@ let (x, y) := f 10; x + y
|
|||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
(Term.paren "(" [(Term.id `x []) [(Term.tupleTail "," [(Term.id `y [])])]] ")")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
|
||||
(Term.letPatDecl
|
||||
(Term.paren "(" [(Term.id `x []) [(Term.tupleTail "," [(Term.id `y [])])]] ")")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
|
||||
";"
|
||||
(Term.add (Term.id `x []) "+" (Term.id `y [])))
|
||||
let { fst := x, .. } := f 10; x + x
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
(Term.structInst "{" [] [(Term.structInstField `fst [] ":=" (Term.id `x [])) ","] [".."] [] "}")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
|
||||
(Term.letPatDecl
|
||||
(Term.structInst "{" [] [(Term.structInstField `fst [] ":=" (Term.id `x [])) ","] [".."] [] "}")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
|
||||
";"
|
||||
(Term.add (Term.id `x []) "+" (Term.id `x [])))
|
||||
let x.y := f 10; x
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl `x.y [] [] ":=" (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
|
||||
(Term.letDecl (Term.letIdDecl `x.y [] [] ":=" (Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
|
||||
";"
|
||||
(Term.id `x []))
|
||||
let x.1 := f 10; x
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
(Term.proj (Term.id `x []) "." (fieldIdx "1"))
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
|
||||
(Term.letPatDecl
|
||||
(Term.proj (Term.id `x []) "." (fieldIdx "1"))
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
|
||||
";"
|
||||
(Term.id `x []))
|
||||
let x[i].y := f 10; x
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
(Term.proj (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") "." `y)
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
|
||||
(Term.letPatDecl
|
||||
(Term.proj (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") "." `y)
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
|
||||
";"
|
||||
(Term.id `x []))
|
||||
let x[i] := f 20; x
|
||||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
(Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "20"))]))
|
||||
(Term.letPatDecl
|
||||
(Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "20"))])))
|
||||
";"
|
||||
(Term.id `x []))
|
||||
-x + y
|
||||
|
|
@ -408,7 +416,7 @@ do
|
|||
";"
|
||||
(Term.doExpr (Term.app (Term.id `g []) [(Term.id `x [])]))
|
||||
";"
|
||||
(Term.doLet "let" (Term.letDecl `y [] [] ":=" (Term.app (Term.id `g []) [(Term.id `x [])])))
|
||||
(Term.doLet "let" (Term.letDecl (Term.letIdDecl `y [] [] ":=" (Term.app (Term.id `g []) [(Term.id `x [])]))))
|
||||
";"
|
||||
(Term.doPat
|
||||
(Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")")
|
||||
|
|
@ -419,11 +427,12 @@ do
|
|||
(Term.doLet
|
||||
"let"
|
||||
(Term.letDecl
|
||||
(Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.paren "(" [(Term.id `b []) [(Term.tupleTail "," [(Term.id `a [])])]] ")")))
|
||||
(Term.letPatDecl
|
||||
(Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")")
|
||||
[]
|
||||
[]
|
||||
":="
|
||||
(Term.paren "(" [(Term.id `b []) [(Term.tupleTail "," [(Term.id `a [])])]] ")"))))
|
||||
";"
|
||||
(Term.doExpr (Term.app (Term.id `pure []) [(Term.paren "(" [(Term.add (Term.id `a []) "+" (Term.id `b [])) []] ")")]))])
|
||||
do { x ← f a; pure $ a + a }
|
||||
|
|
@ -442,19 +451,20 @@ f 20
|
|||
(Term.let
|
||||
"let"
|
||||
(Term.letDecl
|
||||
`f
|
||||
[]
|
||||
[(Term.typeSpec ":" (Term.arrow (Term.id `Nat []) "→" (Term.arrow (Term.id `Nat []) "→" (Term.id `Nat []))))]
|
||||
"|"
|
||||
[(Term.matchAlt
|
||||
[(Term.num (numLit "0")) "," (Term.id `a [])]
|
||||
"=>"
|
||||
(Term.add (Term.id `a []) "+" (Term.num (numLit "10"))))
|
||||
(Term.letEqnsDecl
|
||||
`f
|
||||
[]
|
||||
[(Term.typeSpec ":" (Term.arrow (Term.id `Nat []) "→" (Term.arrow (Term.id `Nat []) "→" (Term.id `Nat []))))]
|
||||
"|"
|
||||
(Term.matchAlt
|
||||
[(Term.add (Term.id `n []) "+" (Term.num (numLit "1"))) "," (Term.id `b [])]
|
||||
"=>"
|
||||
(Term.mul (Term.id `n []) "*" (Term.id `b [])))])
|
||||
[(Term.matchAlt
|
||||
[(Term.num (numLit "0")) "," (Term.id `a [])]
|
||||
"=>"
|
||||
(Term.add (Term.id `a []) "+" (Term.num (numLit "10"))))
|
||||
"|"
|
||||
(Term.matchAlt
|
||||
[(Term.add (Term.id `n []) "+" (Term.num (numLit "1"))) "," (Term.id `b [])]
|
||||
"=>"
|
||||
(Term.mul (Term.id `n []) "*" (Term.id `b [])))]))
|
||||
";"
|
||||
(Term.app (Term.id `f []) [(Term.num (numLit "20"))]))
|
||||
max a b
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue