feat: implement let elaborators without using match_syntax

@Kha I had to do this because of the `ident` vs `Term.id` recurrent
issue. `match_syntax` fails if a `Term.id` is used at `Term.letIdDecl`
where an `ident` is expected.
We should try to remove `Term.id` in the future.
This commit is contained in:
Leonardo de Moura 2020-08-17 09:17:45 -07:00
parent 1ba3925740
commit 3342ba08d2
3 changed files with 113 additions and 94 deletions

View file

@ -212,6 +212,11 @@ match stx with
| Syntax.atom _ v => mkNameSimple v
| Syntax.ident _ _ _ _ => identKind
def updateKind (stx : Syntax) (k : SyntaxNodeKind) : Syntax :=
match stx with
| Syntax.node _ args => Syntax.node k args
| _ => stx
def isOfKind : Syntax → SyntaxNodeKind → Bool
| stx, k => stx.getKind == k
@ -770,6 +775,12 @@ match stx with
Name.anonymous
| _ => Name.anonymous
/- Given an `ident` or `Term.id`, return its name -/
def getRelaxedId (stx : Syntax) : Name :=
match stx with
| Syntax.ident _ _ _ _ => stx.getId
| _ => stx.getIdOfTermId
/-- Similar to `isTermId?`, but succeeds only if the optional part is a `none`. -/
def isSimpleTermId? (stx : Syntax) (relaxed : Bool := false) : Option Syntax :=
match stx.isTermId? relaxed with

View file

