From 0c13445da6535f47dec9540218c055f271f88525 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 18 Feb 2020 20:22:04 -0800 Subject: [PATCH] fix: `mkStructView` Reamrk: `1.2` is a numLit --- src/Init/Lean/Elab/StructInst.lean | 47 ++++++++++++++++++++---------- src/Init/LeanInit.lean | 2 +- tests/lean/run/structInst4.lean | 3 ++ 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/Init/Lean/Elab/StructInst.lean b/src/Init/Lean/Elab/StructInst.lean index 0d658cf4d1..7cac533930 100644 --- a/src/Init/Lean/Elab/StructInst.lean +++ b/src/Init/Lean/Elab/StructInst.lean @@ -298,25 +298,37 @@ def Field.toSyntax : Field Struct → Syntax stx | _ => unreachable! -private def toFieldLHS (stx : Syntax) : FieldLHS := -if stx.getKind == `Lean.Parser.Term.structInstArrayRef then FieldLHS.modifyOp stx (stx.getArg 1) +private def toFieldLHS (stx : Syntax) : Except String (List FieldLHS) := +if stx.getKind == `Lean.Parser.Term.structInstArrayRef then + pure $ [FieldLHS.modifyOp stx (stx.getArg 1)] else -- Note that the representation of the first field is different. let stx := if stx.getKind == nullKind then stx.getArg 1 else stx; - if stx.isIdent then FieldLHS.fieldName stx stx.getId + if stx.isIdent then pure $ [FieldLHS.fieldName stx stx.getId] else match stx.isNatLit? with - | some idx => FieldLHS.fieldIndex stx idx - | none => unreachable! + | some idx => pure $ [FieldLHS.fieldIndex stx idx] + | none => match stx.isLit? numLitKind with + | some val => + let parts := val.split $ fun c => c == '.'; + parts.mapM $ fun part => + match Syntax.decodeNatLitVal part with + | some idx => pure $ FieldLHS.fieldIndex (mkStxNumLit (toString idx) stx.getHeadInfo) idx + | none => throw "unexpected structure syntax" + | none => throw "unexpected structure syntax" -private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Struct := +private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do let args := (stx.getArg 2).getArgs; let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField; -let fields := fieldsStx.toList.map $ fun fieldStx => +fields ← fieldsStx.toList.mapM $ fun fieldStx => do { let val := fieldStx.getArg 3; - let first := toFieldLHS (fieldStx.getArg 0); - let rest := (fieldStx.getArg 1).getArgs.toList.map $ toFieldLHS; - ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct); -⟨stx, structName, fields, source⟩ + lhs ← toFieldLHS (fieldStx.getArg 0) | throw "unexpected structure syntax"; + lhs ← + (fieldStx.getArg 1).getArgs.toList.foldlM + (fun lhs lhsStx => do lhsNew ← toFieldLHS lhsStx; pure (lhs ++ lhsNew)) + lhs; + pure $ ({ref := fieldStx, lhs := lhs, val := FieldVal.term val } : Field Struct) +}; +pure ⟨stx, structName, fields, source⟩ def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct := match s with @@ -807,11 +819,14 @@ structName ← getStructName stx expectedType? source; env ← getEnv; unless (isStructureLike env structName) $ throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure"); -struct ← expandStruct $ mkStructView stx structName source; -trace `Elab.struct stx $ fun _ => toString struct; -(r, struct) ← elabStruct struct expectedType?; -DefaultFields.propagate struct; -pure r +match mkStructView stx structName source with +| Except.error ex => throwError stx ex +| Except.ok struct => do + struct ← expandStruct struct; + trace `Elab.struct stx $ fun _ => toString struct; + (r, struct) ← elabStruct struct expectedType?; + DefaultFields.propagate struct; + pure r @[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do diff --git a/src/Init/LeanInit.lean b/src/Init/LeanInit.lean index f0a6866ef2..56593e3250 100644 --- a/src/Init/LeanInit.lean +++ b/src/Init/LeanInit.lean @@ -547,7 +547,7 @@ private partial def decodeDecimalLitAux (s : String) : String.Pos → Nat → Op if '0' ≤ c && c ≤ '9' then decodeDecimalLitAux (s.next i) (10*val + c.toNat - '0'.toNat) else none -private def decodeNatLitVal (s : String) : Option Nat := +def decodeNatLitVal (s : String) : Option Nat := let len := s.length; if len == 0 then none else diff --git a/tests/lean/run/structInst4.lean b/tests/lean/run/structInst4.lean index 4b17f13b46..2a8c585c8d 100644 --- a/tests/lean/run/structInst4.lean +++ b/tests/lean/run/structInst4.lean @@ -20,3 +20,6 @@ def foo : Foo := {} #check { x[1].2 := true, .. foo } #check { x[1].fst.snd := 1, .. foo } #check { x[1].1.fst := 1, .. foo } + +#check { x[1].1.1 := 5, .. foo } +#check { x[1].1.2 := 5, .. foo }