From 2ad065eb93511cb2bf78562dac0fe638b6099002 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 2 Feb 2020 21:24:14 -0800 Subject: [PATCH] feat: add `expandNonAtomicExplicitSource` and `getStructName` --- src/Init/Lean/Elab/StructInst.lean | 161 +++++++++++++++++++++++++++++ src/Init/Lean/Elab/Term.lean | 10 ++ src/Init/Lean/Parser/Term.lean | 2 +- 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/src/Init/Lean/Elab/StructInst.lean b/src/Init/Lean/Elab/StructInst.lean index 11ca1d6595..24b424cc56 100644 --- a/src/Init/Lean/Elab/StructInst.lean +++ b/src/Init/Lean/Elab/StructInst.lean @@ -12,6 +12,167 @@ namespace Lean namespace Elab namespace Term +/- parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}" -/ + +namespace ExpandNonAtomicExplicitSource + +structure State := +(found : Bool := false) +(source? : Option Syntax := none) + +/- Auxiliary function for `expandNonAtomicExplicitSource` -/ +def main (stx : Syntax) : StateT State TermElabM (Option Syntax) := do +let args := (stx.getArg 2).getArgs; +args ← args.mapM $ fun arg => + if arg.getKind == `Lean.Parser.Term.structInstSource then do + -- parser! ".." >> optional termParser + s ← get; + if s.found then + liftM $ throwError arg "source has already been specified" + else + let optSource := arg.getArg 1; + if optSource.isNone then do + modify $ fun s => { found := true, .. s }; + pure arg + else do + let source := optSource.getArg 0; + fvar? ← liftM $ isLocalTermId? source; + match fvar? with + | some _ => do + -- it is already a local variable + modify $ fun s => { found := true, .. s }; + pure arg + | none => do + src ← `(src); + modify $ fun s => { found := true, source? := source, .. s }; + let optSource := optSource.setArg 0 src; + let arg := arg.setArg 1 optSource; + pure arg + else + pure arg; +s ← get; +match s.source? with +| none => pure none +| some source => do + let newStx := stx.setArg 2 (mkNullNode args); + `(let src := $source; $newStx) + +end ExpandNonAtomicExplicitSource + +/- +If `stx` is of the form `{ ... .. s }` and `s` is not a local variable, expand into `let src := s; { ... .. src}`. +-/ +def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) := +withFreshMacroScope $ (ExpandNonAtomicExplicitSource.main stx).run' {} + +inductive SourceView +| none -- structure instance source has not been provieded +| implicit -- `..` +| explicit (stx : Syntax) -- `.. term` + +def SourceView.isNone : SourceView → Bool +| SourceView.none => true +| _ => false + +private def getStructSource (stx : Syntax) : TermElabM SourceView := +let args := (stx.getArg 2).getArgs; +args.foldSepByM + (fun arg r => + if arg.getKind == `Lean.Parser.Term.structInstSource then + -- parser! ".." >> optional termParser + if !r.isNone then throwError arg "source has already been specified" + else + let arg := arg.getArg 1; + if arg.isNone then pure SourceView.implicit + else pure $ SourceView.explicit (arg.getArg 0) + else + pure r) + SourceView.none + +/- + We say a `{ ... }` notation is a `modifyOp` if it contains only one + ``` + def structInstArrayRef := parser! "[" >> termParser >>"]" + ``` -/ +private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do +let args := (stx.getArg 2).getArgs; +s? ← args.foldSepByM + (fun arg s? => + let k := arg.getKind; + if k == `Lean.Parser.Term.structInstSource then pure s? + else if k == `Lean.Parser.Term.structInstArrayRef then + match s? with + | none => pure (some arg) + | some s => + if s.getKind == `Lean.Parser.Term.structInstArrayRef then + throwError arg "invalid {...} notation, at most one `[..]` at a given level" + else + throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level" + else + match s? with + | none => pure (some arg) + | some s => + if s.getKind == `Lean.Parser.Term.structInstArrayRef then + throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level" + else + pure s?) + none; +match s? with +| none => pure none +| some s => if s.getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none + +private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := +throwError stx ("WIP " ++ stx) + +/- Get structure name and elaborate explicit source (if avialable) -/ +private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Name := +let arg := stx.getArg 1; +if !arg.isNone then do + pure $ arg.getIdAt 0 +else do + let ref := stx; + tryPostponeIfNoneOrMVar expectedType?; + let useSource : Unit → TermElabM Name := fun _ => + match sourceView with + | SourceView.explicit sourceStx => do + fvar? ← isLocalTermId? sourceStx; + match fvar? with + | none => unreachable! + | some fvar => do + fvarType ← inferType stx fvar; + fvarType ← whnf stx fvarType; + tryPostponeIfMVar fvarType; + match fvarType.getAppFn with + | Expr.const constName _ _ => pure constName + | _ => throwError stx ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr fvarType) + | _ => throwError ref ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType?.get!); + match expectedType? with + | none => useSource () + | some expectedType => do + expectedType ← whnf ref expectedType; + match expectedType.getAppFn with + | Expr.const constName _ _ => pure constName + | _ => useSource () + +private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Expr := do +structName ← getStructName stx expectedType? sourceView; +env ← getEnv; +unless (isStructureLike env structName) $ + throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure"); +throwError stx ("WIP " ++ toString structName ++ toString stx) + +@[builtinTermElab structInst] def elabStructInst : TermElab := +fun stx expectedType? => do + stxNew? ← expandNonAtomicExplicitSource stx; + match stxNew? with + | some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? + | none => do + sourceView ← getStructSource stx; + modifyOp? ← isModifyOp? stx; + match modifyOp?, sourceView with + | some modifyOp, SourceView.explicit source => elabModifyOp stx modifyOp source expectedType? + | some _, _ => throwError stx ("invalid {...} notation, explicit source is required when using '[] := '") + | _, _ => elabStructInstAux stx expectedType? sourceView end Term end Elab diff --git a/src/Init/Lean/Elab/Term.lean b/src/Init/Lean/Elab/Term.lean index 92fbf31099..7845caecee 100644 --- a/src/Init/Lean/Elab/Term.lean +++ b/src/Init/Lean/Elab/Term.lean @@ -802,6 +802,16 @@ private def resolveLocalName (n : Name) : TermElabM (Option (Expr × List String lctx ← getLCtx; pure $ resolveLocalNameAux lctx n [] +/- Return true iff `stx` is a `Term.id`, and it is local variable. -/ +def isLocalTermId? (stx : Syntax) : TermElabM (Option Expr) := +match stx.isTermId? with +| some (Syntax.ident _ _ val _, _) => do + r? ← resolveLocalName val; + match r? with + | some (fvar, []) => pure (some fvar) + | _ => pure none +| _ => pure none + private def mkFreshLevelMVars (ref : Syntax) (num : Nat) : TermElabM (List Level) := num.foldM (fun _ us => do u ← mkFreshLevelMVar ref; pure $ u::us) [] diff --git a/src/Init/Lean/Parser/Term.lean b/src/Init/Lean/Parser/Term.lean index b7f97e177e..6f96dea0b3 100644 --- a/src/Init/Lean/Parser/Term.lean +++ b/src/Init/Lean/Parser/Term.lean @@ -59,7 +59,7 @@ def haveAssign := parser! " := " >> termParser @[builtinTermParser] def «show» := parser! "show " >> termParser >> fromTerm @[builtinTermParser] def «fun» := parser! unicodeSymbol "λ" "fun" >> many1 (termParser appPrec) >> darrow >> termParser def structInstArrayRef := parser! "[" >> termParser >>"]" -def structInstLVal := (ident <|> structInstArrayRef) >> many (("." >> ident) <|> structInstArrayRef) +def structInstLVal := (ident <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef) def structInstField := parser! structInstLVal >> " := " >> termParser def structInstSource := parser! ".." >> optional termParser @[builtinTermParser] def structInst := parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}"