refactor: avoid nested sequence in simpleBinder

This commit is contained in:
Sebastian Ullrich 2022-07-07 16:37:38 +02:00
parent 75b0b50983
commit d7bcc271be
10 changed files with 51 additions and 48 deletions

View file

@ -106,11 +106,9 @@ private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) :=
private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) := do
let k := stx.getKind
if k == ``Lean.Parser.Term.simpleBinder then
-- binderIdent+ >> optType
let ids ← getBinderIds stx[0]
let optType := stx[1]
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := expandOptType id optType, bi := BinderInfo.default }
if stx.isIdent || k == ``hole then
-- binderIdent
pure #[{ id := (← expandBinderIdent stx), type := mkHole stx, bi := BinderInfo.default }]
else if k == ``Lean.Parser.Term.explicitBinder then
-- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)`
let ids ← getBinderIds stx[1]
@ -191,7 +189,7 @@ def elabBindersEx {α} (binders : Array Syntax) (k : Array (Syntax × Expr) →
elabBindersAux binders k
/--
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracketedBinder`),
Elaborate the given binders (i.e., `Syntax` objects for `bracketedBinder`),
update the local context, set of local instances, reset instance chache (if needed), and then
execute `k` with the updated context.
The local context will only be included inside `k`.
@ -206,6 +204,18 @@ def elabBinders (binders : Array Syntax) (k : Array Expr → TermElabM α) : Ter
def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
elabBinders #[binder] fun fvars => x fvars[0]!
def expandSimpleBinderWithType (type : Term) (binder : Syntax) : MacroM Syntax :=
if binder.isOfKind ``hole || binder.isIdent then
`(bracketedBinder| ($binder : $type))
else
Macro.throwErrorAt type "unexpected type ascription"
@[builtinMacro Lean.Parser.Term.forall] def expandForall : Macro
| `(forall $binders* : $ty, $term) => do
let binders ← binders.mapM (expandSimpleBinderWithType ty)
`(forall $binders*, $term)
| _ => Macro.throwUnsupported
@[builtinTermElab «forall»] def elabForall : TermElab := fun stx _ =>
match stx with
| `(forall $binders*, $term) =>
@ -286,17 +296,13 @@ partial def expandFunBinders (binders : Array Syntax) (body : Syntax) : MacroM (
let (binders, newBody, _) ← loop body (i+1) (newBinders.push $ mkExplicitBinder major (mkHole binder))
let newBody ← `(match $major:ident with | $pattern => $newBody)
pure (binders, newBody, true)
match binder with
| Syntax.node _ ``Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node _ ``Lean.Parser.Term.strictImplicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node _ ``Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node _ ``Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node _ ``Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node _ ``Lean.Parser.Term.hole _ =>
let ident ← mkFreshIdent binder
let type := binder
loop body (i+1) (newBinders.push <| mkExplicitBinder ident type)
| Syntax.node _ ``Lean.Parser.Term.paren _ =>
match binder.getKind with
| ``Lean.Parser.Term.implicitBinder
| ``Lean.Parser.Term.strictImplicitBinder
| ``Lean.Parser.Term.instBinder
| ``Lean.Parser.Term.explicitBinder
| ``Lean.Parser.Term.hole | `ident => loop body (i+1) (newBinders.push binder)
| ``Lean.Parser.Term.paren =>
-- `(` (termParser >> parenSpecial)? `)`
-- parenSpecial := (tupleTail <|> typeAscription)?
let binderBody := binder[1]
@ -327,9 +333,6 @@ partial def expandFunBinders (binders : Array Syntax) (body : Syntax) : MacroM (
match (← getFunBinderIds? term) with
| some idents => loop body (i+1) (newBinders ++ idents.map (fun ident => mkExplicitBinder ident type))
| none => processAsPattern ()
| Syntax.ident .. =>
let type := mkHole binder
loop body (i+1) (newBinders.push <| mkExplicitBinder binder type)
| _ => processAsPattern ()
else
pure (newBinders, body, false)
@ -541,7 +544,10 @@ def expandMatchAltsWhereDecls (matchAltsWhereDecls : Syntax) : MacroM Syntax :=
loop (getMatchAltsNumPatterns matchAlts) #[]
@[builtinMacro Lean.Parser.Term.fun] partial def expandFun : Macro
| `(fun $binders* => $body) => do
| `(fun $binders* : $ty => $body) => do
let binders ← binders.mapM (expandSimpleBinderWithType ty)
`(fun $binders* => $body)
| `(fun $binders* => $body) => do -- if there is a type ascription, we assume all binders are already simple
let (binders, body, expandedPattern) ← expandFunBinders binders body
if expandedPattern then
`(fun $binders* => $body)
@ -552,7 +558,7 @@ def expandMatchAltsWhereDecls (matchAltsWhereDecls : Syntax) : MacroM Syntax :=
open Lean.Elab.Term.Quotation in
@[builtinQuotPrecheck Lean.Parser.Term.fun] def precheckFun : Precheck
| `(fun $binders* => $body) => do
| `(fun $binders* $[: $ty?]? => $body) => do
let (binders, body, _) ← liftMacroM <| expandFunBinders binders body
let mut ids := #[]
for b in binders do

View file

@ -49,7 +49,7 @@ private def letDeclArgHasBinders (letDeclArg : Syntax) : Bool :=
else if k == ``Lean.Parser.Term.letEqnsDecl then
true
else if k == ``Lean.Parser.Term.letIdDecl then
-- letIdLhs := ident >> checkWsBefore "expected space before binders" >> many (ppSpace >> (simpleBinderWithoutType <|> bracketedBinder)) >> optType
-- letIdLhs := ident >> checkWsBefore "expected space before binders" >> many (ppSpace >> letIdBinder)) >> optType
let binders := letDeclArg[1]
binders.getNumArgs > 0
else
@ -615,7 +615,7 @@ def getDoHaveVars (doHave : Syntax) : TermElabM (Array Var) := do
let arg := doHave[1][0]
if arg.getKind == ``Lean.Parser.Term.haveIdDecl then
-- haveIdDecl := leading_parser atomic (haveIdLhs >> " := ") >> termParser
-- haveIdLhs := optional (ident >> many (ppSpace >> (simpleBinderWithoutType <|> bracketedBinder))) >> optType
-- haveIdLhs := optional (ident >> many (ppSpace >> letIdBinder)) >> optType
return #[← getHaveIdLhsVar arg[0]]
else if arg.getKind == ``Lean.Parser.Term.letPatDecl then
getLetPatDeclVars arg

View file

@ -85,8 +85,8 @@ private def isMultiConstant? (views : Array DefView) : Option (List Name) :=
if views.size == 1 &&
views[0]!.kind == DefKind.opaque &&
views[0]!.binders.getArgs.size > 0 &&
views[0]!.binders.getArgs.all (·.getKind == ``Parser.Term.simpleBinder) then
some <| (views[0]!.binders.getArgs.toList.map (fun stx => stx[0].getArgs.toList.map (·.getId))).join
views[0]!.binders.getArgs.all (·.isIdent) then
some (views[0]!.binders.getArgs.toList.map (·.getId))
else
none

View file

@ -55,8 +55,8 @@ def «partial» := leading_parser "partial "
def «nonrec» := leading_parser "nonrec "
def declModifiers (inline : Bool) := leading_parser optional docComment >> optional (Term.«attributes» >> if inline then skip else ppDedent ppLine) >> optional visibility >> optional «noncomputable» >> optional «unsafe» >> optional («partial» <|> «nonrec»)
def declId := leading_parser ident >> optional (".{" >> sepBy1 ident ", " >> "}")
def declSig := leading_parser many (ppSpace >> (Term.simpleBinderWithoutType <|> Term.bracketedBinder)) >> Term.typeSpec
def optDeclSig := leading_parser many (ppSpace >> (Term.simpleBinderWithoutType <|> Term.bracketedBinder)) >> Term.optType
def declSig := leading_parser many (ppSpace >> (Term.binderIdent <|> Term.bracketedBinder)) >> Term.typeSpec
def optDeclSig := leading_parser many (ppSpace >> (Term.binderIdent <|> Term.bracketedBinder)) >> Term.optType
def declValSimple := leading_parser " :=" >> ppHardLineUnlessUngrouped >> termParser >> optional Term.whereDecls
def declValEqns := leading_parser Term.matchAltsWhereDecls
def whereStructField := leading_parser Term.letDecl

View file

@ -131,9 +131,8 @@ Note that we did not add a `explicitShortBinder` parser since `(α) → α
-/
@[builtinTermParser] def depArrow := leading_parser:25 bracketedBinder true >> unicodeSymbol " → " " -> " >> termParser
def simpleBinder := leading_parser many1 binderIdent >> optType
@[builtinTermParser]
def «forall» := leading_parser:leadPrec unicodeSymbol "∀" "forall" >> many1 (ppSpace >> (simpleBinder <|> bracketedBinder)) >> ", " >> termParser
def «forall» := leading_parser:leadPrec unicodeSymbol "∀" "forall" >> many1 (ppSpace >> (binderIdent <|> bracketedBinder)) >> optType >> ", " >> termParser
def matchAlt (rhsParser : Parser := termParser) : Parser :=
leading_parser (withAnonymousAntiquot := false)
@ -162,10 +161,9 @@ def motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := "
def funImplicitBinder := withAntiquot (mkAntiquot "implicitBinder" ``implicitBinder) <| atomic (lookahead ("{" >> many1 binderIdent >> (symbol " : " <|> "}"))) >> implicitBinder
def funStrictImplicitBinder := atomic (lookahead (strictImplicitLeftBracket >> many1 binderIdent >> (symbol " : " <|> strictImplicitRightBracket))) >> strictImplicitBinder
def funSimpleBinder := withAntiquot (mkAntiquot "simpleBinder" ``simpleBinder) <| atomic (lookahead (many1 binderIdent >> " : ")) >> simpleBinder
def funBinder : Parser := withAntiquot (mkAntiquot "funBinder" `Lean.Parser.Term.funBinder (isPseudoKind := true)) (funStrictImplicitBinder <|> funImplicitBinder <|> instBinder <|> funSimpleBinder <|> termParser maxPrec)
def funBinder : Parser := withAntiquot (mkAntiquot "funBinder" `Lean.Parser.Term.funBinder (isPseudoKind := true)) (funStrictImplicitBinder <|> funImplicitBinder <|> instBinder <|> termParser maxPrec)
-- NOTE: we disable anonymous antiquotations to ensure that `fun $b => ...` remains a `term` antiquotation
def basicFun : Parser := leading_parser (withAnonymousAntiquot := false) ppGroup (many1 (ppSpace >> funBinder) >> " =>") >> ppSpace >> termParser
def basicFun : Parser := leading_parser (withAnonymousAntiquot := false) ppGroup (many1 (ppSpace >> funBinder) >> optType >> " =>") >> ppSpace >> termParser
@[builtinTermParser] def «fun» := leading_parser:maxPrec ppAllowUngrouped >> unicodeSymbol "λ" "fun" >> (basicFun <|> matchAlts)
def optExprPrecedence := optional (atomic ":" >> termParser maxPrec)
@ -179,11 +177,7 @@ def withAnonymousAntiquot := leading_parser atomic ("(" >> nonReservedSymbol "wi
-- note that we cannot use ```"``"``` as a new token either because it would break `precheckedQuot`
@[builtinTermParser] def doubleQuotedName := leading_parser "`" >> checkNoWsBefore >> rawCh '`' (trailingWs := false) >> ident
-- same shape and (antiquotation) kind as `simpleBinder`
def simpleBinderWithoutType := nodeWithAntiquot "simpleBinder" ``simpleBinder (anonymous := true)
(many1 binderIdent >> pushNone)
def letIdBinder := withAntiquot (mkAntiquot "letIdBinder" `Lean.Parser.Term.letIdBinder (isPseudoKind := true)) (simpleBinderWithoutType <|> bracketedBinder)
def letIdBinder := withAntiquot (mkAntiquot "letIdBinder" `Lean.Parser.Term.letIdBinder (isPseudoKind := true)) (binderIdent <|> bracketedBinder)
/- Remark: we use `checkWsBefore` to ensure `let x[i] := e; b` is not parsed as `let x [i] := e; b` where `[i]` is an `instBinder`. -/
def letIdLhs : Parser := ident >> notFollowedBy (checkNoWsBefore "" >> "[") "space is required before instance '[...]' binders to distinguish them from array updates `let x[i] := e; ...`" >> many (ppSpace >> letIdBinder) >> optType
def letIdDecl := leading_parser (withAnonymousAntiquot := false) atomic (letIdLhs >> " := ") >> termParser
@ -210,7 +204,7 @@ def letDecl := leading_parser (withAnonymousAntiquot := false) notFollowedBy
@[builtinTermParser] def «let_tmp» := leading_parser:leadPrec withPosition ("let_tmp " >> letDecl) >> optSemicolon termParser
instance : Coe (TSyntax ``letIdBinder) (TSyntax ``funBinder) where
coe stx := ⟨stx⟩ -- `simpleBinderWithoutType` prevents using a proper quotation for this
coe stx := ⟨stx⟩
-- like `let_fun` but with optional name
def haveIdLhs := optional (ident >> many (ppSpace >> letIdBinder)) >> optType

@ -1 +1 @@
Subproject commit b651e87444265cc698ee33eb0a379ac0a38562d5
Subproject commit bb2c8669791fbb3916dbd8b82a48abdc9b127ef1

View file

@ -8,7 +8,7 @@ options get_default_options() {
// switch to `true` for ABI-breaking changes affecting meta code
opts = opts.update({"interpreter", "prefer_native"}, false);
// switch to `true` for changing built-in parsers used in quotations
opts = opts.update({"internal", "parseQuotWithCurrentStage"}, false);
opts = opts.update({"internal", "parseQuotWithCurrentStage"}, true);
opts = opts.update({"pp", "rawOnError"}, true);
#endif
return opts;

View file

@ -7,7 +7,7 @@ StxQuot.lean:8:12: error: expected identifier or term
"(«term_+_» <missing> \"+\" (num \"1\"))"
"(«term_+_» (num \"1\") \"+\" (num \"1\"))"
StxQuot.lean:19:15: error: expected term
"(Term.fun \"fun\" (Term.basicFun [`a._@.UnhygienicMain._hyg.1] \"=>\" `a._@.UnhygienicMain._hyg.1))"
"(Term.fun \"fun\" (Term.basicFun [`a._@.UnhygienicMain._hyg.1] [] \"=>\" `a._@.UnhygienicMain._hyg.1))"
"(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (num \"1\") [])\n []\n []\n []))"
"[(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (num \"1\") [])\n []\n []\n []))\n (Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `bar._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (num \"2\") [])\n []\n []\n []))]"
"`Nat.one._@.UnhygienicMain._hyg.1"
@ -24,11 +24,11 @@ StxQuot.lean:19:15: error: expected term
0
1
"1"
"(Term.fun\n \"fun\"\n (Term.basicFun\n [`a._@.UnhygienicMain._hyg.1\n (Term.paren \"(\" [`b._@.UnhygienicMain._hyg.1 [(Term.typeAscription \":\" `Nat._@.UnhygienicMain._hyg.1)]] \")\")]\n \"=>\"\n (num \"1\")))"
"(Term.fun\n \"fun\"\n (Term.basicFun\n [`a._@.UnhygienicMain._hyg.1\n (Term.paren \"(\" [`b._@.UnhygienicMain._hyg.1 [(Term.typeAscription \":\" `Nat._@.UnhygienicMain._hyg.1)]] \")\")]\n []\n \"=>\"\n (num \"1\")))"
"#[(Term.paren \"(\" [`a._@.UnhygienicMain._hyg.1 [(Term.typeAscription \":\" `Nat._@.UnhygienicMain._hyg.1)]] \")\"), `b._@.UnhygienicMain._hyg.1]"
"`a._@.UnhygienicMain._hyg.1"
"(Term.forall \"∀\" [(Term.simpleBinder [(Term.hole \"_\")] [])] \",\" `c._@.UnhygienicMain._hyg.1)"
"(Term.simpleBinder [(Term.hole \"_\")] [])"
"(Term.forall \"∀\" [(Term.hole \"_\")] [] \",\" `c._@.UnhygienicMain._hyg.1)"
"(Term.hole \"_\")"
"`a._@.UnhygienicMain._hyg.1"
"(Term.explicitUniv `a._@.UnhygienicMain._hyg.1 \".{\" [(num \"0\")] \"}\")"
"#[(Term.matchAlt \"|\" [[`a._@.UnhygienicMain._hyg.1]] \"=>\" (num \"1\")), (Term.matchAlt \"|\" [[(Term.hole \"_\")]] \"=>\" (num \"2\"))]"
@ -50,8 +50,8 @@ StxQuot.lean:102:13-102:14: error: unknown identifier 'x' at quotation precheck;
"`id._@.UnhygienicMain._hyg.1"
"`pure._@.UnhygienicMain._hyg.1"
"(termFoo_ \"foo\" <missing>)"
"(Term.fun \"fun\" (Term.basicFun [`x._@.UnhygienicMain._hyg.1] \"=>\" `x._@.UnhygienicMain._hyg.1))"
"(Term.fun \"fun\" (Term.basicFun [`x._@.UnhygienicMain._hyg.1] [] \"=>\" `x._@.UnhygienicMain._hyg.1))"
StxQuot.lean:108:22-108:23: error: unknown identifier 'y' at quotation precheck; you can use `set_option quotPrecheck false` to disable this check.
"(Term.fun\n \"fun\"\n (Term.basicFun\n [`x._@.UnhygienicMain._hyg.1 `y._@.UnhygienicMain._hyg.1]\n \"=>\"\n (Term.app `x._@.UnhygienicMain._hyg.1 [`y._@.UnhygienicMain._hyg.1])))"
"(Term.fun\n \"fun\"\n (Term.basicFun\n [(Term.anonymousCtor \"⟨\" [`x._@.UnhygienicMain._hyg.1 \",\" `y._@.UnhygienicMain._hyg.1] \"⟩\")]\n \"=>\"\n `x._@.UnhygienicMain._hyg.1))"
"(Term.fun\n \"fun\"\n (Term.basicFun\n [`x._@.UnhygienicMain._hyg.1 `y._@.UnhygienicMain._hyg.1]\n []\n \"=>\"\n (Term.app `x._@.UnhygienicMain._hyg.1 [`y._@.UnhygienicMain._hyg.1])))"
"(Term.fun\n \"fun\"\n (Term.basicFun\n [(Term.anonymousCtor \"⟨\" [`x._@.UnhygienicMain._hyg.1 \",\" `y._@.UnhygienicMain._hyg.1] \"⟩\")]\n []\n \"=>\"\n `x._@.UnhygienicMain._hyg.1))"
"1"

2
tests/lean/fun.lean Normal file
View file

@ -0,0 +1,2 @@
-- reject to avoid confusion with `fun x : Nat =>`
#check fun (x : Nat) : Nat => x

View file

@ -0,0 +1 @@
fun.lean:2:23-2:26: error: unexpected type ascription