diff --git a/src/Init/LeanInit.lean b/src/Init/LeanInit.lean index de7ecbe6b6..2ef1ac3543 100644 --- a/src/Init/LeanInit.lean +++ b/src/Init/LeanInit.lean @@ -212,6 +212,11 @@ match stx with | Syntax.atom _ v => mkNameSimple v | Syntax.ident _ _ _ _ => identKind +def updateKind (stx : Syntax) (k : SyntaxNodeKind) : Syntax := +match stx with +| Syntax.node _ args => Syntax.node k args +| _ => stx + def isOfKind : Syntax → SyntaxNodeKind → Bool | stx, k => stx.getKind == k @@ -770,6 +775,12 @@ match stx with Name.anonymous | _ => Name.anonymous +/- Given an `ident` or `Term.id`, return its name -/ +def getRelaxedId (stx : Syntax) : Name := +match stx with +| Syntax.ident _ _ _ _ => stx.getId +| _ => stx.getIdOfTermId + /-- Similar to `isTermId?`, but succeeds only if the optional part is a `none`. -/ def isSimpleTermId? (stx : Syntax) (relaxed : Bool := false) : Option Syntax := match stx.isTermId? relaxed with diff --git a/src/Lean/Elab/Binders.lean b/src/Lean/Elab/Binders.lean index 583147b66d..85efc747b4 100644 --- a/src/Lean/Elab/Binders.lean +++ b/src/Lean/Elab/Binders.lean @@ -416,41 +416,39 @@ else do }; pure $ mkApp f val +-- letIdLhs := ident >> checkWsBefore "expected space before binders" >> many bracketedBinder >> optType +def expandLetIdLhs (letIdLhs : Syntax) : Name × Array Syntax × Syntax := do +let id := (letIdLhs.getArg 0).getRelaxedId; -- allow `Term.id` to be used as an id for convenience of macro writers +let binders := (letIdLhs.getArg 1).getArgs; +let optType := letIdLhs.getArg 2; +let type := expandOptType letIdLhs optType; +(id, binders, type) + +def elabLetDeclCore (stx : Syntax) (expectedType? : Option Expr) (useLetExpr : Bool) : TermElabM Expr := do +let ref := stx; +let letDecl := (stx.getArg 1).getArg 0; +let body := stx.getArg 3; +if letDecl.getKind == `Lean.Parser.Term.letIdDecl then + let (id, binders, type) := expandLetIdLhs letDecl; + let val := letDecl.getArg 4; + elabLetDeclAux id binders type val body expectedType? useLetExpr +else if letDecl.getKind == `Lean.Parser.Term.letPatDecl then do + -- node `Lean.Parser.Term.letPatDecl $ try (termParser >> pushNone >> optType >> " := ") >> termParser + let pat := letDecl.getArg 0; + let optType := letDecl.getArg 2; + let type := expandOptType stx optType; + let val := letDecl.getArg 4; + stxNew ← `(let x : $type := $val; match x with $pat => $body); + let stxNew := if useLetExpr then stxNew else stxNew.updateKind `Lean.Parser.Term.«let!»; + withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? +else + throwError "WIP" + @[builtinTermElab «let»] def elabLetDecl : TermElab := -fun stx expectedType? => match_syntax stx with -| `(let $id:ident $args* := $val; $body) => - elabLetDeclAux id.getId args (mkHole stx) val body expectedType? true -| `(let $id:ident $args* : $type := $val; $body) => - elabLetDeclAux id.getId args type val body expectedType? true -| `(let $id:ident $args* | $alts:matchAlt*; $body) => - throwError "invalid let-expression with pattern matching, type must be provided" -| `(let $id:ident $args* : $type | $alts:matchAlt*; $body) => - throwError "WIP" -- TODO -| `(let $pat:term := $val; $body) => do - stxNew ← `(let x := $val; match x with $pat => $body); - withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -| `(let $pat:term : $type := $val; $body) => do - stxNew ← `(let x : $type := $val; match x with $pat => $body); - withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -| _ => throwUnsupportedSyntax +fun stx expectedType? => elabLetDeclCore stx expectedType? true @[builtinTermElab «let!»] def elabLetBangDecl : TermElab := -fun stx expectedType? => match_syntax stx with -| `(let! $id:ident $args* := $val; $body) => - elabLetDeclAux id.getId args (mkHole stx) val body expectedType? false -| `(let! $id:ident $args* : $type := $val; $body) => - elabLetDeclAux id.getId args type val body expectedType? false -| `(let! $id:ident $args* | $alts:matchAlt*; $body) => - throwError "invalid let-expression with pattern matching, type must be provided" -| `(let! $id:ident $args* : $type | $alts:matchAlt*; $body) => - throwError "WIP" -- TODO -| `(let! $pat:term := $val; $body) => do - stxNew ← `(let! x := $val; match x with $pat => $body); - withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -| `(let! $pat:term : $type := $val; $body) => do - stxNew ← `(let! x : $type := $val; match x with $pat => $body); - withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -| _ => throwUnsupportedSyntax +fun stx expectedType? => elabLetDeclCore stx expectedType? false @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Elab.let; diff --git a/tests/compiler/termparsertest1.lean.expected.out b/tests/compiler/termparsertest1.lean.expected.out index 667cbc9372..4e1ee853c0 100644 --- a/tests/compiler/termparsertest1.lean.expected.out +++ b/tests/compiler/termparsertest1.lean.expected.out @@ -278,47 +278,50 @@ f Type 1 let x := 0; x + 1 (Term.let "let" - (Term.letDecl `x [] [] ":=" (Term.num (numLit "0"))) + (Term.letDecl (Term.letIdDecl `x [] [] ":=" (Term.num (numLit "0")))) ";" (Term.add (Term.id `x []) "+" (Term.num (numLit "1")))) let x : Nat := 0; x + 1 (Term.let "let" - (Term.letDecl `x [] [(Term.typeSpec ":" (Term.id `Nat []))] ":=" (Term.num (numLit "0"))) + (Term.letDecl (Term.letIdDecl `x [] [(Term.typeSpec ":" (Term.id `Nat []))] ":=" (Term.num (numLit "0")))) ";" (Term.add (Term.id `x []) "+" (Term.num (numLit "1")))) let f (x : Nat) := x + 1; f 0 (Term.let "let" (Term.letDecl - `f - [(Term.explicitBinder "(" [`x] [":" (Term.id `Nat [])] [] ")")] - [] - ":=" - (Term.add (Term.id `x []) "+" (Term.num (numLit "1")))) + (Term.letIdDecl + `f + [(Term.explicitBinder "(" [`x] [":" (Term.id `Nat [])] [] ")")] + [] + ":=" + (Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))) ";" (Term.app (Term.id `f []) [(Term.num (numLit "0"))])) let f {α : Type} (a : α) : α := a; f 10 (Term.let "let" (Term.letDecl - `f - [(Term.implicitBinder "{" [`α] [":" (Term.type "Type" [])] "}") - (Term.explicitBinder "(" [`a] [":" (Term.id `α [])] [] ")")] - [(Term.typeSpec ":" (Term.id `α []))] - ":=" - (Term.id `a [])) + (Term.letIdDecl + `f + [(Term.implicitBinder "{" [`α] [":" (Term.type "Type" [])] "}") + (Term.explicitBinder "(" [`a] [":" (Term.id `α [])] [] ")")] + [(Term.typeSpec ":" (Term.id `α []))] + ":=" + (Term.id `a []))) ";" (Term.app (Term.id `f []) [(Term.num (numLit "10"))])) let f (x) := x + 1; f 10 + f 20 (Term.let "let" (Term.letDecl - `f - [(Term.explicitBinder "(" [`x] [] [] ")")] - [] - ":=" - (Term.add (Term.id `x []) "+" (Term.num (numLit "1")))) + (Term.letIdDecl + `f + [(Term.explicitBinder "(" [`x] [] [] ")")] + [] + ":=" + (Term.add (Term.id `x []) "+" (Term.num (numLit "1"))))) ";" (Term.add (Term.app (Term.id `f []) [(Term.num (numLit "10"))]) @@ -328,61 +331,66 @@ let (x, y) := f 10; x + y (Term.let "let" (Term.letDecl - (Term.paren "(" [(Term.id `x []) [(Term.tupleTail "," [(Term.id `y [])])]] ")") - [] - [] - ":=" - (Term.app (Term.id `f []) [(Term.num (numLit "10"))])) + (Term.letPatDecl + (Term.paren "(" [(Term.id `x []) [(Term.tupleTail "," [(Term.id `y [])])]] ")") + [] + [] + ":=" + (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))) ";" (Term.add (Term.id `x []) "+" (Term.id `y []))) let { fst := x, .. } := f 10; x + x (Term.let "let" (Term.letDecl - (Term.structInst "{" [] [(Term.structInstField `fst [] ":=" (Term.id `x [])) ","] [".."] [] "}") - [] - [] - ":=" - (Term.app (Term.id `f []) [(Term.num (numLit "10"))])) + (Term.letPatDecl + (Term.structInst "{" [] [(Term.structInstField `fst [] ":=" (Term.id `x [])) ","] [".."] [] "}") + [] + [] + ":=" + (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))) ";" (Term.add (Term.id `x []) "+" (Term.id `x []))) let x.y := f 10; x (Term.let "let" - (Term.letDecl `x.y [] [] ":=" (Term.app (Term.id `f []) [(Term.num (numLit "10"))])) + (Term.letDecl (Term.letIdDecl `x.y [] [] ":=" (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))) ";" (Term.id `x [])) let x.1 := f 10; x (Term.let "let" (Term.letDecl - (Term.proj (Term.id `x []) "." (fieldIdx "1")) - [] - [] - ":=" - (Term.app (Term.id `f []) [(Term.num (numLit "10"))])) + (Term.letPatDecl + (Term.proj (Term.id `x []) "." (fieldIdx "1")) + [] + [] + ":=" + (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))) ";" (Term.id `x [])) let x[i].y := f 10; x (Term.let "let" (Term.letDecl - (Term.proj (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") "." `y) - [] - [] - ":=" - (Term.app (Term.id `f []) [(Term.num (numLit "10"))])) + (Term.letPatDecl + (Term.proj (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") "." `y) + [] + [] + ":=" + (Term.app (Term.id `f []) [(Term.num (numLit "10"))]))) ";" (Term.id `x [])) let x[i] := f 20; x (Term.let "let" (Term.letDecl - (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") - [] - [] - ":=" - (Term.app (Term.id `f []) [(Term.num (numLit "20"))])) + (Term.letPatDecl + (Term.arrayRef (Term.id `x []) "[" (Term.id `i []) "]") + [] + [] + ":=" + (Term.app (Term.id `f []) [(Term.num (numLit "20"))]))) ";" (Term.id `x [])) -x + y @@ -408,7 +416,7 @@ do ";" (Term.doExpr (Term.app (Term.id `g []) [(Term.id `x [])])) ";" - (Term.doLet "let" (Term.letDecl `y [] [] ":=" (Term.app (Term.id `g []) [(Term.id `x [])]))) + (Term.doLet "let" (Term.letDecl (Term.letIdDecl `y [] [] ":=" (Term.app (Term.id `g []) [(Term.id `x [])])))) ";" (Term.doPat (Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")") @@ -419,11 +427,12 @@ do (Term.doLet "let" (Term.letDecl - (Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")") - [] - [] - ":=" - (Term.paren "(" [(Term.id `b []) [(Term.tupleTail "," [(Term.id `a [])])]] ")"))) + (Term.letPatDecl + (Term.paren "(" [(Term.id `a []) [(Term.tupleTail "," [(Term.id `b [])])]] ")") + [] + [] + ":=" + (Term.paren "(" [(Term.id `b []) [(Term.tupleTail "," [(Term.id `a [])])]] ")")))) ";" (Term.doExpr (Term.app (Term.id `pure []) [(Term.paren "(" [(Term.add (Term.id `a []) "+" (Term.id `b [])) []] ")")]))]) do { x ← f a; pure $ a + a } @@ -442,19 +451,20 @@ f 20 (Term.let "let" (Term.letDecl - `f - [] - [(Term.typeSpec ":" (Term.arrow (Term.id `Nat []) "→" (Term.arrow (Term.id `Nat []) "→" (Term.id `Nat []))))] - "|" - [(Term.matchAlt - [(Term.num (numLit "0")) "," (Term.id `a [])] - "=>" - (Term.add (Term.id `a []) "+" (Term.num (numLit "10")))) + (Term.letEqnsDecl + `f + [] + [(Term.typeSpec ":" (Term.arrow (Term.id `Nat []) "→" (Term.arrow (Term.id `Nat []) "→" (Term.id `Nat []))))] "|" - (Term.matchAlt - [(Term.add (Term.id `n []) "+" (Term.num (numLit "1"))) "," (Term.id `b [])] - "=>" - (Term.mul (Term.id `n []) "*" (Term.id `b [])))]) + [(Term.matchAlt + [(Term.num (numLit "0")) "," (Term.id `a [])] + "=>" + (Term.add (Term.id `a []) "+" (Term.num (numLit "10")))) + "|" + (Term.matchAlt + [(Term.add (Term.id `n []) "+" (Term.num (numLit "1"))) "," (Term.id `b [])] + "=>" + (Term.mul (Term.id `n []) "*" (Term.id `b [])))])) ";" (Term.app (Term.id `f []) [(Term.num (numLit "20"))])) max a b