refactor: remove unnecessary complexity

This commit is contained in:
Leonardo de Moura 2020-02-17 11:43:55 -08:00
parent 9b23c6cd10
commit bd951b40ce

View file

@ -454,28 +454,25 @@ private partial def expandStruct : Struct → TermElabM Struct
s ← groupFields expandStruct s;
addMissingFields expandStruct s
structure State :=
(instMVars : Array MVarId := #[])
structure CtorHeaderResult :=
(ctorFn : Expr)
(ctorFnType : Expr)
(instMVars : Array MVarId := #[])
private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → StateT State TermElabM CtorHeaderResult
| 0, type, ctorFn => pure { ctorFn := ctorFn, ctorFnType := type }
| n+1, type, ctorFn => do
type ← liftM $ whnfForall ref type;
private def mkCtorHeaderAux (ref : Syntax) : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
| n+1, type, ctorFn, instMVars => do
type ← whnfForall ref type;
match type with
| Expr.forallE _ d b c =>
match c.binderInfo with
| BinderInfo.instImplicit => do
a ← liftM $ mkFreshExprMVar ref d MetavarKind.synthetic;
modify $ fun s => { instMVars := s.instMVars.push a.mvarId!, .. s };
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a)
a ← mkFreshExprMVar ref d MetavarKind.synthetic;
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
| _ => do
a ← liftM $ mkFreshExprMVar ref d;
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a)
| _ => liftM $ throwError ref "unexpected constructor type"
a ← mkFreshExprMVar ref d;
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
| _ => throwError ref "unexpected constructor type"
private partial def getForallBody : Nat → Expr → Option Expr
| i+1, Expr.forallE _ _ b _ => getForallBody i b
@ -493,45 +490,46 @@ match expectedType? with
isDefEq ref expectedType typeBody;
pure ()
private def mkCtorHeader (ref : Syntax) (ctorVal : ConstructorVal) (expectedType? : Option Expr) : StateT State TermElabM CtorHeaderResult := do
lvls ← ctorVal.lparams.mapM $ fun _ => liftM $ mkFreshLevelMVar ref;
private def mkCtorHeader (ref : Syntax) (ctorVal : ConstructorVal) (expectedType? : Option Expr) : TermElabM CtorHeaderResult := do
lvls ← ctorVal.lparams.mapM $ fun _ => mkFreshLevelMVar ref;
let val := Lean.mkConst ctorVal.name lvls;
let type := (ConstantInfo.ctorInfo ctorVal).instantiateTypeLevelParams lvls;
r ← mkCtorHeaderAux ref ctorVal.nparams type val;
liftM $ propagateExpectedType ref r.ctorFnType ctorVal.nfields expectedType?;
r ← mkCtorHeaderAux ref ctorVal.nparams type val #[];
propagateExpectedType ref r.ctorFnType ctorVal.nfields expectedType?;
synthesizeAppInstMVars ref r.instMVars;
pure r
def markDefaultMissing (e : Expr) : Expr :=
mkMData (KVMap.empty.insert `structInstDefault (DataValue.ofBool true)) e
def throwFailedToElabField {α} (ref : Syntax) (fieldName : Name) (structName : Name) (msgData : MessageData) : StateT State TermElabM α :=
liftM $ throwError ref ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData)
def throwFailedToElabField {α} (ref : Syntax) (fieldName : Name) (structName : Name) (msgData : MessageData) : TermElabM α :=
throwError ref ("failed to elaborate field '" ++ fieldName ++ "' of '" ++ structName ++ ", " ++ msgData)
private partial def elabStruct : Struct → Option Expr → StateT State TermElabM (Expr × Struct)
private partial def elabStruct : Struct → Option Expr → TermElabM (Expr × Struct)
| s, expectedType? => do
env ← liftM $ getEnv;
env ← getEnv;
let ctorVal := getStructureCtor env s.structName;
{ ctorFn := ctorFn, ctorFnType := ctorFnType } ← mkCtorHeader s.ref ctorVal expectedType?;
{ ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader s.ref ctorVal expectedType?;
(e, _, fields) ← s.fields.foldlM
(fun (acc : Expr × Expr × Fields) field =>
let (e, type, fields) := acc;
match field.lhs with
| [FieldLHS.fieldName ref fieldName] => do
type ← liftM $ whnfForall field.ref type;
type ← whnfForall field.ref type;
match type with
| Expr.forallE _ d b c =>
let continue (val : Expr) (field : Field Struct) : StateT State TermElabM (Expr × Expr × Fields) := do {
let continue (val : Expr) (field : Field Struct) : TermElabM (Expr × Expr × Fields) := do {
let e := mkApp e val;
let type := b.instantiate1 val;
let field := { expr := some val, .. field };
pure (e, type, field::fields)
};
match field.val with
| FieldVal.term stx => do val ← liftM $ elabTerm stx (some d); continue val field
| FieldVal.term stx => do val ← elabTerm stx (some d); continue val field
| FieldVal.nested s => do (val, sNew) ← elabStruct s (some d); continue val { val := FieldVal.nested sNew, .. field }
| FieldVal.default => do val ← liftM $ mkFreshExprMVar field.ref (some d); continue (markDefaultMissing val) field
| FieldVal.default => do val ← mkFreshExprMVar field.ref (some d); continue (markDefaultMissing val) field
| _ => throwFailedToElabField field.ref fieldName s.structName ("unexpected constructor type" ++ indentExpr type)
| _ => liftM $ throwError field.ref "unexpected unexpanded structure field")
| _ => throwError field.ref "unexpected unexpanded structure field")
(ctorFn, ctorFnType, []);
pure (e, s.setFields fields.reverse)
@ -542,7 +540,7 @@ unless (isStructureLike env structName) $
throwError stx ("invalid {...} notation, '" ++ structName ++ "' is not a structure");
struct ← expandStruct $ mkStructView stx structName source;
trace `Elab.struct stx $ fun _ => toString struct;
((r, struct), s)(elabStruct struct expectedType?).run {};
(r, struct) ← elabStruct struct expectedType?;
-- TODO: resolve missing default
pure r