From e02a06ad1c4b05ee7242e60969e81fd56e93cd41 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 16 Oct 2020 08:40:42 -0700 Subject: [PATCH] chore: move to new frontend --- src/Init/Data/List/Control.lean | 4 +- src/Lean/Elab/StructInst.lean | 608 +++++++++++++++----------------- src/Std/Data/HashMap.lean | 8 +- 3 files changed, 294 insertions(+), 326 deletions(-) diff --git a/src/Init/Data/List/Control.lean b/src/Init/Data/List/Control.lean index 7f2912f5bd..6c4911de0c 100644 --- a/src/Init/Data/List/Control.lean +++ b/src/Init/Data/List/Control.lean @@ -108,14 +108,14 @@ def filterMapM {m : Type u → Type v} [Monad m] {α β : Type u} (f : α → m filterMapMAux f as.reverse [] @[specialize] -def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (s → α → m s) → s → List α → m s +def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : forall (f : s → α → m s) (init : s), List α → m s | f, s, [] => pure s | f, s, h :: r => do s' ← f s h; foldlM f s' r @[specialize] -def foldrM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (α → s → m s) → s → List α → m s +def foldrM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : forall (f : α → s → m s) (init : s), List α → m s | f, s, [] => pure s | f, s, h :: r => do s' ← foldrM f s r; diff --git a/src/Lean/Elab/StructInst.lean b/src/Lean/Elab/StructInst.lean index 986213456c..3b0d1689f3 100644 --- a/src/Lean/Elab/StructInst.lean +++ b/src/Lean/Elab/StructInst.lean @@ -1,3 +1,4 @@ +#lang lean4 /- Copyright (c) 2020 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. @@ -8,10 +9,7 @@ import Lean.Elab.App import Lean.Elab.Binders import Lean.Elab.Quotation -namespace Lean -namespace Elab -namespace Term -namespace StructInst +namespace Lean.Elab.Term.StructInst open Std (HashMap) open Meta @@ -20,12 +18,12 @@ open Meta @[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro := fun stx => - let expectedArg := stx.getArg 4; + let expectedArg := stx[4] if expectedArg.isNone then Macro.throwUnsupported else - let expected := expectedArg.getArg 1; - let stxNew := stx.setArg 4 mkNullNode; + let expected := expectedArg[1] + let stxNew := stx.setArg 4 mkNullNode `(($stxNew : $expected)) /- @@ -34,17 +32,18 @@ If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand Note that this one is not a `Macro` because we need to access the local context. -/ private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) := -withFreshMacroScope $ - let sourceOpt := stx.getArg 1; - if sourceOpt.isNone then pure none else do - let source := sourceOpt.getArg 0; - fvar? ← isLocalIdent? source; - match fvar? with +withFreshMacroScope do + let sourceOpt := stx[1] + if sourceOpt.isNone then + pure none + else + let source := sourceOpt[0] + match (← isLocalIdent? source) with | some _ => pure none - | none => do - src ← `(src); - let sourceOpt := sourceOpt.setArg 0 src; - let stxNew := stx.setArg 1 sourceOpt; + | none => + let src ← `(src) + let sourceOpt := sourceOpt.setArg 0 src + let stxNew := stx.setArg 1 sourceOpt `(let src := $source; $stxNew) inductive Source @@ -64,15 +63,15 @@ def setStructSourceSyntax (structStx : Syntax) : Source → Syntax | Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode private def getStructSource (stx : Syntax) : TermElabM Source := -withRef stx $ -let explicitSource := stx.getArg 1; -let implicitSource := stx.getArg 3; +withRef stx do +let explicitSource := stx[1] +let implicitSource := stx[3] if explicitSource.isNone && implicitSource.isNone then pure Source.none else if explicitSource.isNone then pure $ Source.implicit implicitSource -else if implicitSource.isNone then do - fvar? ← isLocalIdent? (explicitSource.getArg 0); +else if implicitSource.isNone then + let fvar? ← isLocalIdent? explicitSource[0] match fvar? with | none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here | some src => pure $ Source.explicit explicitSource src @@ -85,16 +84,16 @@ else def structInstArrayRef := parser! "[" >> termParser >>"]" ``` -/ private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do -let args := (stx.getArg 2).getArgs; -s? ← args.foldSepByM +let args := stx[2].getArgs +let s? ← args.foldSepByM (fun arg s? => /- Remark: the syntax for `structInstField` is ``` def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef) def structInstField := parser! structInstLVal >> " := " >> termParser ``` -/ - let lval := arg.getArg 0; - let k := lval.getKind; + let lval := arg[0] + let k := lval.getKind if k == `Lean.Parser.Term.structInstArrayRef then match s? with | none => pure (some arg) @@ -111,56 +110,55 @@ s? ← args.foldSepByM throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level" else pure s?) - none; + none match s? with | none => pure none -| some s => if (s.getArg 0).getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none +| some s => if s[0].getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none -private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := -let continue (val : Syntax) : TermElabM Expr := do { - let lval := modifyOp.getArg 0; - let idx := lval.getArg 1; - let self := source.getArg 0; - stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val)); - trace `Elab.struct.modifyOp fun _ => stx ++ "\n===>\n" ++ stxNew; +private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do +let cont (val : Syntax) : TermElabM Expr := do + let lval := modifyOp[0] + let idx := lval[1] + let self := source[0] + let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val)) + trace[Elab.struct.modifyOp]! "{stx}\n===>\n{stxNew}" withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? -}; do -trace `Elab.struct.modifyOp fun _ => modifyOp ++ "\nSource: " ++ source; -let rest := modifyOp.getArg 1; -if rest.isNone then do - continue (modifyOp.getArg 3) -else do - s ← `(s); - let valFirst := rest.getArg 0; - let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst.getArg 1; - let restArgs := rest.getArgs; - let valRest := mkNullNode (restArgs.extract 1 restArgs.size); - let valField := modifyOp.setArg 0 valFirst; - let valField := valField.setArg 1 valRest; - let valSource := source.modifyArg 0 $ fun _ => s; - let val := stx.setArg 1 valSource; - let val := val.setArg 2 $ mkNullNode #[valField]; - trace `Elab.struct.modifyOp fun _ => stx ++ "\nval: " ++ val; - continue val +trace[Elab.struct.modifyOp]! "{modifyOp}\nSource: {source}" +let rest := modifyOp[1] +if rest.isNone then + cont modifyOp[3] +else + let s ← `(s) + let valFirst := rest[0] + let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst[1] + let restArgs := rest.getArgs + let valRest := mkNullNode restArgs[1:restArgs.size] + let valField := modifyOp.setArg 0 valFirst + let valField := valField.setArg 1 valRest + let valSource := source.modifyArg 0 fun _ => s + let val := stx.setArg 1 valSource + let val := val.setArg 2 $ mkNullNode #[valField] + trace[Elab.struct.modifyOp]! "{stx}\nval: {val}" + cont val /- Get structure name and elaborate explicit source (if available) -/ private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do -tryPostponeIfNoneOrMVar expectedType?; +tryPostponeIfNoneOrMVar expectedType? let useSource : Unit → TermElabM Name := fun _ => match sourceView, expectedType? with | Source.explicit _ src, _ => do - srcType ← inferType src; - srcType ← whnf srcType; - tryPostponeIfMVar srcType; + let srcType ← inferType src + let srcType ← whnf srcType + tryPostponeIfMVar srcType match srcType.getAppFn with | Expr.const constName _ _ => pure constName - | _ => throwError ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType) - | _, some expectedType => throwError ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType) - | _, none => throwError ("invalid {...} notation, expected type must be known"); + | _ => throwError! "invalid \{...} notation, source type is not of the form (C ...){indentExpr srcType}" + | _, some expectedType => throwError! "invalid \{...} notation, expected type is not of the form (C ...){indentExpr expectedType}" + | _, none => throwError! "invalid \{...} notation, expected type must be known" match expectedType? with | none => useSource () -| some expectedType => do - expectedType ← whnf expectedType; +| some expectedType => + let expectedType ← whnf expectedType match expectedType.getAppFn with | Expr.const constName _ _ => pure constName | _ => useSource () @@ -178,9 +176,9 @@ instance FieldLHS.hasFormat : HasFormat FieldLHS := | FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩ inductive FieldVal (σ : Type) -| term (stx : Syntax) : FieldVal -| nested (s : σ) : FieldVal -| default : FieldVal -- mark that field must be synthesized using default value +| term (stx : Syntax) : FieldVal σ +| nested (s : σ) : FieldVal σ +| default : FieldVal σ -- mark that field must be synthesized using default value structure Field (σ : Type) := (ref : Syntax) (lhs : List FieldLHS) (val : FieldVal σ) (expr? : Option Expr := none) @@ -203,7 +201,7 @@ partial def Struct.allDefault : Struct → Bool | ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with | FieldVal.term _ => false | FieldVal.default => true - | FieldVal.nested s => Struct.allDefault s + | FieldVal.nested s => allDefault s def Struct.ref : Struct → Syntax | ⟨ref, _, _, _⟩ => ref @@ -226,7 +224,7 @@ Format.joinSep field.lhs " . " ++ " := " ++ partial def formatStruct : Struct → Format | ⟨_, structName, fields, source⟩ => - let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "; + let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", " match source with | Source.none => "{" ++ fieldsFmt ++ "}" | Source.implicit _ => "{" ++ fieldsFmt ++ " .. }" @@ -258,35 +256,34 @@ def FieldVal.toSyntax : FieldVal Struct → Syntax def Field.toSyntax : Field Struct → Syntax | field => - let stx := field.ref; - let stx := stx.setArg 3 field.val.toSyntax; + let stx := field.ref + let stx := stx.setArg 3 field.val.toSyntax match field.lhs with | first::rest => - let stx := stx.setArg 0 $ first.toSyntax true; - let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false); + let stx := stx.setArg 0 $ first.toSyntax true + let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false) stx | _ => unreachable! private def toFieldLHS (stx : Syntax) : Except String FieldLHS := if stx.getKind == `Lean.Parser.Term.structInstArrayRef then - pure $ FieldLHS.modifyOp stx (stx.getArg 1) + pure $ FieldLHS.modifyOp stx stx[1] else -- Note that the representation of the first field is different. - let stx := if stx.getKind == nullKind then stx.getArg 1 else stx; + let stx := if stx.getKind == nullKind then stx[1] else stx if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes else match stx.isFieldIdx? with | some idx => pure $ FieldLHS.fieldIndex stx idx | none => throw "unexpected structure syntax" private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do -let args := (stx.getArg 2).getArgs; -let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField; -fields ← fieldsStx.toList.mapM $ fun fieldStx => do { - let val := fieldStx.getArg 3; - first ← toFieldLHS (fieldStx.getArg 0); - rest ← (fieldStx.getArg 1).getArgs.toList.mapM toFieldLHS; +let args := stx[2].getArgs +let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField +let fields ← fieldsStx.toList.mapM fun fieldStx => do + let val := fieldStx[3] + let first ← toFieldLHS fieldStx[0] + let rest ← fieldStx[1].getArgs.toList.mapM toFieldLHS pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct) -}; pure ⟨stx, structName, fields, source⟩ def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct := @@ -297,25 +294,25 @@ match s with Id.run $ s.modifyFieldsM f def Struct.setFields (s : Struct) (fields : Fields) : Struct := -s.modifyFields $ fun _ => fields +s.modifyFields fun _ => fields private def expandCompositeFields (s : Struct) : Struct := s.modifyFields $ fun fields => fields.map $ fun field => match field with | { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field | { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } => - let newEntries := n.components.map $ FieldLHS.fieldName ref; + let newEntries := n.components.map $ FieldLHS.fieldName ref { field with lhs := newEntries ++ rest } | _ => field private def expandNumLitFields (s : Struct) : TermElabM Struct := -s.modifyFieldsM $ fun fields => do - env ← getEnv; - let fieldNames := getStructureFields env s.structName; - fields.mapM $ fun field => match field with +s.modifyFieldsM fun fields => do + let env ← getEnv + let fieldNames := getStructureFields env s.structName + fields.mapM fun field => match field with | { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } => if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0" - else if idx > fieldNames.size then throwErrorAt ref ("invalid field index, structure has only #" ++ toString fieldNames.size ++ " fields") - else pure { field with lhs := FieldLHS.fieldName ref (fieldNames.get! $ idx - 1) :: rest } + else if idx > fieldNames.size then throwErrorAt! ref "invalid field index, structure has only #{fieldNames.size} fields" + else pure { field with lhs := FieldLHS.fieldName ref fieldNames[idx - 1] :: rest } | _ => pure field /- For example, consider the following structures: @@ -334,38 +331,36 @@ s.modifyFieldsM $ fun fields => do { toB.toA.x := 0, toB.y := 0, z := true : C } ``` -/ private def expandParentFields (s : Struct) : TermElabM Struct := do -env ← getEnv; -s.modifyFieldsM $ fun fields => fields.mapM $ fun field => match field with +let env ← getEnv +s.modifyFieldsM fun fields => fields.mapM fun field => match field with | { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } => match findField? env s.structName fieldName with - | none => throwErrorAt ref ("'" ++ fieldName ++ "' is not a field of structure '" ++ s.structName ++ "'") + | none => throwErrorAt! ref "'{fieldName}' is not a field of structure '{s.structName}'" | some baseStructName => if baseStructName == s.structName then pure field else match getPathToBaseStructure? env baseStructName s.structName with | some path => do let path := path.map $ fun funName => match funName with | Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s) - | _ => unreachable!; + | _ => unreachable! pure { field with lhs := path ++ field.lhs } - | _ => throwErrorAt ref ("failed to access field '" ++ fieldName ++ "' in parent structure") + | _ => throwErrorAt! ref "failed to access field '{fieldName}' in parent structure" | _ => pure field private abbrev FieldMap := HashMap Name Fields private def mkFieldMap (fields : Fields) : TermElabM FieldMap := -fields.foldlM - (fun fieldMap field => - match field.lhs with - | FieldLHS.fieldName _ fieldName :: rest => - match fieldMap.find? fieldName with - | some (prevField::restFields) => - if field.isSimple || prevField.isSimple then - throwErrorAt field.ref ("field '" ++ fieldName ++ "' has already beed specified") - else - pure $ fieldMap.insert fieldName (field::prevField::restFields) - | _ => pure $ fieldMap.insert fieldName [field] - | _ => unreachable!) - {} +fields.foldlM (init := {}) fun fieldMap field => + match field.lhs with + | FieldLHS.fieldName _ fieldName :: rest => + match fieldMap.find? fieldName with + | some (prevField::restFields) => + if field.isSimple || prevField.isSimple then + throwErrorAt! field.ref "field '{fieldName}' has already beed specified" + else + pure $ fieldMap.insert fieldName (field::prevField::restFields) + | _ => pure $ fieldMap.insert fieldName [field] + | _ => unreachable! private def isSimpleField? : Fields → Option (Field Struct) | [field] => if field.isSimple then some field else none @@ -374,7 +369,7 @@ private def isSimpleField? : Fields → Option (Field Struct) private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do match fieldNames.findIdx? $ fun n => n == fieldName with | some idx => pure idx -| none => throwError ("field '" ++ fieldName ++ "' is not a valid field of '" ++ structName ++ "'") +| none => throwError! "field '{fieldName}' is not a valid field of '{structName}'" private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax := Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName] @@ -382,81 +377,78 @@ Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldNam private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source := match src with | Source.explicit stx src => do - idx ← getFieldIdx structName fieldNames fieldName; - let stx := stx.modifyArg 0 $ fun stx => mkProjStx stx fieldName; + let idx ← getFieldIdx structName fieldNames fieldName + let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName pure $ Source.explicit stx (mkProj structName idx src) | s => pure s @[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do -env ← getEnv; -let fieldNames := getStructureFields env s.structName; -withRef s.ref $ -s.modifyFieldsM $ fun fields => do - fieldMap ← mkFieldMap fields; - fieldMap.toList.mapM $ fun ⟨fieldName, fields⟩ => +let env ← getEnv +let fieldNames := getStructureFields env s.structName +withRef s.ref do +s.modifyFieldsM fun fields => do + let fieldMap ← mkFieldMap fields + fieldMap.toList.mapM fun ⟨fieldName, fields⟩ => do match isSimpleField? fields with | some field => pure field - | none => do - let substructFields := fields.map $ fun field => { field with lhs := field.lhs.tail! }; - substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source; - let field := fields.head!; + | none => + let substructFields := fields.map fun field => { field with lhs := field.lhs.tail! } + let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source + let field := fields.head! match Lean.isSubobjectField? env s.structName fieldName with - | some substructName => do - let substruct := Struct.mk s.ref substructName substructFields substructSource; - substruct ← expandStruct substruct; + | some substructName => + let substruct := Struct.mk s.ref substructName substructFields substructSource + let substruct ← expandStruct substruct pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct } | none => do -- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them. - let valStx := s.ref; -- construct substructure syntax using s.ref as template - let valStx := valStx.setArg 4 mkNullNode; -- erase optional expected type - let args := substructFields.toArray.map $ Field.toSyntax; - let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ",")); - let valStx := setStructSourceSyntax valStx substructSource; + let valStx := s.ref -- construct substructure syntax using s.ref as template + let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type + let args := substructFields.toArray.map Field.toSyntax + let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ",")) + let valStx := setStructSourceSyntax valStx substructSource pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx } def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) := -fields.find? $ fun field => +fields.find? fun field => match field.lhs with | [FieldLHS.fieldName _ n] => n == fieldName | _ => false @[specialize] private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do -env ← getEnv; -let fieldNames := getStructureFields env s.structName; -let ref := s.ref; +let env ← getEnv +let fieldNames := getStructureFields env s.structName +let ref := s.ref withRef ref do -fields ← fieldNames.foldlM - (fun fields fieldName => do - match findField? s.fields fieldName with - | some field => pure $ field::fields - | none => - let addField (val : FieldVal Struct) : TermElabM Fields := do { - pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields - }; - match Lean.isSubobjectField? env s.structName fieldName with - | some substructName => do - substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source; - let substruct := Struct.mk s.ref substructName [] substructSource; - substruct ← expandStruct substruct; - addField (FieldVal.nested substruct) - | none => - match s.source with - | Source.none => addField FieldVal.default - | Source.implicit _ => addField (FieldVal.term (mkHole s.ref)) - | Source.explicit stx _ => - -- stx is of the form `optional (try (termParser >> "with"))` - let src := stx.getArg 0; - let val := mkProjStx src fieldName; - addField (FieldVal.term val)) - []; +let fields ← fieldNames.foldlM (init := []) fun fields fieldName => do + match findField? s.fields fieldName with + | some field => pure $ field::fields + | none => + let addField (val : FieldVal Struct) : TermElabM Fields := do + pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields + match Lean.isSubobjectField? env s.structName fieldName with + | some substructName => do + let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source + let substruct := Struct.mk s.ref substructName [] substructSource + let substruct ← expandStruct substruct + addField (FieldVal.nested substruct) + | none => + match s.source with + | Source.none => addField FieldVal.default + | Source.implicit _ => addField (FieldVal.term (mkHole s.ref)) + | Source.explicit stx _ => + -- stx is of the form `optional (try (termParser >> "with"))` + let src := stx[0] + let val := mkProjStx src fieldName + addField (FieldVal.term val) pure $ s.setFields fields.reverse private partial def expandStruct : Struct → TermElabM Struct | s => do - let s := expandCompositeFields s; - s ← expandNumLitFields s; - s ← expandParentFields s; - s ← groupFields expandStruct s; + let s := expandCompositeFields s + let s ← expandNumLitFields s + let s ← expandParentFields s + let s ← groupFields expandStruct s addMissingFields expandStruct s structure CtorHeaderResult := @@ -467,15 +459,15 @@ structure CtorHeaderResult := private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult | 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars } | n+1, type, ctorFn, instMVars => do - type ← whnfForall type; + let type ← whnfForall type match type with | Expr.forallE _ d b c => match c.binderInfo with - | BinderInfo.instImplicit => do - a ← mkFreshExprMVar d MetavarKind.synthetic; + | BinderInfo.instImplicit => + let a ← mkFreshExprMVar d MetavarKind.synthetic mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!) - | _ => do - a ← mkFreshExprMVar d; + | _ => + let a ← mkFreshExprMVar d mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars | _ => throwError "unexpected constructor type" @@ -487,21 +479,21 @@ private partial def getForallBody : Nat → Expr → Option Expr private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit := match expectedType? with | none => pure () -| some expectedType => +| some expectedType => do match getForallBody numFields type with | none => pure () | some typeBody => - unless typeBody.hasLooseBVars $ do - _ ← isDefEq expectedType typeBody; + unless typeBody.hasLooseBVars do + isDefEq expectedType typeBody pure () private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do -lvls ← ctorVal.lparams.mapM $ fun _ => mkFreshLevelMVar; -let val := Lean.mkConst ctorVal.name lvls; -let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls; -r ← mkCtorHeaderAux ctorVal.nparams type val #[]; -propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?; -synthesizeAppInstMVars r.instMVars; +let lvls ← ctorVal.lparams.mapM fun _ => mkFreshLevelMVar +let val := Lean.mkConst ctorVal.name lvls +let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls +let r ← mkCtorHeaderAux ctorVal.nparams type val #[] +propagateExpectedType r.ctorFnType ctorVal.nfields expectedType? +synthesizeAppInstMVars r.instMVars pure r def markDefaultMissing (e : Expr) : Expr := @@ -511,43 +503,40 @@ def defaultMissing? (e : Expr) : Option Expr := annotation? `structInstDefault e def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α := -throwError ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData) +throwError! "failed to elaborate field '{fieldName}' of '{structName}, {msgData}" -def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := -if !s.allDefault then pure none +def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do +if !s.allDefault then + pure none else - catch (synthInstance? expectedType) (fun _ => pure none) + try synthInstance? expectedType catch _ => pure none private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct) | s, expectedType? => withRef s.ref do - env ← getEnv; - let ctorVal := getStructureCtor env s.structName; - { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?; - (e, _, fields) ← s.fields.foldlM - (fun (acc : Expr × Expr × Fields) field => - let (e, type, fields) := acc; - match field.lhs with - | [FieldLHS.fieldName ref fieldName] => do - type ← whnfForall type; - match type with - | Expr.forallE _ d b c => - let continue (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do { - let e := mkApp e val; - let type := b.instantiate1 val; - let field := { field with expr? := some val }; - pure (e, type, field::fields) - }; - match field.val with - | FieldVal.term stx => do val ← elabTermEnsuringType stx d; continue val field - | FieldVal.nested s => do - val? ← trySynthStructInstance? s d; -- if all fields of `s` are marked as `default`, then try to synthesize instance - match val? with - | some val => continue val { field with val := FieldVal.term (mkHole field.ref) } - | none => do(val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; continue val { field with val := FieldVal.nested sNew } - | FieldVal.default => do val ← withRef field.ref $ mkFreshExprMVar (some d); continue (markDefaultMissing val) field - | _ => withRef field.ref $ throwFailedToElabField fieldName s.structName ("unexpected constructor type" ++ indentExpr type) - | _ => throwErrorAt field.ref "unexpected unexpanded structure field") - (ctorFn, ctorFnType, []); + let env ← getEnv + let ctorVal := getStructureCtor env s.structName + let { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType? + let (e, _, fields) ← s.fields.foldlM (init := (ctorFn, ctorFnType, [])) fun (e, type, fields) field => + match field.lhs with + | [FieldLHS.fieldName ref fieldName] => do + let type ← whnfForall type + match type with + | Expr.forallE _ d b c => + let cont (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do + let e := mkApp e val + let type := b.instantiate1 val + let field := { field with expr? := some val } + pure (e, type, field::fields) + match field.val with + | FieldVal.term stx => cont (← elabTermEnsuringType stx d) field + | FieldVal.nested s => do + -- if all fields of `s` are marked as `default`, then try to synthesize instance + match (← trySynthStructInstance? s d) with + | some val => cont val { field with val := FieldVal.term (mkHole field.ref) } + | none => do let (val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; cont val { field with val := FieldVal.nested sNew } + | FieldVal.default => do let val ← withRef field.ref $ mkFreshExprMVar (some d); cont (markDefaultMissing val) field + | _ => withRef field.ref $ throwFailedToElabField fieldName s.structName msg!"unexpected constructor type{indentExpr type}" + | _ => throwErrorAt field.ref "unexpected unexpanded structure field" pure (e, s.setFields fields.reverse) namespace DefaultFields @@ -587,28 +576,24 @@ structure State := partial def collectStructNames : Struct → Array Name → Array Name | struct, names => - let names := names.push struct.structName; - struct.fields.foldl - (fun names field => - match field.val with - | FieldVal.nested struct => collectStructNames struct names - | _ => names) - names + let names := names.push struct.structName + struct.fields.foldl (init := names) fun names field => + match field.val with + | FieldVal.nested struct => collectStructNames struct names + | _ => names partial def getHierarchyDepth : Struct → Nat | struct => - struct.fields.foldl - (fun max field => - match field.val with - | FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1) - | _ => max) - 0 + struct.fields.foldl (init := 0) fun max field => + match field.val with + | FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1) + | _ => max partial def findDefaultMissing? (mctx : MetavarContext) : Struct → Option (Field Struct) | struct => - struct.fields.findSome? $ fun field => + struct.fields.findSome? fun field => match field.val with - | FieldVal.nested struct => findDefaultMissing? struct + | FieldVal.nested struct => findDefaultMissing? mctx struct | _ => match field.expr? with | none => unreachable! | some expr => match defaultMissing? expr with @@ -623,31 +608,30 @@ match field.lhs with abbrev M := ReaderT Context (StateRefT State TermElabM) def isRoundDone : M Bool := do -ctx ← read; -s ← get; -pure (s.progress && ctx.maxDistance > 0) +return (← get).progress && (← read).maxDistance > 0 def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr := -struct.fields.findSome? $ fun field => +struct.fields.findSome? fun field => if getFieldName field == fieldName then field.expr? else none partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr) -| Expr.lam n d b c => withRef struct.ref $ +| Expr.lam n d b c => withRef struct.ref do if c.binderInfo.isExplicit then - let fieldName := n; + let fieldName := n match getFieldValue? struct fieldName with | none => pure none - | some val => do - valType ← inferType val; - condM (isDefEq valType d) - (mkDefaultValueAux? (b.instantiate1 val)) - (pure none) - else do - arg ← mkFreshExprMVar d; - mkDefaultValueAux? (b.instantiate1 arg) + | some val => + let valType ← inferType val + if (← isDefEq valType d) then + mkDefaultValueAux? struct (b.instantiate1 val) + else + pure none + else + let arg ← mkFreshExprMVar d + mkDefaultValueAux? struct (b.instantiate1 arg) | e => if e.isAppOfArity `id 2 then pure (some e.appArg!) @@ -656,7 +640,7 @@ partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Ex def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) := withRef struct.ref do -us ← cinfo.lparams.mapM $ fun _ => mkFreshLevelMVar; +let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us) /-- If `e` is a projection function of one of the given structures, then reduce it -/ @@ -664,7 +648,7 @@ def reduceProjOf? (structNames : Array Name) (e : Expr) : MetaM (Option Expr) := if !e.isApp then pure none else match e.getAppFn with | Expr.const name _ _ => do - env ← getEnv; + let env ← getEnv match env.getProjectionStructureName? name with | some structName => if structNames.contains structName then @@ -676,142 +660,126 @@ else match e.getAppFn with /-- Reduce default value. It performs beta reduction and projections of the given structures. -/ partial def reduce (structNames : Array Name) : Expr → MetaM Expr -| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLambdaFVars xs b -| e@(Expr.forallE _ _ _ _) => forallTelescope e $ fun xs b => do b ← reduce b; mkForallFVars xs b -| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLetFVars xs b +| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLambdaFVars xs (← reduce structNames b) +| e@(Expr.forallE _ _ _ _) => forallTelescope e fun xs b => do mkForallFVars xs (← reduce structNames b) +| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLetFVars xs (← reduce structNames b) | e@(Expr.proj _ i b _) => do - r? ← Meta.reduceProj? b i; - match r? with - | some r => reduce r - | none => do b ← reduce b; pure $ e.updateProj! b + match (← Meta.reduceProj? b i) with + | some r => reduce structNames r + | none => pure $ e.updateProj! (← reduce structNames b) | e@(Expr.app f _ _) => do - r? ← reduceProjOf? structNames e; - match r? with - | some r => reduce r - | none => do - let f := f.getAppFn; - f' ← reduce f; + match (← reduceProjOf? structNames e) with + | some r => reduce structNames r + | none => + let f := f.getAppFn + let f' ← reduce structNames f if f'.isLambda then - let revArgs := e.getAppRevArgs; - reduce $ f'.betaRev revArgs - else do - args ← e.getAppArgs.mapM reduce; + let revArgs := e.getAppRevArgs + reduce structNames (f'.betaRev revArgs) + else + let args ← e.getAppArgs.mapM (reduce structNames) pure (mkAppN f' args) | e@(Expr.mdata _ b _) => do - b ← reduce b; + let b ← reduce structNames b if (defaultMissing? e).isSome && !b.isMVar then pure b else pure $ e.updateMData! b | e@(Expr.mvar mvarId _) => do - val? ← getExprMVarAssignment? mvarId; - match val? with - | some val => if val.isMVar then reduce val else pure val + match (← getExprMVarAssignment? mvarId) with + | some val => if val.isMVar then reduce structNames val else pure val | none => pure e | e => pure e -partial def tryToSynthesizeDefaultAux (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) - (fieldName : Name) (mvarId : MVarId) : Nat → Nat → TermElabM Bool -| i, dist => - if dist > maxDistance then pure false +partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool := +let rec loop (i : Nat) (dist : Nat) := do + if dist > maxDistance then + pure false else if h : i < structs.size then do - let struct := structs.get ⟨i, h⟩; - let defaultName := struct.structName ++ fieldName ++ `_default; - env ← getEnv; + let struct := structs.get ⟨i, h⟩ + let defaultName := struct.structName ++ fieldName ++ `_default + let env ← getEnv match env.find? defaultName with | some cinfo@(ConstantInfo.defnInfo defVal) => do - mctx ← getMCtx; - val? ← mkDefaultValue? struct cinfo; + let mctx ← getMCtx + let val? ← mkDefaultValue? struct cinfo match val? with - | none => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1) + | none => do setMCtx mctx; loop (i+1) (dist+1) | some val => do - val ← liftMetaM $ reduce allStructNames val; - match val.find? $ fun e => (defaultMissing? e).isSome with - | some _ => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1) - | none => do - mvarDecl ← getMVarDecl mvarId; - val ← ensureHasType mvarDecl.type val; - assignExprMVar mvarId val; + let val ← liftMetaM $ reduce allStructNames val + match val.find? fun e => (defaultMissing? e).isSome with + | some _ => setMCtx mctx; loop (i+1) (dist+1) + | none => + let mvarDecl ← getMVarDecl mvarId + let val ← ensureHasType mvarDecl.type val + assignExprMVar mvarId val pure true - | _ => tryToSynthesizeDefaultAux (i+1) dist + | _ => loop (i+1) dist else pure false - -def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) - (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool := -tryToSynthesizeDefaultAux structs allStructNames maxDistance fieldName mvarId 0 0 +loop 0 0 partial def step : Struct → M Unit -| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) $ do - struct.fields.forM $ fun field => +| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do + struct.fields.forM fun field => do match field.val with | FieldVal.nested struct => step struct | _ => match field.expr? with | none => unreachable! | some expr => match defaultMissing? expr with | some (Expr.mvar mvarId _) => - unlessM (liftM $ isExprMVarAssigned mvarId) $ do - ctx ← read; - whenM (liftM $ withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) $ do - modify $ fun s => { s with progress := true } + unless (← isExprMVarAssigned mvarId) do + let ctx ← read + if (← withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) then + modify fun s => { s with progress := true } | _ => pure () partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit | d, struct => do - mctx ← getMCtx; - match findDefaultMissing? mctx struct with + match findDefaultMissing? (← getMCtx) struct with | none => pure () -- Done | some field => if d > hierarchyDepth then - throwErrorAt field.ref ("field '" ++ getFieldName field ++ "' is missing") - else withReader (fun ctx => { ctx with maxDistance := d }) $ do - modify $ fun (s : State) => { s with progress := false }; - step struct; - s ← get; - if s.progress then do - propagateLoop 0 struct + throwErrorAt! field.ref "field '{getFieldName field}' is missing" + else withReader (fun ctx => { ctx with maxDistance := d }) do + modify fun s => { s with progress := false } + step struct + if (← get).progress then do + propagateLoop hierarchyDepth 0 struct else - propagateLoop (d+1) struct + propagateLoop hierarchyDepth (d+1) struct def propagate (struct : Struct) : TermElabM Unit := -let hierarchyDepth := getHierarchyDepth struct; -let structNames := collectStructNames struct #[]; +let hierarchyDepth := getHierarchyDepth struct +let structNames := collectStructNames struct #[] (propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {} end DefaultFields private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do -structName ← getStructName stx expectedType? source; -env ← getEnv; -unless (isStructureLike env structName) $ - throwError ("invalid {...} notation, '" ++ structName ++ "' is not a structure"); +let structName ← getStructName stx expectedType? source +unless isStructureLike (← getEnv) structName do + throwError! "invalid \{...} notation, '{structName}' is not a structure" match mkStructView stx structName source with | Except.error ex => throwError ex -| Except.ok struct => do - struct ← expandStruct struct; - trace `Elab.struct fun _ => toString struct; - (r, struct) ← elabStruct struct expectedType?; - DefaultFields.propagate struct; +| Except.ok struct => + let struct ← expandStruct struct + trace[Elab.struct]! "{struct}" + let (r, struct) ← elabStruct struct expectedType? + DefaultFields.propagate struct pure r @[builtinTermElab structInst] def elabStructInst : TermElab := fun stx expectedType? => do - stxNew? ← expandNonAtomicExplicitSource stx; - match stxNew? with + match (← expandNonAtomicExplicitSource stx) with | some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? - | none => do - sourceView ← getStructSource stx; - modifyOp? ← isModifyOp? stx; - match modifyOp?, sourceView with + | none => + let sourceView ← getStructSource stx + match (← isModifyOp? stx), sourceView with | some modifyOp, Source.explicit source _ => elabModifyOp stx modifyOp source expectedType? - | some _, _ => throwError ("invalid {...} notation, explicit source is required when using '[] := '") + | some _, _ => throwError "invalid {...} notation, explicit source is required when using '[] := '" | _, _ => elabStructInstAux stx expectedType? sourceView -@[init] private def regTraceClasses : IO Unit := do -registerTraceClass `Elab.struct; -pure () +initialize registerTraceClass `Elab.struct -end StructInst -end Term -end Elab -end Lean +end Lean.Elab.Term.StructInst diff --git a/src/Std/Data/HashMap.lean b/src/Std/Data/HashMap.lean index d32de5db72..151d09b56e 100644 --- a/src/Std/Data/HashMap.lean +++ b/src/Std/Data/HashMap.lean @@ -168,13 +168,13 @@ self.find? idx match m with | ⟨ m, _ ⟩ => m.contains a -@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → β → m δ) (d : δ) (h : HashMap α β) : m δ := +@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → β → m δ) (init : δ) (h : HashMap α β) : m δ := match h with -| ⟨ h, _ ⟩ => h.foldM f d +| ⟨ h, _ ⟩ => h.foldM f init -@[inline] def fold {δ : Type w} (f : δ → α → β → δ) (d : δ) (m : HashMap α β) : δ := +@[inline] def fold {δ : Type w} (f : δ → α → β → δ) (init : δ) (m : HashMap α β) : δ := match m with -| ⟨ m, _ ⟩ => m.fold f d +| ⟨ m, _ ⟩ => m.fold f init @[inline] def size (m : HashMap α β) : Nat := match m with