chore: move to new frontend

This commit is contained in:
Leonardo de Moura 2020-10-16 08:40:42 -07:00
parent 34cddb334e
commit e02a06ad1c
3 changed files with 294 additions and 326 deletions

View file

@ -108,14 +108,14 @@ def filterMapM {m : Type u → Type v} [Monad m] {α β : Type u} (f : α → m
filterMapMAux f as.reverse []
@[specialize]
def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (s → α → m s) → s → List α → m s
def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : forall (f : s → α → m s) (init : s), List α → m s
| f, s, [] => pure s
| f, s, h :: r => do
s' ← f s h;
foldlM f s' r
@[specialize]
def foldrM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (α → s → m s) → s → List α → m s
def foldrM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : forall (f : α → s → m s) (init : s), List α → m s
| f, s, [] => pure s
| f, s, h :: r => do
s' ← foldrM f s r;

View file

@ -1,3 +1,4 @@
#lang lean4
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
@ -8,10 +9,7 @@ import Lean.Elab.App
import Lean.Elab.Binders
import Lean.Elab.Quotation
namespace Lean
namespace Elab
namespace Term
namespace StructInst
namespace Lean.Elab.Term.StructInst
open Std (HashMap)
open Meta
@ -20,12 +18,12 @@ open Meta
@[builtinMacro Lean.Parser.Term.structInst] def expandStructInstExpectedType : Macro :=
fun stx =>
let expectedArg := stx.getArg 4;
let expectedArg := stx[4]
if expectedArg.isNone
then Macro.throwUnsupported
else
let expected := expectedArg.getArg 1;
let stxNew := stx.setArg 4 mkNullNode;
let expected := expectedArg[1]
let stxNew := stx.setArg 4 mkNullNode
`(($stxNew : $expected))
/-
@ -34,17 +32,18 @@ If `stx` is of the form `{ s with ... }` and `s` is not a local variable, expand
Note that this one is not a `Macro` because we need to access the local context.
-/
private def expandNonAtomicExplicitSource (stx : Syntax) : TermElabM (Option Syntax) :=
withFreshMacroScope $
let sourceOpt := stx.getArg 1;
if sourceOpt.isNone then pure none else do
let source := sourceOpt.getArg 0;
fvar? ← isLocalIdent? source;
match fvar? with
withFreshMacroScope do
let sourceOpt := stx[1]
if sourceOpt.isNone then
pure none
else
let source := sourceOpt[0]
match (← isLocalIdent? source) with
| some _ => pure none
| none => do
src ← `(src);
let sourceOpt := sourceOpt.setArg 0 src;
let stxNew := stx.setArg 1 sourceOpt;
| none =>
let src ← `(src)
let sourceOpt := sourceOpt.setArg 0 src
let stxNew := stx.setArg 1 sourceOpt
`(let src := $source; $stxNew)
inductive Source
@ -64,15 +63,15 @@ def setStructSourceSyntax (structStx : Syntax) : Source → Syntax
| Source.explicit stx _ => (structStx.setArg 1 stx).setArg 3 mkNullNode
private def getStructSource (stx : Syntax) : TermElabM Source :=
withRef stx $
let explicitSource := stx.getArg 1;
let implicitSource := stx.getArg 3;
withRef stx do
let explicitSource := stx[1]
let implicitSource := stx[3]
if explicitSource.isNone && implicitSource.isNone then
pure Source.none
else if explicitSource.isNone then
pure $ Source.implicit implicitSource
else if implicitSource.isNone then do
fvar? ← isLocalIdent? (explicitSource.getArg 0);
else if implicitSource.isNone then
let fvar? ← isLocalIdent? explicitSource[0]
match fvar? with
| none => unreachable! -- expandNonAtomicExplicitSource must have been used when we get here
| some src => pure $ Source.explicit explicitSource src
@ -85,16 +84,16 @@ else
def structInstArrayRef := parser! "[" >> termParser >>"]"
``` -/
private def isModifyOp? (stx : Syntax) : TermElabM (Option Syntax) := do
let args := (stx.getArg 2).getArgs;
s? ← args.foldSepByM
let args := stx[2].getArgs
let s? ← args.foldSepByM
(fun arg s? =>
/- Remark: the syntax for `structInstField` is
```
def structInstLVal := (ident <|> numLit <|> structInstArrayRef) >> many (group ("." >> (ident <|> numLit)) <|> structInstArrayRef)
def structInstField := parser! structInstLVal >> " := " >> termParser
``` -/
let lval := arg.getArg 0;
let k := lval.getKind;
let lval := arg[0]
let k := lval.getKind
if k == `Lean.Parser.Term.structInstArrayRef then
match s? with
| none => pure (some arg)
@ -111,56 +110,55 @@ s? ← args.foldSepByM
throwErrorAt arg "invalid {...} notation, can't mix field and `[..]` at a given level"
else
pure s?)
none;
none
match s? with
| none => pure none
| some s => if (s.getArg 0).getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
| some s => if s[0].getKind == `Lean.Parser.Term.structInstArrayRef then pure s? else pure none
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr :=
let continue (val : Syntax) : TermElabM Expr := do {
let lval := modifyOp.getArg 0;
let idx := lval.getArg 1;
let self := source.getArg 0;
stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val));
trace `Elab.struct.modifyOp fun _ => stx ++ "\n===>\n" ++ stxNew;
private def elabModifyOp (stx modifyOp source : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
let cont (val : Syntax) : TermElabM Expr := do
let lval := modifyOp[0]
let idx := lval[1]
let self := source[0]
let stxNew ← `($(self).modifyOp (idx := $idx) (fun s => $val))
trace[Elab.struct.modifyOp]! "{stx}\n===>\n{stxNew}"
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
}; do
trace `Elab.struct.modifyOp fun _ => modifyOp ++ "\nSource: " ++ source;
let rest := modifyOp.getArg 1;
if rest.isNone then do
continue (modifyOp.getArg 3)
else do
s ← `(s);
let valFirst := rest.getArg 0;
let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst.getArg 1;
let restArgs := rest.getArgs;
let valRest := mkNullNode (restArgs.extract 1 restArgs.size);
let valField := modifyOp.setArg 0 valFirst;
let valField := valField.setArg 1 valRest;
let valSource := source.modifyArg 0 $ fun _ => s;
let val := stx.setArg 1 valSource;
let val := val.setArg 2 $ mkNullNode #[valField];
trace `Elab.struct.modifyOp fun _ => stx ++ "\nval: " ++ val;
continue val
trace[Elab.struct.modifyOp]! "{modifyOp}\nSource: {source}"
let rest := modifyOp[1]
if rest.isNone then
cont modifyOp[3]
else
let s ← `(s)
let valFirst := rest[0]
let valFirst := if valFirst.getKind == `Lean.Parser.Term.structInstArrayRef then valFirst else valFirst[1]
let restArgs := rest.getArgs
let valRest := mkNullNode restArgs[1:restArgs.size]
let valField := modifyOp.setArg 0 valFirst
let valField := valField.setArg 1 valRest
let valSource := source.modifyArg 0 fun _ => s
let val := stx.setArg 1 valSource
let val := val.setArg 2 $ mkNullNode #[valField]
trace[Elab.struct.modifyOp]! "{stx}\nval: {val}"
cont val
/- Get structure name and elaborate explicit source (if available) -/
private def getStructName (stx : Syntax) (expectedType? : Option Expr) (sourceView : Source) : TermElabM Name := do
tryPostponeIfNoneOrMVar expectedType?;
tryPostponeIfNoneOrMVar expectedType?
let useSource : Unit → TermElabM Name := fun _ =>
match sourceView, expectedType? with
| Source.explicit _ src, _ => do
srcType ← inferType src;
srcType ← whnf srcType;
tryPostponeIfMVar srcType;
let srcType ← inferType src
let srcType ← whnf srcType
tryPostponeIfMVar srcType
match srcType.getAppFn with
| Expr.const constName _ _ => pure constName
| _ => throwError ("invalid {...} notation, source type is not of the form (C ...)" ++ indentExpr srcType)
| _, some expectedType => throwError ("invalid {...} notation, expected type is not of the form (C ...)" ++ indentExpr expectedType)
| _, none => throwError ("invalid {...} notation, expected type must be known");
| _ => throwError! "invalid \{...} notation, source type is not of the form (C ...){indentExpr srcType}"
| _, some expectedType => throwError! "invalid \{...} notation, expected type is not of the form (C ...){indentExpr expectedType}"
| _, none => throwError! "invalid \{...} notation, expected type must be known"
match expectedType? with
| none => useSource ()
| some expectedType => do
expectedType ← whnf expectedType;
| some expectedType =>
let expectedType ← whnf expectedType
match expectedType.getAppFn with
| Expr.const constName _ _ => pure constName
| _ => useSource ()
@ -178,9 +176,9 @@ instance FieldLHS.hasFormat : HasFormat FieldLHS :=
| FieldLHS.modifyOp _ i => "[" ++ i.prettyPrint ++ "]"⟩
inductive FieldVal (σ : Type)
| term (stx : Syntax) : FieldVal
| nested (s : σ) : FieldVal
| default : FieldVal -- mark that field must be synthesized using default value
| term (stx : Syntax) : FieldVal σ
| nested (s : σ) : FieldVal σ
| default : FieldVal σ -- mark that field must be synthesized using default value
structure Field (σ : Type) :=
(ref : Syntax) (lhs : List FieldLHS) (val : FieldVal σ) (expr? : Option Expr := none)
@ -203,7 +201,7 @@ partial def Struct.allDefault : Struct → Bool
| ⟨_, _, fields, _⟩ => fields.all fun ⟨_, _, val, _⟩ => match val with
| FieldVal.term _ => false
| FieldVal.default => true
| FieldVal.nested s => Struct.allDefault s
| FieldVal.nested s => allDefault s
def Struct.ref : Struct → Syntax
| ⟨ref, _, _, _⟩ => ref
@ -226,7 +224,7 @@ Format.joinSep field.lhs " . " ++ " := " ++
partial def formatStruct : Struct → Format
| ⟨_, structName, fields, source⟩ =>
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", ";
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "
match source with
| Source.none => "{" ++ fieldsFmt ++ "}"
| Source.implicit _ => "{" ++ fieldsFmt ++ " .. }"
@ -258,35 +256,34 @@ def FieldVal.toSyntax : FieldVal Struct → Syntax
def Field.toSyntax : Field Struct → Syntax
| field =>
let stx := field.ref;
let stx := stx.setArg 3 field.val.toSyntax;
let stx := field.ref
let stx := stx.setArg 3 field.val.toSyntax
match field.lhs with
| first::rest =>
let stx := stx.setArg 0 $ first.toSyntax true;
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false);
let stx := stx.setArg 0 $ first.toSyntax true
let stx := stx.setArg 1 $ mkNullNode $ rest.toArray.map (FieldLHS.toSyntax false)
stx
| _ => unreachable!
private def toFieldLHS (stx : Syntax) : Except String FieldLHS :=
if stx.getKind == `Lean.Parser.Term.structInstArrayRef then
pure $ FieldLHS.modifyOp stx (stx.getArg 1)
pure $ FieldLHS.modifyOp stx stx[1]
else
-- Note that the representation of the first field is different.
let stx := if stx.getKind == nullKind then stx.getArg 1 else stx;
let stx := if stx.getKind == nullKind then stx[1] else stx
if stx.isIdent then pure $ FieldLHS.fieldName stx stx.getId.eraseMacroScopes
else match stx.isFieldIdx? with
| some idx => pure $ FieldLHS.fieldIndex stx idx
| none => throw "unexpected structure syntax"
private def mkStructView (stx : Syntax) (structName : Name) (source : Source) : Except String Struct := do
let args := (stx.getArg 2).getArgs;
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField;
fields ← fieldsStx.toList.mapM $ fun fieldStx => do {
let val := fieldStx.getArg 3;
first ← toFieldLHS (fieldStx.getArg 0);
rest ← (fieldStx.getArg 1).getArgs.toList.mapM toFieldLHS;
let args := stx[2].getArgs
let fieldsStx := args.filter $ fun arg => arg.getKind == `Lean.Parser.Term.structInstField
let fields ← fieldsStx.toList.mapM fun fieldStx => do
let val := fieldStx[3]
let first ← toFieldLHS fieldStx[0]
let rest ← fieldStx[1].getArgs.toList.mapM toFieldLHS
pure $ ({ref := fieldStx, lhs := first :: rest, val := FieldVal.term val } : Field Struct)
};
pure ⟨stx, structName, fields, source⟩
def Struct.modifyFieldsM {m : Type → Type} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
@ -297,25 +294,25 @@ match s with
Id.run $ s.modifyFieldsM f
def Struct.setFields (s : Struct) (fields : Fields) : Struct :=
s.modifyFields $ fun _ => fields
s.modifyFields fun _ => fields
private def expandCompositeFields (s : Struct) : Struct :=
s.modifyFields $ fun fields => fields.map $ fun field => match field with
| { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field
| { lhs := FieldLHS.fieldName ref n@(Name.str _ _ _) :: rest, .. } =>
let newEntries := n.components.map $ FieldLHS.fieldName ref;
let newEntries := n.components.map $ FieldLHS.fieldName ref
{ field with lhs := newEntries ++ rest }
| _ => field
private def expandNumLitFields (s : Struct) : TermElabM Struct :=
s.modifyFieldsM $ fun fields => do
env ← getEnv;
let fieldNames := getStructureFields env s.structName;
fields.mapM $ fun field => match field with
s.modifyFieldsM fun fields => do
let env ← getEnv
let fieldNames := getStructureFields env s.structName
fields.mapM fun field => match field with
| { lhs := FieldLHS.fieldIndex ref idx :: rest, .. } =>
if idx == 0 then throwErrorAt ref "invalid field index, index must be greater than 0"
else if idx > fieldNames.size then throwErrorAt ref ("invalid field index, structure has only #" ++ toString fieldNames.size ++ " fields")
else pure { field with lhs := FieldLHS.fieldName ref (fieldNames.get! $ idx - 1) :: rest }
else if idx > fieldNames.size then throwErrorAt! ref "invalid field index, structure has only #{fieldNames.size} fields"
else pure { field with lhs := FieldLHS.fieldName ref fieldNames[idx - 1] :: rest }
| _ => pure field
/- For example, consider the following structures:
@ -334,38 +331,36 @@ s.modifyFieldsM $ fun fields => do
{ toB.toA.x := 0, toB.y := 0, z := true : C }
``` -/
private def expandParentFields (s : Struct) : TermElabM Struct := do
env ← getEnv;
s.modifyFieldsM $ fun fields => fields.mapM $ fun field => match field with
let env ← getEnv
s.modifyFieldsM fun fields => fields.mapM fun field => match field with
| { lhs := FieldLHS.fieldName ref fieldName :: rest, .. } =>
match findField? env s.structName fieldName with
| none => throwErrorAt ref ("'" ++ fieldName ++ "' is not a field of structure '" ++ s.structName ++ "'")
| none => throwErrorAt! ref "'{fieldName}' is not a field of structure '{s.structName}'"
| some baseStructName =>
if baseStructName == s.structName then pure field
else match getPathToBaseStructure? env baseStructName s.structName with
| some path => do
let path := path.map $ fun funName => match funName with
| Name.str _ s _ => FieldLHS.fieldName ref (mkNameSimple s)
| _ => unreachable!;
| _ => unreachable!
pure { field with lhs := path ++ field.lhs }
| _ => throwErrorAt ref ("failed to access field '" ++ fieldName ++ "' in parent structure")
| _ => throwErrorAt! ref "failed to access field '{fieldName}' in parent structure"
| _ => pure field
private abbrev FieldMap := HashMap Name Fields
private def mkFieldMap (fields : Fields) : TermElabM FieldMap :=
fields.foldlM
(fun fieldMap field =>
match field.lhs with
| FieldLHS.fieldName _ fieldName :: rest =>
match fieldMap.find? fieldName with
| some (prevField::restFields) =>
if field.isSimple || prevField.isSimple then
throwErrorAt field.ref ("field '" ++ fieldName ++ "' has already beed specified")
else
pure $ fieldMap.insert fieldName (field::prevField::restFields)
| _ => pure $ fieldMap.insert fieldName [field]
| _ => unreachable!)
{}
fields.foldlM (init := {}) fun fieldMap field =>
match field.lhs with
| FieldLHS.fieldName _ fieldName :: rest =>
match fieldMap.find? fieldName with
| some (prevField::restFields) =>
if field.isSimple || prevField.isSimple then
throwErrorAt! field.ref "field '{fieldName}' has already beed specified"
else
pure $ fieldMap.insert fieldName (field::prevField::restFields)
| _ => pure $ fieldMap.insert fieldName [field]
| _ => unreachable!
private def isSimpleField? : Fields → Option (Field Struct)
| [field] => if field.isSimple then some field else none
@ -374,7 +369,7 @@ private def isSimpleField? : Fields → Option (Field Struct)
private def getFieldIdx (structName : Name) (fieldNames : Array Name) (fieldName : Name) : TermElabM Nat := do
match fieldNames.findIdx? $ fun n => n == fieldName with
| some idx => pure idx
| none => throwError ("field '" ++ fieldName ++ "' is not a valid field of '" ++ structName ++ "'")
| none => throwError! "field '{fieldName}' is not a valid field of '{structName}'"
private def mkProjStx (s : Syntax) (fieldName : Name) : Syntax :=
Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldName]
@ -382,81 +377,78 @@ Syntax.node `Lean.Parser.Term.proj #[s, mkAtomFrom s ".", mkIdentFrom s fieldNam
private def mkSubstructSource (structName : Name) (fieldNames : Array Name) (fieldName : Name) (src : Source) : TermElabM Source :=
match src with
| Source.explicit stx src => do
idx ← getFieldIdx structName fieldNames fieldName;
let stx := stx.modifyArg 0 $ fun stx => mkProjStx stx fieldName;
let idx ← getFieldIdx structName fieldNames fieldName
let stx := stx.modifyArg 0 fun stx => mkProjStx stx fieldName
pure $ Source.explicit stx (mkProj structName idx src)
| s => pure s
@[specialize] private def groupFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
env ← getEnv;
let fieldNames := getStructureFields env s.structName;
withRef s.ref $
s.modifyFieldsM $ fun fields => do
fieldMap ← mkFieldMap fields;
fieldMap.toList.mapM $ fun ⟨fieldName, fields⟩ =>
let env ← getEnv
let fieldNames := getStructureFields env s.structName
withRef s.ref do
s.modifyFieldsM fun fields => do
let fieldMap ← mkFieldMap fields
fieldMap.toList.mapM fun ⟨fieldName, fields⟩ => do
match isSimpleField? fields with
| some field => pure field
| none => do
let substructFields := fields.map $ fun field => { field with lhs := field.lhs.tail! };
substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source;
let field := fields.head!;
| none =>
let substructFields := fields.map fun field => { field with lhs := field.lhs.tail! }
let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source
let field := fields.head!
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName => do
let substruct := Struct.mk s.ref substructName substructFields substructSource;
substruct ← expandStruct substruct;
| some substructName =>
let substruct := Struct.mk s.ref substructName substructFields substructSource
let substruct ← expandStruct substruct
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
| none => do
-- It is not a substructure field. Thus, we wrap fields using `Syntax`, and use `elabTerm` to process them.
let valStx := s.ref; -- construct substructure syntax using s.ref as template
let valStx := valStx.setArg 4 mkNullNode; -- erase optional expected type
let args := substructFields.toArray.map $ Field.toSyntax;
let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ","));
let valStx := setStructSourceSyntax valStx substructSource;
let valStx := s.ref -- construct substructure syntax using s.ref as template
let valStx := valStx.setArg 4 mkNullNode -- erase optional expected type
let args := substructFields.toArray.map Field.toSyntax
let valStx := valStx.setArg 2 (mkSepStx args (mkAtomFrom s.ref ","))
let valStx := setStructSourceSyntax valStx substructSource
pure { field with lhs := [field.lhs.head!], val := FieldVal.term valStx }
def findField? (fields : Fields) (fieldName : Name) : Option (Field Struct) :=
fields.find? $ fun field =>
fields.find? fun field =>
match field.lhs with
| [FieldLHS.fieldName _ n] => n == fieldName
| _ => false
@[specialize] private def addMissingFields (expandStruct : Struct → TermElabM Struct) (s : Struct) : TermElabM Struct := do
env ← getEnv;
let fieldNames := getStructureFields env s.structName;
let ref := s.ref;
let env ← getEnv
let fieldNames := getStructureFields env s.structName
let ref := s.ref
withRef ref do
fields ← fieldNames.foldlM
(fun fields fieldName => do
match findField? s.fields fieldName with
| some field => pure $ field::fields
| none =>
let addField (val : FieldVal Struct) : TermElabM Fields := do {
pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields
};
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName => do
substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source;
let substruct := Struct.mk s.ref substructName [] substructSource;
substruct ← expandStruct substruct;
addField (FieldVal.nested substruct)
| none =>
match s.source with
| Source.none => addField FieldVal.default
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
| Source.explicit stx _ =>
-- stx is of the form `optional (try (termParser >> "with"))`
let src := stx.getArg 0;
let val := mkProjStx src fieldName;
addField (FieldVal.term val))
[];
let fields ← fieldNames.foldlM (init := []) fun fields fieldName => do
match findField? s.fields fieldName with
| some field => pure $ field::fields
| none =>
let addField (val : FieldVal Struct) : TermElabM Fields := do
pure $ { ref := s.ref, lhs := [FieldLHS.fieldName s.ref fieldName], val := val } :: fields
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName => do
let substructSource ← mkSubstructSource s.structName fieldNames fieldName s.source
let substruct := Struct.mk s.ref substructName [] substructSource
let substruct ← expandStruct substruct
addField (FieldVal.nested substruct)
| none =>
match s.source with
| Source.none => addField FieldVal.default
| Source.implicit _ => addField (FieldVal.term (mkHole s.ref))
| Source.explicit stx _ =>
-- stx is of the form `optional (try (termParser >> "with"))`
let src := stx[0]
let val := mkProjStx src fieldName
addField (FieldVal.term val)
pure $ s.setFields fields.reverse
private partial def expandStruct : Struct → TermElabM Struct
| s => do
let s := expandCompositeFields s;
s ← expandNumLitFields s;
s ← expandParentFields s;
s ← groupFields expandStruct s;
let s := expandCompositeFields s
let s ← expandNumLitFields s
let s ← expandParentFields s
let s ← groupFields expandStruct s
addMissingFields expandStruct s
structure CtorHeaderResult :=
@ -467,15 +459,15 @@ structure CtorHeaderResult :=
private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
| n+1, type, ctorFn, instMVars => do
type ← whnfForall type;
let type ← whnfForall type
match type with
| Expr.forallE _ d b c =>
match c.binderInfo with
| BinderInfo.instImplicit => do
a ← mkFreshExprMVar d MetavarKind.synthetic;
| BinderInfo.instImplicit =>
let a ← mkFreshExprMVar d MetavarKind.synthetic
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
| _ => do
a ← mkFreshExprMVar d;
| _ =>
let a ← mkFreshExprMVar d
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
| _ => throwError "unexpected constructor type"
@ -487,21 +479,21 @@ private partial def getForallBody : Nat → Expr → Option Expr
private def propagateExpectedType (type : Expr) (numFields : Nat) (expectedType? : Option Expr) : TermElabM Unit :=
match expectedType? with
| none => pure ()
| some expectedType =>
| some expectedType => do
match getForallBody numFields type with
| none => pure ()
| some typeBody =>
unless typeBody.hasLooseBVars $ do
_ ← isDefEq expectedType typeBody;
unless typeBody.hasLooseBVars do
isDefEq expectedType typeBody
pure ()
private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do
lvls ← ctorVal.lparams.mapM $ fun _ => mkFreshLevelMVar;
let val := Lean.mkConst ctorVal.name lvls;
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls;
r ← mkCtorHeaderAux ctorVal.nparams type val #[];
propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?;
synthesizeAppInstMVars r.instMVars;
let lvls ← ctorVal.lparams.mapM fun _ => mkFreshLevelMVar
let val := Lean.mkConst ctorVal.name lvls
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls
let r ← mkCtorHeaderAux ctorVal.nparams type val #[]
propagateExpectedType r.ctorFnType ctorVal.nfields expectedType?
synthesizeAppInstMVars r.instMVars
pure r
def markDefaultMissing (e : Expr) : Expr :=
@ -511,43 +503,40 @@ def defaultMissing? (e : Expr) : Option Expr :=
annotation? `structInstDefault e
def throwFailedToElabField {α} (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α :=
throwError ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData)
throwError! "failed to elaborate field '{fieldName}' of '{structName}, {msgData}"
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) :=
if !s.allDefault then pure none
def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Option Expr) := do
if !s.allDefault then
pure none
else
catch (synthInstance? expectedType) (fun _ => pure none)
try synthInstance? expectedType catch _ => pure none
private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct)
| s, expectedType? => withRef s.ref do
env ← getEnv;
let ctorVal := getStructureCtor env s.structName;
{ ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?;
(e, _, fields) ← s.fields.foldlM
(fun (acc : Expr × Expr × Fields) field =>
let (e, type, fields) := acc;
match field.lhs with
| [FieldLHS.fieldName ref fieldName] => do
type ← whnfForall type;
match type with
| Expr.forallE _ d b c =>
let continue (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do {
let e := mkApp e val;
let type := b.instantiate1 val;
let field := { field with expr? := some val };
pure (e, type, field::fields)
};
match field.val with
| FieldVal.term stx => do val ← elabTermEnsuringType stx d; continue val field
| FieldVal.nested s => do
val? ← trySynthStructInstance? s d; -- if all fields of `s` are marked as `default`, then try to synthesize instance
match val? with
| some val => continue val { field with val := FieldVal.term (mkHole field.ref) }
| none => do(val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; continue val { field with val := FieldVal.nested sNew }
| FieldVal.default => do val ← withRef field.ref $ mkFreshExprMVar (some d); continue (markDefaultMissing val) field
| _ => withRef field.ref $ throwFailedToElabField fieldName s.structName ("unexpected constructor type" ++ indentExpr type)
| _ => throwErrorAt field.ref "unexpected unexpanded structure field")
(ctorFn, ctorFnType, []);
let env ← getEnv
let ctorVal := getStructureCtor env s.structName
let { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?
let (e, _, fields) ← s.fields.foldlM (init := (ctorFn, ctorFnType, [])) fun (e, type, fields) field =>
match field.lhs with
| [FieldLHS.fieldName ref fieldName] => do
let type ← whnfForall type
match type with
| Expr.forallE _ d b c =>
let cont (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do
let e := mkApp e val
let type := b.instantiate1 val
let field := { field with expr? := some val }
pure (e, type, field::fields)
match field.val with
| FieldVal.term stx => cont (← elabTermEnsuringType stx d) field
| FieldVal.nested s => do
-- if all fields of `s` are marked as `default`, then try to synthesize instance
match (← trySynthStructInstance? s d) with
| some val => cont val { field with val := FieldVal.term (mkHole field.ref) }
| none => do let (val, sNew) ← elabStruct s (some d); val ← ensureHasType d val; cont val { field with val := FieldVal.nested sNew }
| FieldVal.default => do let val ← withRef field.ref $ mkFreshExprMVar (some d); cont (markDefaultMissing val) field
| _ => withRef field.ref $ throwFailedToElabField fieldName s.structName msg!"unexpected constructor type{indentExpr type}"
| _ => throwErrorAt field.ref "unexpected unexpanded structure field"
pure (e, s.setFields fields.reverse)
namespace DefaultFields
@ -587,28 +576,24 @@ structure State :=
partial def collectStructNames : Struct → Array Name → Array Name
| struct, names =>
let names := names.push struct.structName;
struct.fields.foldl
(fun names field =>
match field.val with
| FieldVal.nested struct => collectStructNames struct names
| _ => names)
names
let names := names.push struct.structName
struct.fields.foldl (init := names) fun names field =>
match field.val with
| FieldVal.nested struct => collectStructNames struct names
| _ => names
partial def getHierarchyDepth : Struct → Nat
| struct =>
struct.fields.foldl
(fun max field =>
match field.val with
| FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1)
| _ => max)
0
struct.fields.foldl (init := 0) fun max field =>
match field.val with
| FieldVal.nested struct => Nat.max max (getHierarchyDepth struct + 1)
| _ => max
partial def findDefaultMissing? (mctx : MetavarContext) : Struct → Option (Field Struct)
| struct =>
struct.fields.findSome? $ fun field =>
struct.fields.findSome? fun field =>
match field.val with
| FieldVal.nested struct => findDefaultMissing? struct
| FieldVal.nested struct => findDefaultMissing? mctx struct
| _ => match field.expr? with
| none => unreachable!
| some expr => match defaultMissing? expr with
@ -623,31 +608,30 @@ match field.lhs with
abbrev M := ReaderT Context (StateRefT State TermElabM)
def isRoundDone : M Bool := do
ctx ← read;
s ← get;
pure (s.progress && ctx.maxDistance > 0)
return (← get).progress && (← read).maxDistance > 0
def getFieldValue? (struct : Struct) (fieldName : Name) : Option Expr :=
struct.fields.findSome? $ fun field =>
struct.fields.findSome? fun field =>
if getFieldName field == fieldName then
field.expr?
else
none
partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Expr)
| Expr.lam n d b c => withRef struct.ref $
| Expr.lam n d b c => withRef struct.ref do
if c.binderInfo.isExplicit then
let fieldName := n;
let fieldName := n
match getFieldValue? struct fieldName with
| none => pure none
| some val => do
valType ← inferType val;
condM (isDefEq valType d)
(mkDefaultValueAux? (b.instantiate1 val))
(pure none)
else do
arg ← mkFreshExprMVar d;
mkDefaultValueAux? (b.instantiate1 arg)
| some val =>
let valType ← inferType val
if (← isDefEq valType d) then
mkDefaultValueAux? struct (b.instantiate1 val)
else
pure none
else
let arg ← mkFreshExprMVar d
mkDefaultValueAux? struct (b.instantiate1 arg)
| e =>
if e.isAppOfArity `id 2 then
pure (some e.appArg!)
@ -656,7 +640,7 @@ partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Ex
def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
withRef struct.ref do
us ← cinfo.lparams.mapM $ fun _ => mkFreshLevelMVar;
let us ← cinfo.lparams.mapM fun _ => mkFreshLevelMVar
mkDefaultValueAux? struct (cinfo.instantiateValueLevelParams us)
/-- If `e` is a projection function of one of the given structures, then reduce it -/
@ -664,7 +648,7 @@ def reduceProjOf? (structNames : Array Name) (e : Expr) : MetaM (Option Expr) :=
if !e.isApp then pure none
else match e.getAppFn with
| Expr.const name _ _ => do
env ← getEnv;
let env ← getEnv
match env.getProjectionStructureName? name with
| some structName =>
if structNames.contains structName then
@ -676,142 +660,126 @@ else match e.getAppFn with
/-- Reduce default value. It performs beta reduction and projections of the given structures. -/
partial def reduce (structNames : Array Name) : Expr → MetaM Expr
| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLambdaFVars xs b
| e@(Expr.forallE _ _ _ _) => forallTelescope e $ fun xs b => do b ← reduce b; mkForallFVars xs b
| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e $ fun xs b => do b ← reduce b; mkLetFVars xs b
| e@(Expr.lam _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLambdaFVars xs (← reduce structNames b)
| e@(Expr.forallE _ _ _ _) => forallTelescope e fun xs b => do mkForallFVars xs (← reduce structNames b)
| e@(Expr.letE _ _ _ _ _) => lambdaLetTelescope e fun xs b => do mkLetFVars xs (← reduce structNames b)
| e@(Expr.proj _ i b _) => do
r? ← Meta.reduceProj? b i;
match r? with
| some r => reduce r
| none => do b ← reduce b; pure $ e.updateProj! b
match (← Meta.reduceProj? b i) with
| some r => reduce structNames r
| none => pure $ e.updateProj! (← reduce structNames b)
| e@(Expr.app f _ _) => do
r? ← reduceProjOf? structNames e;
match r? with
| some r => reduce r
| none => do
let f := f.getAppFn;
f' ← reduce f;
match (← reduceProjOf? structNames e) with
| some r => reduce structNames r
| none =>
let f := f.getAppFn
let f' ← reduce structNames f
if f'.isLambda then
let revArgs := e.getAppRevArgs;
reduce $ f'.betaRev revArgs
else do
args ← e.getAppArgs.mapM reduce;
let revArgs := e.getAppRevArgs
reduce structNames (f'.betaRev revArgs)
else
let args ← e.getAppArgs.mapM (reduce structNames)
pure (mkAppN f' args)
| e@(Expr.mdata _ b _) => do
b ← reduce b;
let b ← reduce structNames b
if (defaultMissing? e).isSome && !b.isMVar then
pure b
else
pure $ e.updateMData! b
| e@(Expr.mvar mvarId _) => do
val? ← getExprMVarAssignment? mvarId;
match val? with
| some val => if val.isMVar then reduce val else pure val
match (← getExprMVarAssignment? mvarId) with
| some val => if val.isMVar then reduce structNames val else pure val
| none => pure e
| e => pure e
partial def tryToSynthesizeDefaultAux (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat)
(fieldName : Name) (mvarId : MVarId) : Nat → Nat → TermElabM Bool
| i, dist =>
if dist > maxDistance then pure false
partial def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name) (maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
let rec loop (i : Nat) (dist : Nat) := do
if dist > maxDistance then
pure false
else if h : i < structs.size then do
let struct := structs.get ⟨i, h⟩;
let defaultName := struct.structName ++ fieldName ++ `_default;
env ← getEnv;
let struct := structs.get ⟨i, h⟩
let defaultName := struct.structName ++ fieldName ++ `_default
let env ← getEnv
match env.find? defaultName with
| some cinfo@(ConstantInfo.defnInfo defVal) => do
mctx ← getMCtx;
val? ← mkDefaultValue? struct cinfo;
let mctx ← getMCtx
let val? ← mkDefaultValue? struct cinfo
match val? with
| none => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1)
| none => do setMCtx mctx; loop (i+1) (dist+1)
| some val => do
val ← liftMetaM $ reduce allStructNames val;
match val.find? $ fun e => (defaultMissing? e).isSome with
| some _ => do setMCtx mctx; tryToSynthesizeDefaultAux (i+1) (dist+1)
| none => do
mvarDecl ← getMVarDecl mvarId;
val ← ensureHasType mvarDecl.type val;
assignExprMVar mvarId val;
let val ← liftMetaM $ reduce allStructNames val
match val.find? fun e => (defaultMissing? e).isSome with
| some _ => setMCtx mctx; loop (i+1) (dist+1)
| none =>
let mvarDecl ← getMVarDecl mvarId
let val ← ensureHasType mvarDecl.type val
assignExprMVar mvarId val
pure true
| _ => tryToSynthesizeDefaultAux (i+1) dist
| _ => loop (i+1) dist
else
pure false
def tryToSynthesizeDefault (structs : Array Struct) (allStructNames : Array Name)
(maxDistance : Nat) (fieldName : Name) (mvarId : MVarId) : TermElabM Bool :=
tryToSynthesizeDefaultAux structs allStructNames maxDistance fieldName mvarId 0 0
loop 0 0
partial def step : Struct → M Unit
| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) $ do
struct.fields.forM $ fun field =>
| struct => unlessM isRoundDone $ withReader (fun ctx => { ctx with structs := ctx.structs.push struct }) do
struct.fields.forM fun field => do
match field.val with
| FieldVal.nested struct => step struct
| _ => match field.expr? with
| none => unreachable!
| some expr => match defaultMissing? expr with
| some (Expr.mvar mvarId _) =>
unlessM (liftM $ isExprMVarAssigned mvarId) $ do
ctx ← read;
whenM (liftM $ withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) $ do
modify $ fun s => { s with progress := true }
unless (← isExprMVarAssigned mvarId) do
let ctx ← read
if (← withRef field.ref $ tryToSynthesizeDefault ctx.structs ctx.allStructNames ctx.maxDistance (getFieldName field) mvarId) then
modify fun s => { s with progress := true }
| _ => pure ()
partial def propagateLoop (hierarchyDepth : Nat) : Nat → Struct → M Unit
| d, struct => do
mctx ← getMCtx;
match findDefaultMissing? mctx struct with
match findDefaultMissing? (← getMCtx) struct with
| none => pure () -- Done
| some field =>
if d > hierarchyDepth then
throwErrorAt field.ref ("field '" ++ getFieldName field ++ "' is missing")
else withReader (fun ctx => { ctx with maxDistance := d }) $ do
modify $ fun (s : State) => { s with progress := false };
step struct;
s ← get;
if s.progress then do
propagateLoop 0 struct
throwErrorAt! field.ref "field '{getFieldName field}' is missing"
else withReader (fun ctx => { ctx with maxDistance := d }) do
modify fun s => { s with progress := false }
step struct
if (← get).progress then do
propagateLoop hierarchyDepth 0 struct
else
propagateLoop (d+1) struct
propagateLoop hierarchyDepth (d+1) struct
def propagate (struct : Struct) : TermElabM Unit :=
let hierarchyDepth := getHierarchyDepth struct;
let structNames := collectStructNames struct #[];
let hierarchyDepth := getHierarchyDepth struct
let structNames := collectStructNames struct #[]
(propagateLoop hierarchyDepth 0 struct { allStructNames := structNames }).run' {}
end DefaultFields
private def elabStructInstAux (stx : Syntax) (expectedType? : Option Expr) (source : Source) : TermElabM Expr := do
structName ← getStructName stx expectedType? source;
env ← getEnv;
unless (isStructureLike env structName) $
throwError ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
let structName ← getStructName stx expectedType? source
unless isStructureLike (← getEnv) structName do
throwError! "invalid \{...} notation, '{structName}' is not a structure"
match mkStructView stx structName source with
| Except.error ex => throwError ex
| Except.ok struct => do
struct ← expandStruct struct;
trace `Elab.struct fun _ => toString struct;
(r, struct) ← elabStruct struct expectedType?;
DefaultFields.propagate struct;
| Except.ok struct =>
let struct ← expandStruct struct
trace[Elab.struct]! "{struct}"
let (r, struct) ← elabStruct struct expectedType?
DefaultFields.propagate struct
pure r
@[builtinTermElab structInst] def elabStructInst : TermElab :=
fun stx expectedType? => do
stxNew? ← expandNonAtomicExplicitSource stx;
match stxNew? with
match (← expandNonAtomicExplicitSource stx) with
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| none => do
sourceView ← getStructSource stx;
modifyOp? ← isModifyOp? stx;
match modifyOp?, sourceView with
| none =>
let sourceView ← getStructSource stx
match (← isModifyOp? stx), sourceView with
| some modifyOp, Source.explicit source _ => elabModifyOp stx modifyOp source expectedType?
| some _, _ => throwError ("invalid {...} notation, explicit source is required when using '[<index>] := <value>'")
| some _, _ => throwError "invalid {...} notation, explicit source is required when using '[<index>] := <value>'"
| _, _ => elabStructInstAux stx expectedType? sourceView
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.struct;
pure ()
initialize registerTraceClass `Elab.struct
end StructInst
end Term
end Elab
end Lean
end Lean.Elab.Term.StructInst

View file

@ -168,13 +168,13 @@ self.find? idx
match m with
| ⟨ m, _ ⟩ => m.contains a
@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → β → m δ) (d : δ) (h : HashMap α β) : m δ :=
@[inline] def foldM {δ : Type w} {m : Type w → Type w} [Monad m] (f : δ → α → β → m δ) (init : δ) (h : HashMap α β) : m δ :=
match h with
| ⟨ h, _ ⟩ => h.foldM f d
| ⟨ h, _ ⟩ => h.foldM f init
@[inline] def fold {δ : Type w} (f : δ → α → β → δ) (d : δ) (m : HashMap α β) : δ :=
@[inline] def fold {δ : Type w} (f : δ → α → β → δ) (init : δ) (m : HashMap α β) : δ :=
match m with
| ⟨ m, _ ⟩ => m.fold f d
| ⟨ m, _ ⟩ => m.fold f init
@[inline] def size (m : HashMap α β) : Nat :=
match m with