@ -416,41 +416,39 @@ else do
};
pure $ mkApp f val
-- letIdLhs := ident >> checkWsBefore "expected space before binders" >> many bracketedBinder >> optType
def expandLetIdLhs (letIdLhs : Syntax) : Name × Array Syntax × Syntax := do
let id := (letIdLhs.getArg 0).getRelaxedId; -- allow `Term.id` to be used as an id for convenience of macro writers
let binders := (letIdLhs.getArg 1).getArgs;
let optType := letIdLhs.getArg 2;
let type := expandOptType letIdLhs optType;
(id, binders, type)
def elabLetDeclCore (stx : Syntax) (expectedType? : Option Expr) (useLetExpr : Bool) : TermElabM Expr := do
let ref := stx;
let letDecl := (stx.getArg 1).getArg 0;
let body := stx.getArg 3;
if letDecl.getKind == `Lean.Parser.Term.letIdDecl then
let (id, binders, type) := expandLetIdLhs letDecl;
let val := letDecl.getArg 4;
elabLetDeclAux id binders type val body expectedType? useLetExpr
else if letDecl.getKind == `Lean.Parser.Term.letPatDecl then do
-- node `Lean.Parser.Term.letPatDecl $ try (termParser >> pushNone >> optType >> " := ") >> termParser
let pat := letDecl.getArg 0;
let optType := letDecl.getArg 2;
let type := expandOptType stx optType;
let val := letDecl.getArg 4;
stxNew ← `(let x : $type := $val; match x with $pat => $body);
let stxNew := if useLetExpr then stxNew else stxNew.updateKind `Lean.Parser.Term.«let!»;
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
else
throwError "WIP"
@[builtinTermElab «let»] def elabLetDecl : TermElab :=
fun stx expectedType? => match_syntax stx with
| `(let $id:ident $args* := $val; $body) =>
elabLetDeclAux id.getId args (mkHole stx) val body expectedType? true
| `(let $id:ident $args* : $type := $val; $body) =>
elabLetDeclAux id.getId args type val body expectedType? true
| `(let $id:ident $args* | $alts:matchAlt*; $body) =>
throwError "invalid let-expression with pattern matching, type must be provided"
| `(let $id:ident $args* : $type | $alts:matchAlt*; $body) =>
throwError "WIP" -- TODO
| `(let $pat:term := $val; $body) => do
stxNew ← `(let x := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| `(let $pat:term : $type := $val; $body) => do
stxNew ← `(let x : $type := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| _ => throwUnsupportedSyntax
fun stx expectedType? => elabLetDeclCore stx expectedType? true
@[builtinTermElab «let!»] def elabLetBangDecl : TermElab :=
fun stx expectedType? => match_syntax stx with
| `(let! $id:ident $args* := $val; $body) =>
elabLetDeclAux id.getId args (mkHole stx) val body expectedType? false
| `(let! $id:ident $args* : $type := $val; $body) =>
elabLetDeclAux id.getId args type val body expectedType? false
| `(let! $id:ident $args* | $alts:matchAlt*; $body) =>
throwError "invalid let-expression with pattern matching, type must be provided"
| `(let! $id:ident $args* : $type | $alts:matchAlt*; $body) =>
throwError "WIP" -- TODO
| `(let! $pat:term := $val; $body) => do
stxNew ← `(let! x := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| `(let! $pat:term : $type := $val; $body) => do
stxNew ← `(let! x : $type := $val; match x with $pat => $body);
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| _ => throwUnsupportedSyntax
fun stx expectedType? => elabLetDeclCore stx expectedType? false
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.let;

View file

@ -278,47 +278,50 @@ f Type 1
let x := 0; x + 1
(Term.let
"let"
(Term.letDecl `x [] [] ":=" (Term.num (numLit "0")))
(Term.letDecl (Term.letIdDecl `x [] [] ":=" (Term.num (numLit "0"))))
";"
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
let x : Nat := 0; x + 1
(Term.let
"let"
(Term.letDecl `x [] [(Term.typeSpec ":" (Term.id `Nat []))] ":=" (Term.num (numLit "0")))
(Term.letDecl (Term.letIdDecl `x [] [(Term.typeSpec ":" (Term.id `Nat []))] ":=" (Term.num (numLit "0"))))
";"
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
let f (x : Nat) := x + 1; f 0
(Term.let
"let"
(Term.letDecl
`f
[(Term.explicitBinder "(" [`x] [":" (Term.id `Nat [])] [] ")")]
[]
":="
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
(Term.letIdDecl
`f
[(Term.explicitBinder "(" [`x] [":" (Term.id `Nat [])] [] ")")]
[]
":="
(Term.add (Term.id `x []) "+" (Term.num (numLit "1")))))
";"
(Term.app (Term.id `f []) [(Term.num (numLit "0"))]))
let f {α : Type} (a : α) : α := a; f 10
(Term.let
"let"
(Term.letDecl
`f
[(Term.implicitBinder "{" [`α] [":" (Term.type "Type" [])] "}")
(Term.explicitBinder "(" [`a] [":" (Term.id `α [])] [] ")")]
[(Term.typeSpec ":" (Term.id `α []))]
":="
(Term.id `a []))
(Term.letIdDecl
`f
[(Term.implicitBinder "{" [`α] [":" (Term.type "Type" [])] "}")
(Term.explicitBinder "(" [`a] [":" (Term.id `α [])] [] ")")]
[(Term.typeSpec ":" (Term.id `α []))]
":="
(Term.id `a [])))
";"
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
let f (x) := x + 1; f 10 + f 20
(Term.let
"let"
(Term.letDecl
`f
[(Term.explicitBinder "(" [`x] [] [] ")")]
[]
":="
(Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))
(Term.letIdDecl
`f
[(Term.explicitBinder "(" [`x] [] [] ")")]
[]
":="
(Term.add (Term.id `x []) "+" (Term.num (numLit "1")))))
";"
(Term.add
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])
@ -328,61 +331,66 @@ let (x, y) := f 10; x + y
(Term.let
"let"
(Term.letDecl
(Term.paren "(" [(Term.id `x []) [(Term.tupleTail "," [(Term.id `y [])])]] ")")
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
(Term.letPatDecl
(Term.paren "(" [(Term.id `x []) [(Term.tupleTail "," [(Term.id `y [])])]] ")")
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
";"
(Term.add (Term.id `x []) "+" (Term.id `y [])))
let { fst := x, .. } := f 10; x + x
(Term.let
"let"
(Term.letDecl
(Term.structInst "{" [] [(Term.structInstField `fst [] ":=" (Term.id `x [])) ","] [".."] [] "}")
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
(Term.letPatDecl
(Term.structInst "{" [] [(Term.structInstField `fst [] ":=" (Term.id `x [])) ","] [".."] [] "}")
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
";"
(Term.add (Term.id `x []) "+" (Term.id `x [])))
let x.y := f 10; x
(Term.let
"let"
(Term.letDecl `x.y [] [] ":=" (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
(Term.letDecl (Term.letIdDecl `x.y [] [] ":=" (Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
";"
(Term.id `x []))
let x.1 := f 10; x
(Term.let
"let"
(Term.letDecl
(Term.proj (Term.id `x []) "." (fieldIdx "1"))
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
(Term.letPatDecl
(Term.proj (Term.id `x []) "." (fieldIdx "1"))
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
";"
(Term.id `x []))
let x[i].y := f 10; x
(Term.let
"let"
(Term.letDecl
(Term.proj (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") "." `y)
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))]))
(Term.letPatDecl
(Term.proj (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") "." `y)
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "10"))])))
";"
(Term.id `x []))
let x[i] := f 20; x
(Term.let
"let"
(Term.letDecl
(Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]")
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "20"))]))
(Term.letPatDecl
(Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]")
[]
[]
":="
(Term.app (Term.id `f []) [(Term.num (numLit "20"))])))
";"
(Term.id `x []))
-x + y
@ -408,7 +416,7 @@ do
";"
(Term.doExpr (Term.app (Term.id `g []) [(Term.id `x [])]))
";"
(Term.doLet "let" (Term.letDecl `y [] [] ":=" (Term.app (Term.id `g []) [(Term.id `x [])])))
(Term.doLet "let" (Term.letDecl (Term.letIdDecl `y [] [] ":=" (Term.app (Term.id `g []) [(Term.id `x [])]))))
";"
(Term.doPat
(Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")")
@ -419,11 +427,12 @@ do
(Term.doLet
"let"
(Term.letDecl
(Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")")
[]
[]
":="
(Term.paren "(" [(Term.id `b []) [(Term.tupleTail "," [(Term.id `a [])])]] ")")))
(Term.letPatDecl
(Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")")
[]
[]
":="
(Term.paren "(" [(Term.id `b []) [(Term.tupleTail "," [(Term.id `a [])])]] ")"))))
";"
(Term.doExpr (Term.app (Term.id `pure []) [(Term.paren "(" [(Term.add (Term.id `a []) "+" (Term.id `b [])) []] ")")]))])
do { x ← f a; pure $ a + a }
@ -442,19 +451,20 @@ f 20
(Term.let
"let"
(Term.letDecl
`f
[]
[(Term.typeSpec ":" (Term.arrow (Term.id `Nat []) "→" (Term.arrow (Term.id `Nat []) "→" (Term.id `Nat []))))]
"|"
[(Term.matchAlt
[(Term.num (numLit "0")) "," (Term.id `a [])]
"=>"
(Term.add (Term.id `a []) "+" (Term.num (numLit "10"))))
(Term.letEqnsDecl
`f
[]
[(Term.typeSpec ":" (Term.arrow (Term.id `Nat []) "→" (Term.arrow (Term.id `Nat []) "→" (Term.id `Nat []))))]
"|"
(Term.matchAlt
[(Term.add (Term.id `n []) "+" (Term.num (numLit "1"))) "," (Term.id `b [])]
"=>"
(Term.mul (Term.id `n []) "*" (Term.id `b [])))])
[(Term.matchAlt
[(Term.num (numLit "0")) "," (Term.id `a [])]
"=>"
(Term.add (Term.id `a []) "+" (Term.num (numLit "10"))))
"|"
(Term.matchAlt
[(Term.add (Term.id `n []) "+" (Term.num (numLit "1"))) "," (Term.id `b [])]
"=>"
(Term.mul (Term.id `n []) "*" (Term.id `b [])))]))
";"
(Term.app (Term.id `f []) [(Term.num (numLit "20"))]))
max a b