feat: add expandNonAtomicExplicitSource and getStructName

This commit is contained in:
Leonardo de Moura 2020-02-02 21:24:14 -08:00
parent 7b1a4272f8
commit 2ad065eb93
3 changed files with 172 additions and 1 deletions

View file

@ -12,6 +12,167 @@ namespace Lean
namespace Elab
namespace Term
/- parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}" -/
namespace ExpandNonAtomicExplicitSource
structure State :=
(found : Bool := false)
(source? : Option Syntax := none)
/- Auxiliary function for `expandNonAtomicExplicitSource` -/
def main (stx : Syntax) : StateT State TermElabM (Option Syntax) := do
let args := (stx.getArg 2).getArgs;
args ← args.mapM $ fun arg =>
if arg.getKind == `Lean.Parser.Term.structInstSource then do
-- parser! ".." >> optional termParser
s ← get;
if s.found then
liftM $ throwError arg "source has already been specified"
else
let optSource := arg.getArg 1;
if optSource.isNone then do
modify $ fun s => { found := true, .. s };
pure arg
else do
let source := optSource.getArg 0;
fvar? ← liftM $ isLocalTermId? source;
match fvar? with
| some _ => do
-- it is already a local variable
modify $ fun s => { found := true, .. s };
pure arg
| none => do
src ← `(src);
modify $ fun s => { found := true, source? := source, .. s };
let optSource := optSource.setArg 0 src;
let arg := arg.setArg 1 optSource;
pure arg
else
pure arg;
s ← get;
match s.source? with
| none => pure none
| some source => do
let newStx := stx.setArg 2 (mkNullNode args);
`(let src := $source; $newStx)
end ExpandNonAtomicExplicitSource
/-
If `stx` is of the form `{ ... .. s }` and `s` is not a local variable, expand into `let src := s; { ... .. src}`.
-/
def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
withFreshMacroScope $ (ExpandNonAtomicExplicitSource.main stx).run' {}
inductive SourceView
| none -- structure instance source has not been provieded
| implicit -- `..`
| explicit (stx : Syntax) -- `.. term`
def SourceView.isNone : SourceView → Bool
| SourceView.none => true
| _ => false
private def getStructSource (stx : Syntax) : TermElabM SourceView :=
let args := (stx.getArg 2).getArgs;
args.foldSepByM
(fun arg r =>
if arg.getKind == `Lean.Parser.Term.structInstSource then
-- parser! ".." >> optional termParser
if !r.isNone then throwError arg "source has already been specified"
else
let arg := arg.getArg 1;
if arg.isNone then pure SourceView.implicit
else pure $ SourceView.explicit (arg.getArg 0)
else
pure r)
SourceView.none
/-
We say a `{ ... }` notation is a `modifyOp` if it contains only one
```
def structInstArrayRef := parser! "[" >> termParser >>"]"
``` -/
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
let args := (stx.getArg 2).getArgs;
s? ← args.foldSepByM
(fun arg s? =>
let k := arg.getKind;
if k == `Lean.Parser.Term.structInstSource then pure s?
else if k == `Lean.Parser.Term.structInstArrayRef then
match s? with
| none => pure (some arg)
| some s =>
if s.getKind == `Lean.Parser.Term.structInstArrayRef then
throwError arg "invalid {...} notation, at most one `[..]` at a given level"
else
throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level"
else
match s? with
| none => pure (some arg)
| some s =>
if s.getKind == `Lean.Parser.Term.structInstArrayRef then
throwError arg "invalid {...} notation, can't mix field and `[..]` at a given level"
else
pure s?)
none;
match s? with
| none => pure none
| some s => if s.getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
throwError stx ("WIP " ++ stx)
/- Get structure name and elaborate explicit source (if avialable) -/
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Name :=
let arg := stx.getArg 1;
if !arg.isNone then do
pure $ arg.getIdAt 0
else do
let ref := stx;
tryPostponeIfNoneOrMVar expectedType?;
let useSource : Unit → TermElabM Name := fun _ =>
match sourceView with
| SourceView.explicit sourceStx => do
fvar? ← isLocalTermId? sourceStx;
match fvar? with
| none => unreachable!
| some fvar => do
fvarType ← inferType stx fvar;
fvarType ← whnf stx fvarType;
tryPostponeIfMVar fvarType;
match fvarType.getAppFn with
| Expr.const constName _ _ => pure constName
| _ => throwError stx ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr fvarType)
| _ => throwError ref ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType?.get!);
match expectedType? with
| none => useSource ()
| some expectedType => do
expectedType ← whnf ref expectedType;
match expectedType.getAppFn with
| Expr.const constName _ _ => pure constName
| _ => useSource ()
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (sourceView : SourceView) : TermElabM Expr := do
structName ← getStructName stx expectedType? sourceView;
env ← getEnv;
unless (isStructureLike env structName) $
throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
throwError stx ("WIP " ++ toString structName ++ toString stx)
@[builtinTermElab structInst] def elabStructInst : TermElab :=
fun stx expectedType? => do
stxNew? ← expandNonAtomicExplicitSource stx;
match stxNew? with
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| none => do
sourceView ← getStructSource stx;
modifyOp? ← isModifyOp? stx;
match modifyOp?, sourceView with
| some modifyOp, SourceView.explicit source => elabModifyOp stx modifyOp source expectedType?
| some _, _ => throwError stx ("invalid {...} notation, explicit source is required when using '[<index>] := <value>'")
| _, _ => elabStructInstAux stx expectedType? sourceView
end Term
end Elab

View file

@ -802,6 +802,16 @@ private def resolveLocalName (n : Name) : TermElabM (Option (Expr × List String
lctx ← getLCtx;
pure $ resolveLocalNameAux lctx n []
/- Return true iff `stx` is a `Term.id`, and it is local variable. -/
def isLocalTermId? (stx : Syntax) : TermElabM (Option Expr) :=
match stx.isTermId? with
| some (Syntax.ident _ _ val _, _) => do
r? ← resolveLocalName val;
match r? with
| some (fvar, []) => pure (some fvar)
| _ => pure none
| _ => pure none
private def mkFreshLevelMVars (ref : Syntax) (num : Nat) : TermElabM (List Level) :=
num.foldM (fun _ us => do u ← mkFreshLevelMVar ref; pure $ u::us) []

View file

@ -59,7 +59,7 @@ def haveAssign := parser! " := " >> termParser
@[builtinTermParser] def «show» := parser! "show " >> termParser >> fromTerm
@[builtinTermParser] def «fun» := parser! unicodeSymbol "λ" "fun" >> many1 (termParser appPrec) >> darrow >> termParser
def structInstArrayRef := parser! "[" >> termParser >>"]"
def structInstLVal := (ident <|> structInstArrayRef) >> many (("." >> ident) <|> structInstArrayRef)
def structInstLVal := (ident <|> structInstArrayRef) >> many (("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstField := parser! structInstLVal >> " := " >> termParser
def structInstSource := parser! ".." >> optional termParser
@[builtinTermParser] def structInst := parser! symbol "{" appPrec >> optional (try (ident >> " . ")) >> sepBy (structInstField <|> structInstSource) ", " true >> "}"