fix: mkStructView

Reamrk: `1.2` is a numLit
This commit is contained in:
Leonardo de Moura 2020-02-18 20:22:04 -08:00
parent c434066f45
commit 0c13445da6
3 changed files with 35 additions and 17 deletions

View file

@ -298,25 +298,37 @@ def Field.toSyntax : Field Struct → Syntax
stx
| _ => unreachable!
private def toFieldLHS (stx : Syntax) : FieldLHS :=
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then FieldLHS.modifyOp stx (stx.getArg 1)
private def toFieldLHS (stx : Syntax) : Except String (List FieldLHS) :=
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then
pure $ [FieldLHS.modifyOp stx (stx.getArg 1)]
else
-- Note that the representation of the first field is different.
let stx := if stx.getKind == nullKind then stx.getArg 1 else stx;
if stx.isIdent then FieldLHS.fieldName stx stx.getId
if stx.isIdent then pure $ [FieldLHS.fieldName stx stx.getId]
else match stx.isNatLit? with
| some idx => FieldLHS.fieldIndex stx idx
| none => unreachable!
| some idx => pure $ [FieldLHS.fieldIndex stx idx]
| none => match stx.isLit? numLitKind with
| some val =>
let parts := val.split $ fun c => c == '.';
parts.mapM $ fun part =>
match Syntax.decodeNatLitVal part with
| some idx => pure $ FieldLHS.fieldIndex (mkStxNumLit (toString idx) stx.getHeadInfo) idx
| none => throw "unexpected structure syntax"
| none => throw "unexpected structure syntax"
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Struct :=
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do
let args := (stx.getArg 2).getArgs;
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField;
let fields := fieldsStx.toList.map $ fun fieldStx =>
fields ← fieldsStx.toList.mapM $ fun fieldStx => do {
let val := fieldStx.getArg 3;
let first := toFieldLHS (fieldStx.getArg 0);
let rest := (fieldStx.getArg 1).getArgs.toList.map $ toFieldLHS;
({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct);
⟨stx, structName, fields, source⟩
lhs ← toFieldLHS (fieldStx.getArg 0) | throw "unexpected structure syntax";
lhs ←
(fieldStx.getArg 1).getArgs.toList.foldlM
(fun lhs lhsStx => do lhsNew ← toFieldLHS lhsStx; pure (lhs ++ lhsNew))
lhs;
pure $ ({ref := fieldStx, lhs := lhs, val := FieldVal.term val } : Field Struct)
};
pure ⟨stx, structName, fields, source⟩
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
match s with
@ -807,11 +819,14 @@ structName ← getStructName stx expectedType? source;
env ← getEnv;
unless (isStructureLike env structName) $
throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
struct ← expandStruct $ mkStructView stx structName source;
trace `Elab.struct stx $ fun _ => toString struct;
(r, struct) ← elabStruct struct expectedType?;
DefaultFields.propagate struct;
pure r
match mkStructView stx structName source with
| Except.error ex => throwError stx ex
| Except.ok struct => do
struct ← expandStruct struct;
trace `Elab.struct stx $ fun _ => toString struct;
(r, struct) ← elabStruct struct expectedType?;
DefaultFields.propagate struct;
pure r
@[builtinTermElab structInst] def elabStructInst : TermElab :=
fun stx expectedType? => do

View file

@ -547,7 +547,7 @@ private partial def decodeDecimalLitAux (s : String) : String.Pos → Nat → Op
if '0' ≤ c && c ≤ '9' then decodeDecimalLitAux (s.next i) (10*val + c.toNat - '0'.toNat)
else none
private def decodeNatLitVal (s : String) : Option Nat :=
def decodeNatLitVal (s : String) : Option Nat :=
let len := s.length;
if len == 0 then none
else

View file

@ -20,3 +20,6 @@ def foo : Foo := {}
#check { x[1].2 := true, .. foo }
#check { x[1].fst.snd := 1, .. foo }
#check { x[1].1.fst := 1, .. foo }
#check { x[1].1.1 := 5, .. foo }
#check { x[1].1.2 := 5, .. foo }