feat: refine auto bound implicit locals

This commit is contained in:
Leonardo de Moura 2021-03-23 17:03:33 -07:00
parent ec409a9bfc
commit 99cd4fa720
12 changed files with 146 additions and 169 deletions

View file

@ -157,33 +157,26 @@ private def ensureAtomicBinderName (binderView : BinderView) : TermElabM Unit :=
unless n.isAtomic do
throwErrorAt binderView.id "invalid binder name '{n}', it must be atomic"
private partial def elabBinderViews {α} (binderViews : Array BinderView) (catchAutoBoundImplicit : Bool) (fvars : Array Expr) (k : Array Expr → TermElabM α)
private partial def elabBinderViews {α} (binderViews : Array BinderView) (fvars : Array Expr) (k : Array Expr → TermElabM α)
: TermElabM α :=
let rec loop (i : Nat) (fvars : Array Expr) : TermElabM α := do
if h : i < binderViews.size then
let binderView := binderViews.get ⟨i, h⟩
ensureAtomicBinderName binderView
if catchAutoBoundImplicit then
elabTypeWithAutoBoundImplicit binderView.type fun type => do
registerFailedToInferBinderTypeInfo type binderView.type
withLocalDecl binderView.id.getId binderView.bi type fun fvar => do
addLocalVarInfo binderView.id fvar
loop (i+1) (fvars.push fvar)
else
let type ← elabType binderView.type
registerFailedToInferBinderTypeInfo type binderView.type
withLocalDecl binderView.id.getId binderView.bi type fun fvar => do
addLocalVarInfo binderView.id fvar
loop (i+1) (fvars.push fvar)
let type ← elabType binderView.type
registerFailedToInferBinderTypeInfo type binderView.type
withLocalDecl binderView.id.getId binderView.bi type fun fvar => do
addLocalVarInfo binderView.id fvar
loop (i+1) (fvars.push fvar)
else
k fvars
loop 0 fvars
private partial def elabBindersAux {α} (binders : Array Syntax) (catchAutoBoundImplicit : Bool) (k : Array Expr → TermElabM α) : TermElabM α :=
private partial def elabBindersAux {α} (binders : Array Syntax) (k : Array Expr → TermElabM α) : TermElabM α :=
let rec loop (i : Nat) (fvars : Array Expr) : TermElabM α := do
if h : i < binders.size then
let binderViews ← matchBinder (binders.get ⟨i, h⟩)
elabBinderViews binderViews catchAutoBoundImplicit fvars <| loop (i+1)
elabBinderViews binderViews fvars <| loop (i+1)
else
k fvars
loop 0 #[]
@ -192,15 +185,15 @@ private partial def elabBindersAux {α} (binders : Array Syntax) (catchAutoBound
Elaborate the given binders (i.e., `Syntax` objects for `simpleBinder <|> bracketedBinder`),
update the local context, set of local instances, reset instance chache (if needed), and then
execute `x` with the updated context. -/
def elabBinders {α} (binders : Array Syntax) (k : Array Expr → TermElabM α) (catchAutoBoundImplicit := false) : TermElabM α :=
def elabBinders {α} (binders : Array Syntax) (k : Array Expr → TermElabM α) : TermElabM α :=
withoutPostponingUniverseConstraints do
if binders.isEmpty then
k #[]
else
elabBindersAux binders catchAutoBoundImplicit k
elabBindersAux binders k
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) (catchAutoBoundImplicit := false) : TermElabM α :=
elabBinders #[binder] (catchAutoBoundImplicit := catchAutoBoundImplicit) (fun fvars => x (fvars.get! 0))
@[inline] def elabBinder {α} (binder : Syntax) (x : Expr → TermElabM α) : TermElabM α :=
elabBinders #[binder] fun fvars => x fvars[0]
@[builtinTermElab «forall»] def elabForall : TermElab := fun stx _ =>
match stx with

View file

@ -330,17 +330,20 @@ def liftTermElabM {α} (declName? : Option Name) (x : TermElabM α) : CommandEla
@[inline] def runTermElabM {α} (declName? : Option Name) (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
let scope ← getScope
liftTermElabM declName? <|
-- We don't want to store messages produced when elaborating `(getVarDecls s)` because they have already been saved when we elaborated the `variable`(s) command.
-- So, we use `Term.resetMessageLog`.
Term.withAutoBoundImplicitLocal <|
Term.elabBinders scope.varDecls (catchAutoBoundImplicit := true) fun xs => do
Term.withAutoBoundImplicit <|
Term.elabBinders scope.varDecls fun xs => do
-- We need to synthesize postponed terms because this is a checkpoint for the auto-bound implicit feature
-- If we don't use this checkpoint here, then auto-bound implicits in the postponed terms will not be handled correctly.
Term.synthesizeSyntheticMVarsNoPostponing
let mut sectionFVars := {}
for uid in scope.varUIds, x in xs do
sectionFVars := sectionFVars.insert uid x
withReader ({ · with sectionFVars := sectionFVars }) do
-- We don't want to store messages produced when elaborating `(getVarDecls s)` because they have already been saved when we elaborated the `variable`(s) command.
-- So, we use `Term.resetMessageLog`.
Term.resetMessageLog
let xs ← Term.addAutoBoundImplicits xs
Term.withAutoBoundImplicitLocal (flag := false) <| elabFn xs
Term.withoutAutoBoundImplicit <| elabFn xs
@[inline] def catchExceptions (x : CommandElabM Unit) : CommandElabCoreM Empty Unit := fun ctx ref =>
EIO.catchExceptions (withLogging x ctx ref) (fun _ => pure ())
@ -485,8 +488,8 @@ partial def elabChoiceAux (cmds : Array Syntax) (i : Nat) : CommandElabM Unit :=
@[builtinCommandElab «variable»] def elabVariable : CommandElab
| `(variable $binders*) => do
-- Try to elaborate `binders` for sanity checking
runTermElabM none fun _ => Term.withAutoBoundImplicitLocal <|
Term.elabBinders binders (catchAutoBoundImplicit := true) fun _ => pure ()
runTermElabM none fun _ => Term.withAutoBoundImplicit <|
Term.elabBinders binders fun _ => pure ()
let varUIds ← binders.concatMap getBracketedBinderIds |>.mapM (withFreshMacroScope ∘ MonadQuotation.addMacroScope)
modifyScope fun scope => { scope with varDecls := scope.varDecls ++ binders, varUIds := scope.varUIds ++ varUIds }
| _ => throwUnsupportedSyntax

View file

@ -70,19 +70,21 @@ structure ElabHeaderResult where
private partial def elabHeaderAux (views : Array InductiveView) (i : Nat) (acc : Array ElabHeaderResult) : TermElabM (Array ElabHeaderResult) := do
if h : i < views.size then
let view := views.get ⟨i, h⟩
let acc ← Term.withAutoBoundImplicitLocal <| Term.elabBinders view.binders.getArgs (catchAutoBoundImplicit := true) fun params => do
let acc ← Term.withAutoBoundImplicit <| Term.elabBinders view.binders.getArgs fun params => do
match view.type? with
| none =>
let u ← mkFreshLevelMVar
let type := mkSort u
Term.synthesizeSyntheticMVarsNoPostponing
let params ← Term.addAutoBoundImplicits params
pure <| acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), params := params, type := type, view := view }
| some typeStx =>
Term.elabTypeWithAutoBoundImplicit typeStx fun type => do
unless (← isTypeFormerType type) do
throwErrorAt typeStx "invalid inductive type, resultant type is not a sort"
let params ← Term.addAutoBoundImplicits params
pure <| acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), params := params, type := type, view := view }
let type ← Term.elabType typeStx
unless (← isTypeFormerType type) do
throwErrorAt typeStx "invalid inductive type, resultant type is not a sort"
Term.synthesizeSyntheticMVarsNoPostponing
let params ← Term.addAutoBoundImplicits params
pure <| acc.push { lctx := (← getLCtx), localInsts := (← getLocalInstances), params := params, type := type, view := view }
elabHeaderAux views (i+1) acc
else
pure acc
@ -192,7 +194,7 @@ private def isInductiveFamily (numParams : Nat) (indFVar : Expr) : TermElabM Boo
private def elabCtors (indFVars : Array Expr) (indFVar : Expr) (params : Array Expr) (r : ElabHeaderResult) : TermElabM (List Constructor) := withRef r.view.ref do
let indFamily ← isInductiveFamily params.size indFVar
r.view.ctors.toList.mapM fun ctorView =>
Term.withAutoBoundImplicitLocal <| Term.elabBinders (catchAutoBoundImplicit := true) ctorView.binders.getArgs fun ctorParams =>
Term.withAutoBoundImplicit <| Term.elabBinders ctorView.binders.getArgs fun ctorParams =>
withRef ctorView.ref do
let rec elabCtorType (k : Expr → TermElabM Constructor) : TermElabM Constructor := do
match ctorView.type? with
@ -201,17 +203,18 @@ private def elabCtors (indFVars : Array Expr) (indFVar : Expr) (params : Array E
throwError "constructor resulting type must be specified in inductive family declaration"
k <| mkAppN indFVar params
| some ctorType =>
Term.elabTypeWithAutoBoundImplicit ctorType fun type => do
Term.synthesizeSyntheticMVars (mayPostpone := true)
let type ← instantiateMVars type
let type ← checkParamOccs type
forallTelescopeReducing type fun _ resultingType => do
unless resultingType.getAppFn == indFVar do
throwError "unexpected constructor resulting type{indentExpr resultingType}"
unless (← isType resultingType) do
throwError "unexpected constructor resulting type, type expected{indentExpr resultingType}"
k type
let type ← Term.elabType ctorType
Term.synthesizeSyntheticMVars (mayPostpone := true)
let type ← instantiateMVars type
let type ← checkParamOccs type
forallTelescopeReducing type fun _ resultingType => do
unless resultingType.getAppFn == indFVar do
throwError "unexpected constructor resulting type{indentExpr resultingType}"
unless (← isType resultingType) do
throwError "unexpected constructor resulting type, type expected{indentExpr resultingType}"
k type
elabCtorType fun type => do
Term.synthesizeSyntheticMVarsNoPostponing
let ctorParams ← Term.addAutoBoundImplicits ctorParams
let type ← mkForallFVars ctorParams type
let type ← mkForallFVars params type

View file

@ -74,20 +74,6 @@ private def check (prevHeaders : Array DefViewElabHeader) (newHeader : DefViewEl
private def registerFailedToInferDefTypeInfo (type : Expr) (ref : Syntax) : TermElabM Unit :=
registerCustomErrorIfMVar type ref "failed to infer definition type"
private def elabFunType (ref : Syntax) (xs : Array Expr) (view : DefView) (k : Array Expr → Expr → TermElabM α) : TermElabM α := do
match view.type? with
| some typeStx =>
elabTypeWithAutoBoundImplicit typeStx fun type => do
synthesizeSyntheticMVarsNoPostponing
let type ← instantiateMVars type
registerFailedToInferDefTypeInfo type typeStx
k xs (← mkForallFVars xs type)
| none =>
let hole := mkHole ref
let type ← elabType hole
registerFailedToInferDefTypeInfo type ref
k xs (← mkForallFVars xs type)
private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHeader) := do
let mut headers := #[]
for view in views do
@ -95,31 +81,41 @@ private def elabHeaders (views : Array DefView) : TermElabM (Array DefViewElabHe
let ⟨shortDeclName, declName, levelNames⟩ ← expandDeclId (← getCurrNamespace) (← getLevelNames) view.declId view.modifiers
addDeclarationRanges declName view.ref
applyAttributesAt declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration
withDeclName declName <| withAutoBoundImplicitLocal <| withLevelNames levelNames <|
elabBinders (catchAutoBoundImplicit := true) view.binders.getArgs fun xs => do
withDeclName declName <| withAutoBoundImplicit <| withLevelNames levelNames <|
elabBinders view.binders.getArgs fun xs => do
let refForElabFunType := view.value
elabFunType refForElabFunType xs view fun xs type => do
let mut type ← mkForallFVars (← read).autoBoundImplicits.toArray type
let xs ← addAutoBoundImplicits xs
let levelNames ← getLevelNames
if view.type?.isSome then
Term.synthesizeSyntheticMVarsNoPostponing
type ← instantiateMVars type
let pendingMVarIds ← getMVars type
discard <| logUnassignedUsingErrorInfos pendingMVarIds <|
m!"\nwhen the resulting type of a declaration is explicitly provided, all holes (e.g., `_`) in the header are resolved before the declaration body is processed"
let newHeader := {
ref := view.ref,
modifiers := view.modifiers,
kind := view.kind,
shortDeclName := shortDeclName,
declName := declName,
levelNames := levelNames,
numParams := xs.size,
type := type,
valueStx := view.value : DefViewElabHeader }
check headers newHeader
pure newHeader
let type ← match view.type? with
| some typeStx =>
let type ← elabType typeStx
registerFailedToInferDefTypeInfo type typeStx
pure type
| none =>
let hole := mkHole refForElabFunType
let type ← elabType hole
registerFailedToInferDefTypeInfo type refForElabFunType
pure type
Term.synthesizeSyntheticMVarsNoPostponing
let type ← mkForallFVars xs type
let type ← mkForallFVars (← read).autoBoundImplicits.toArray type
let type ← instantiateMVars type
let xs ← addAutoBoundImplicits xs
let levelNames ← getLevelNames
if view.type?.isSome then
let pendingMVarIds ← getMVars type
discard <| logUnassignedUsingErrorInfos pendingMVarIds <|
m!"\nwhen the resulting type of a declaration is explicitly provided, all holes (e.g., `_`) in the header are resolved before the declaration body is processed"
let newHeader := {
ref := view.ref,
modifiers := view.modifiers,
kind := view.kind,
shortDeclName := shortDeclName,
declName := declName,
levelNames := levelNames,
numParams := xs.size,
type := type,
valueStx := view.value : DefViewElabHeader }
check headers newHeader
pure newHeader
headers := headers.push newHeader
pure headers

View file

@ -259,28 +259,31 @@ private partial def withParents (view : StructView) (i : Nat) (infos : Array Str
k infos
private def elabFieldTypeValue (view : StructFieldView) : TermElabM (Option Expr × Option Expr) := do
Term.withAutoBoundImplicitLocal <| Term.elabBinders (catchAutoBoundImplicit := true) view.binders.getArgs fun params => do
Term.withAutoBoundImplicit <| Term.elabBinders view.binders.getArgs fun params => do
match view.type? with
| none =>
match view.value? with
| none => return (none, none)
| some valStx =>
Term.synthesizeSyntheticMVarsNoPostponing
let params ← Term.addAutoBoundImplicits params
let value ← Term.elabTerm valStx none
let value ← mkLambdaFVars params value
return (none, value)
| some typeStx =>
Term.elabTypeWithAutoBoundImplicit typeStx fun type => do
let params ← Term.addAutoBoundImplicits params
match view.value? with
| none =>
let type ← mkForallFVars params type
return (type, none)
| some valStx =>
let value ← Term.elabTermEnsuringType valStx type
let type ← mkForallFVars params type
let value ← mkLambdaFVars params value
return (type, value)
let type ← Term.elabType typeStx
Term.synthesizeSyntheticMVarsNoPostponing
let params ← Term.addAutoBoundImplicits params
match view.value? with
| none =>
let type ← mkForallFVars params type
return (type, none)
| some valStx =>
let value ← Term.elabTermEnsuringType valStx type
Term.synthesizeSyntheticMVarsNoPostponing
let type ← mkForallFVars params type
let value ← mkLambdaFVars params value
return (type, value)
private partial def withFields
(views : Array StructFieldView) (i : Nat) (infos : Array StructFieldInfo) (k : Array StructFieldInfo → TermElabM α) : TermElabM α := do
@ -557,25 +560,25 @@ def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit :=
Term.withDeclName declName do
let ctor ← expandCtor stx modifiers declName
let fields ← expandFields stx modifiers declName
Term.withLevelNames allUserLevelNames <| Term.withAutoBoundImplicitLocal <|
Term.elabBinders params (catchAutoBoundImplicit := true) fun params => do
Term.withLevelNames allUserLevelNames <| Term.withAutoBoundImplicit <|
Term.elabBinders params fun params => do
Term.synthesizeSyntheticMVarsNoPostponing
let params ← Term.addAutoBoundImplicits params
let allUserLevelNames ← Term.getLevelNames
Term.withAutoBoundImplicitLocal (flag := false) do
elabStructureView {
ref := stx
modifiers := modifiers
scopeLevelNames := scopeLevelNames
allUserLevelNames := allUserLevelNames
declName := declName
isClass := isClass
scopeVars := scopeVars
params := params
parents := parents
type := type
ctor := ctor
fields := fields
}
elabStructureView {
ref := stx
modifiers := modifiers
scopeLevelNames := scopeLevelNames
allUserLevelNames := allUserLevelNames
declName := declName
isClass := isClass
scopeVars := scopeVars
params := params
parents := parents
type := type
ctor := ctor
fields := fields
}
unless isClass do
mkSizeOfInstances declName
return declName

View file

@ -294,11 +294,7 @@ mutual
loop ()
else
reportStuckSyntheticMVars
/- Disable `autoBoundImplicit` to avoid nontermination.
The postponed terms have a fixed `localContext`, i.e. the context of the metavariable representing the "hole".
`throwAutoBoundImplicit` exception will have not effect. -/
withReader (fun ctx => { ctx with autoBoundImplicit := false }) do
loop ()
loop ()
end
def synthesizeSyntheticMVarsNoPostponing : TermElabM Unit :=

View file

@ -342,11 +342,6 @@ def withLevelNames (levelNames : List Name) (x : TermElabM α) : TermElabM α :=
def withoutErrToSorry (x : TermElabM α) : TermElabM α :=
withReader (fun ctx => { ctx with errToSorry := false }) x
/-- Execute `x` with `autoBoundImplicit := (autoBoundImplicitLocal.get options) && flag` -/
def withAutoBoundImplicitLocal (x : TermElabM α) (flag := true) : TermElabM α := do
let flag := autoBoundImplicitLocal.get (← getOptions) && flag
withReader (fun ctx => { ctx with autoBoundImplicit := flag, autoBoundImplicits := {} }) x
/-- For testing `TermElabM` methods. The #eval command will sign the error. -/
def throwErrorIfErrors : TermElabM Unit := do
if (← get).messages.hasErrors then
@ -1136,50 +1131,30 @@ def elabType (stx : Syntax) : TermElabM Expr := do
withRef stx $ ensureType type
/--
Execute `k` while catching auto bound implicit exceptions. When an exception is caught,
Enable auto-bound implicits, and execute `k` while catching auto bound implicit exceptions. When an exception is caught,
a new local declaration is created, registered, and `k` is tried to be executed again. -/
partial def withCatchingAutoBoundExceptions (k : TermElabM α) : TermElabM α := do
if (← read).autoBoundImplicit then
let rec loop : TermElabM α := do
let s ← saveAllState
try
k
catch
| ex => match isAutoBoundImplicitLocalException? ex with
| some n =>
-- Restore state, declare `n`, and try again
s.restore
withLocalDecl n BinderInfo.implicit (← mkFreshTypeMVar) fun x =>
withReader (fun ctx => { ctx with autoBoundImplicits := ctx.autoBoundImplicits.push x } )
loop
| none => throw ex
loop
partial def withAutoBoundImplicit (k : TermElabM α) : TermElabM α := do
let flag := autoBoundImplicitLocal.get (← getOptions)
if flag then
withReader (fun ctx => { ctx with autoBoundImplicit := flag, autoBoundImplicits := {} }) do
let rec loop (s : SavedState) : TermElabM α := do
try
k
catch
| ex => match isAutoBoundImplicitLocalException? ex with
| some n =>
-- Restore state, declare `n`, and try again
s.restore
withLocalDecl n BinderInfo.implicit (← mkFreshTypeMVar) fun x =>
withReader (fun ctx => { ctx with autoBoundImplicits := ctx.autoBoundImplicits.push x } ) do
loop (← saveAllState)
| none => throw ex
loop (← saveAllState)
else
k
/--
Elaborate `stx` creating new implicit variables for unbound ones when `autoBoundImplicit == true`, and then
execute the continuation `k` in the potentially extended local context.
The auto bound implicit locals are stored in the context variable `autoBoundImplicits`
-/
partial def elabTypeWithAutoBoundImplicit (stx : Syntax) (k : Expr → TermElabM α) : TermElabM α := do
if (← read).autoBoundImplicit then
let rec loop : TermElabM α := do
let s ← saveAllState
try
k (← elabType stx)
catch
| ex => match isAutoBoundImplicitLocalException? ex with
| some n =>
-- Restore state, declare `n`, and try again
s.restore
withLocalDecl n BinderInfo.implicit (← mkFreshTypeMVar) fun x =>
withReader (fun ctx => { ctx with autoBoundImplicits := ctx.autoBoundImplicits.push x } )
loop
| none => throw ex
loop
else
k (← elabType stx)
def withoutAutoBoundImplicit (k : TermElabM α) : TermElabM α := do
withReader (fun ctx => { ctx with autoBoundImplicit := false, autoBoundImplicits := {} }) k
/--
Return `autoBoundImplicits ++ xs.

View file

@ -1,8 +1,2 @@
301.lean:1:9-1:17: error: missing cases:
(Nat.succ _)
301.lean:1:21-1:24: error: type mismatch
( fun (x : Nat) => ?m x)
has type
(n : Nat) → ?m ( fun (x : Nat) => ?m x) n : Type ?u
but is expected to have type
?m x n : Type ?u

View file

@ -6,7 +6,7 @@ def BV (n : Nat) := { a : Array Bool // a.size = n }
def allZero (bv : BV n) : Prop :=
∀ i, i < n → bv.val[i] = false
def foo (n : Nat) (h : allZero b) : BV n :=
def foo (b : BV n) (h : allZero b) : BV n :=
b
def optbind (x : Option α) (f : α → Option β) : Option β :=

View file

@ -1,4 +1,3 @@
autoBoundImplicits2.lean:9:0-10:3: error: invalid auto implicit argument 'b', it depends on explicitly provided argument 'n'
g1 : ?m → ?m
autoBoundImplicits2.lean:30:17-30:18: error: unknown universe level 'u'
autoBoundImplicits2.lean:33:17-33:18: error: unknown universe level 'β'

View file

@ -1 +1,6 @@
autoBoundPostponeLoop.lean:5:12-5:13: error: unknown identifier 'h'
autoBoundPostponeLoop.lean:5:12-5:18: error: invalid `▸` notation, argument
h
has type
?m
equality expected
autoBoundPostponeLoop.lean:1:8-1:10: error: (kernel) declaration has metavariables 'ex'

View file

@ -0,0 +1,10 @@
example : n.succ = 1 → n = 0 := by
intros h; injection h; assumption
example (h : n.succ = 1) : n = 0 := by
injection h; assumption
constant T : Type
constant T.Pred : T → T → Prop
example {ρ} (hρ : ρ.Pred σ) : T.Pred ρ ρ := sorry