fix: mkStructView
Reamrk: `1.2` is a numLit
This commit is contained in:
parent
c434066f45
commit
0c13445da6
3 changed files with 35 additions and 17 deletions
|
|
@ -298,25 +298,37 @@ def Field.toSyntax : Field Struct → Syntax
|
|||
stx
|
||||
| _ => unreachable!
|
||||
|
||||
private def toFieldLHS (stx : Syntax) : FieldLHS :=
|
||||
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then FieldLHS.modifyOp stx (stx.getArg 1)
|
||||
private def toFieldLHS (stx : Syntax) : Except String (List FieldLHS) :=
|
||||
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then
|
||||
pure $ [FieldLHS.modifyOp stx (stx.getArg 1)]
|
||||
else
|
||||
-- Note that the representation of the first field is different.
|
||||
let stx := if stx.getKind == nullKind then stx.getArg 1 else stx;
|
||||
if stx.isIdent then FieldLHS.fieldName stx stx.getId
|
||||
if stx.isIdent then pure $ [FieldLHS.fieldName stx stx.getId]
|
||||
else match stx.isNatLit? with
|
||||
| some idx => FieldLHS.fieldIndex stx idx
|
||||
| none => unreachable!
|
||||
| some idx => pure $ [FieldLHS.fieldIndex stx idx]
|
||||
| none => match stx.isLit? numLitKind with
|
||||
| some val =>
|
||||
let parts := val.split $ fun c => c == '.';
|
||||
parts.mapM $ fun part =>
|
||||
match Syntax.decodeNatLitVal part with
|
||||
| some idx => pure $ FieldLHS.fieldIndex (mkStxNumLit (toString idx) stx.getHeadInfo) idx
|
||||
| none => throw "unexpected structure syntax"
|
||||
| none => throw "unexpected structure syntax"
|
||||
|
||||
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Struct :=
|
||||
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do
|
||||
let args := (stx.getArg 2).getArgs;
|
||||
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField;
|
||||
let fields := fieldsStx.toList.map $ fun fieldStx =>
|
||||
fields ← fieldsStx.toList.mapM $ fun fieldStx => do {
|
||||
let val := fieldStx.getArg 3;
|
||||
let first := toFieldLHS (fieldStx.getArg 0);
|
||||
let rest := (fieldStx.getArg 1).getArgs.toList.map $ toFieldLHS;
|
||||
({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct);
|
||||
⟨stx, structName, fields, source⟩
|
||||
lhs ← toFieldLHS (fieldStx.getArg 0) | throw "unexpected structure syntax";
|
||||
lhs ←
|
||||
(fieldStx.getArg 1).getArgs.toList.foldlM
|
||||
(fun lhs lhsStx => do lhsNew ← toFieldLHS lhsStx; pure (lhs ++ lhsNew))
|
||||
lhs;
|
||||
pure $ ({ref := fieldStx, lhs := lhs, val := FieldVal.term val } : Field Struct)
|
||||
};
|
||||
pure ⟨stx, structName, fields, source⟩
|
||||
|
||||
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
|
||||
match s with
|
||||
|
|
@ -807,11 +819,14 @@ structName ← getStructName stx expectedType? source;
|
|||
env ← getEnv;
|
||||
unless (isStructureLike env structName) $
|
||||
throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
|
||||
struct ← expandStruct $ mkStructView stx structName source;
|
||||
trace `Elab.struct stx $ fun _ => toString struct;
|
||||
(r, struct) ← elabStruct struct expectedType?;
|
||||
DefaultFields.propagate struct;
|
||||
pure r
|
||||
match mkStructView stx structName source with
|
||||
| Except.error ex => throwError stx ex
|
||||
| Except.ok struct => do
|
||||
struct ← expandStruct struct;
|
||||
trace `Elab.struct stx $ fun _ => toString struct;
|
||||
(r, struct) ← elabStruct struct expectedType?;
|
||||
DefaultFields.propagate struct;
|
||||
pure r
|
||||
|
||||
@[builtinTermElab structInst] def elabStructInst : TermElab :=
|
||||
fun stx expectedType? => do
|
||||
|
|
|
|||
|
|
@ -547,7 +547,7 @@ private partial def decodeDecimalLitAux (s : String) : String.Pos → Nat → Op
|
|||
if '0' ≤ c && c ≤ '9' then decodeDecimalLitAux (s.next i) (10*val + c.toNat - '0'.toNat)
|
||||
else none
|
||||
|
||||
private def decodeNatLitVal (s : String) : Option Nat :=
|
||||
def decodeNatLitVal (s : String) : Option Nat :=
|
||||
let len := s.length;
|
||||
if len == 0 then none
|
||||
else
|
||||
|
|
|
|||
|
|
@ -20,3 +20,6 @@ def foo : Foo := {}
|
|||
#check { x[1].2 := true, .. foo }
|
||||
#check { x[1].fst.snd := 1, .. foo }
|
||||
#check { x[1].1.fst := 1, .. foo }
|
||||
|
||||
#check { x[1].1.1 := 5, .. foo }
|
||||
#check { x[1].1.2 := 5, .. foo }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue