feat: elaborate strict implicit binders
This commit is contained in:
parent
9988264345
commit
4cd7e359df
5 changed files with 69 additions and 36 deletions
|
|
@ -388,6 +388,11 @@ private def processExplictArg (k : M Expr) : M Expr := do
|
|||
else
|
||||
finalize
|
||||
|
||||
/- Return true if there are regular or named arguments to be processed. -/
|
||||
private def hasArgsToProcess : M Bool := do
|
||||
let s ← get
|
||||
return !s.args.isEmpty || !s.namedArgs.isEmpty
|
||||
|
||||
/-
|
||||
Process a `fType` of the form `{x : A} → B x`.
|
||||
This method assume `fType` is a function type -/
|
||||
|
|
@ -397,6 +402,17 @@ private def processImplicitArg (k : M Expr) : M Expr := do
|
|||
else
|
||||
addImplicitArg k
|
||||
|
||||
/-
|
||||
Process a `fType` of the form `{{x : A}} → B x`.
|
||||
This method assume `fType` is a function type -/
|
||||
private def processStrictImplicitArg (k : M Expr) : M Expr := do
|
||||
if (← get).explicit then
|
||||
processExplictArg k
|
||||
else if (← hasArgsToProcess) then
|
||||
addImplicitArg k
|
||||
else
|
||||
finalize
|
||||
|
||||
/- Return true if the next argument at `args` is of the form `_` -/
|
||||
private def isNextArgHole : M Bool := do
|
||||
match (← get).args with
|
||||
|
|
@ -423,11 +439,6 @@ private def processInstImplicitArg (k : M Expr) : M Expr := do
|
|||
addNewArg arg
|
||||
k
|
||||
|
||||
/- Return true if there are regular or named arguments to be processed. -/
|
||||
private def hasArgsToProcess : M Bool := do
|
||||
let s ← get
|
||||
pure $ !s.args.isEmpty || !s.namedArgs.isEmpty
|
||||
|
||||
/- Elaborate function application arguments. -/
|
||||
partial def main : M Expr := do
|
||||
let s ← get
|
||||
|
|
@ -444,9 +455,10 @@ partial def main : M Expr := do
|
|||
main
|
||||
| none =>
|
||||
match binfo with
|
||||
| BinderInfo.implicit => processImplicitArg main
|
||||
| BinderInfo.instImplicit => processInstImplicitArg main
|
||||
| _ => processExplictArg main
|
||||
| BinderInfo.implicit => processImplicitArg main
|
||||
| BinderInfo.instImplicit => processInstImplicitArg main
|
||||
| BinderInfo.strictImplicit => processStrictImplicitArg main
|
||||
| _ => processExplictArg main
|
||||
else if (← hasArgsToProcess) then
|
||||
synthesizePendingAndNormalizeFunType
|
||||
main
|
||||
|
|
@ -572,25 +584,25 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
|
|||
|
||||
/- whnfCore + implicit consumption.
|
||||
Example: given `e` with `eType := {α : Type} → (fun β => List β) α `, it produces `(e ?m, List ?m)` where `?m` is fresh metavariable. -/
|
||||
private partial def consumeImplicits (stx : Syntax) (e eType : Expr) : TermElabM (Expr × Expr) := do
|
||||
private partial def consumeImplicits (stx : Syntax) (e eType : Expr) (hasArgs : Bool) : TermElabM (Expr × Expr) := do
|
||||
let eType ← whnfCore eType
|
||||
match eType with
|
||||
| Expr.forallE n d b c =>
|
||||
if c.binderInfo.isImplicit then
|
||||
if c.binderInfo.isImplicit || (hasArgs && c.binderInfo.isStrictImplicit) then
|
||||
let mvar ← mkFreshExprMVar d
|
||||
registerMVarErrorHoleInfo mvar.mvarId! stx
|
||||
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar)
|
||||
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) hasArgs
|
||||
else if c.binderInfo.isInstImplicit then
|
||||
let mvar ← mkInstMVar d
|
||||
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar)
|
||||
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) hasArgs
|
||||
else match d.getOptParamDefault? with
|
||||
| some defVal => consumeImplicits stx (mkApp e defVal) (b.instantiate1 defVal)
|
||||
| some defVal => consumeImplicits stx (mkApp e defVal) (b.instantiate1 defVal) hasArgs
|
||||
-- TODO: we do not handle autoParams here.
|
||||
| _ => pure (e, eType)
|
||||
| _ => pure (e, eType)
|
||||
|
||||
private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExceptions : Array Exception) : TermElabM (Expr × LValResolution) := do
|
||||
let (e, eType) ← consumeImplicits lval.getRef e eType
|
||||
private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExceptions : Array Exception) (hasArgs : Bool) : TermElabM (Expr × LValResolution) := do
|
||||
let (e, eType) ← consumeImplicits lval.getRef e eType hasArgs
|
||||
tryPostponeIfMVar eType
|
||||
try
|
||||
let lvalRes ← resolveLValAux e eType lval
|
||||
|
|
@ -599,15 +611,15 @@ private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExce
|
|||
| ex@(Exception.error _ _) =>
|
||||
let eType? ← unfoldDefinition? eType
|
||||
match eType? with
|
||||
| some eType => resolveLValLoop lval e eType (previousExceptions.push ex)
|
||||
| some eType => resolveLValLoop lval e eType (previousExceptions.push ex) hasArgs
|
||||
| none =>
|
||||
previousExceptions.forM fun ex => logException ex
|
||||
throw ex
|
||||
| ex@(Exception.internal _ _) => throw ex
|
||||
|
||||
private def resolveLVal (e : Expr) (lval : LVal) : TermElabM (Expr × LValResolution) := do
|
||||
private def resolveLVal (e : Expr) (lval : LVal) (hasArgs : Bool) : TermElabM (Expr × LValResolution) := do
|
||||
let eType ← inferType e
|
||||
resolveLValLoop lval e eType #[]
|
||||
resolveLValLoop lval e eType #[] hasArgs
|
||||
|
||||
private partial def mkBaseProjections (baseStructName : Name) (structName : Name) (e : Expr) : TermElabM Expr := do
|
||||
let env ← getEnv
|
||||
|
|
@ -675,7 +687,8 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
|
|||
| f, lval::lvals => do
|
||||
if let LVal.fieldName (ref := fieldStx) (targetStx := targetStx) .. := lval then
|
||||
addDotCompletionInfo targetStx f expectedType? fieldStx
|
||||
let (f, lvalRes) ← resolveLVal f lval
|
||||
let hasArgs := !namedArgs.isEmpty || !args.isEmpty
|
||||
let (f, lvalRes) ← resolveLVal f lval hasArgs
|
||||
match lvalRes with
|
||||
| LValResolution.projIdx structName idx =>
|
||||
let f := mkProj structName idx f
|
||||
|
|
|
|||
|
|
@ -104,24 +104,29 @@ private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) :=
|
|||
|
||||
private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) := do
|
||||
let k := stx.getKind
|
||||
if k == `Lean.Parser.Term.simpleBinder then
|
||||
if k == ``Lean.Parser.Term.simpleBinder then
|
||||
-- binderIdent+ >> optType
|
||||
let ids ← getBinderIds stx[0]
|
||||
let type := expandOptType (mkNullNode ids) stx[1]
|
||||
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default }
|
||||
else if k == `Lean.Parser.Term.explicitBinder then
|
||||
else if k == ``Lean.Parser.Term.explicitBinder then
|
||||
-- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)`
|
||||
let ids ← getBinderIds stx[1]
|
||||
let type := expandBinderType (mkNullNode ids) stx[2]
|
||||
let optModifier := stx[3]
|
||||
let type ← expandBinderModifier type optModifier
|
||||
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default }
|
||||
else if k == `Lean.Parser.Term.implicitBinder then
|
||||
else if k == ``Lean.Parser.Term.implicitBinder then
|
||||
-- `{` binderIdent+ binderType `}`
|
||||
let ids ← getBinderIds stx[1]
|
||||
let type := expandBinderType (mkNullNode ids) stx[2]
|
||||
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.implicit }
|
||||
else if k == `Lean.Parser.Term.instBinder then
|
||||
else if k == ``Lean.Parser.Term.strictImplicitBinder then
|
||||
-- `⦃` binderIdent+ binderType `⦄`
|
||||
let ids ← getBinderIds stx[1]
|
||||
let type := expandBinderType (mkNullNode ids) stx[2]
|
||||
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.strictImplicit }
|
||||
else if k == ``Lean.Parser.Term.instBinder then
|
||||
-- `[` optIdent type `]`
|
||||
let id ← expandOptIdent stx[1]
|
||||
let type := stx[2]
|
||||
|
|
@ -256,15 +261,16 @@ partial def expandFunBinders (binders : Array Syntax) (body : Syntax) : MacroM (
|
|||
let newBody ← `(match $major:ident with | $pattern => $newBody)
|
||||
pure (binders, newBody, true)
|
||||
match binder with
|
||||
| Syntax.node `Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node `Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node `Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node `Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node `Lean.Parser.Term.hole _ =>
|
||||
| Syntax.node ``Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node ``Lean.Parser.Term.strictImplicitBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node ``Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node ``Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node ``Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder)
|
||||
| Syntax.node ``Lean.Parser.Term.hole _ =>
|
||||
let ident ← mkFreshIdent binder
|
||||
let type := binder
|
||||
loop body (i+1) (newBinders.push <| mkExplicitBinder ident type)
|
||||
| Syntax.node `Lean.Parser.Term.paren args =>
|
||||
| Syntax.node ``Lean.Parser.Term.paren args =>
|
||||
-- `(` (termParser >> parenSpecial)? `)`
|
||||
-- parenSpecial := (tupleTail <|> typeAscription)?
|
||||
let binderBody := binder[1]
|
||||
|
|
|
|||
|
|
@ -506,10 +506,12 @@ def delabLam : Delab :=
|
|||
else
|
||||
pure $ curNames.get! 0;
|
||||
`(funBinder| ($stxCurNames : $stxT))
|
||||
| BinderInfo.default, false => pure curNames.back -- here `curNames.size == 1`
|
||||
| BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT})
|
||||
| BinderInfo.implicit, false => `(funBinder| {$curNames*})
|
||||
| BinderInfo.instImplicit, _ =>
|
||||
| BinderInfo.default, false => pure curNames.back -- here `curNames.size == 1`
|
||||
| BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT})
|
||||
| BinderInfo.implicit, false => `(funBinder| {$curNames*})
|
||||
| BinderInfo.strictImplicit, true => `(funBinder| ⦃$curNames* : $stxT⦄)
|
||||
| BinderInfo.strictImplicit, false => `(funBinder| ⦃$curNames*⦄)
|
||||
| BinderInfo.instImplicit, _ =>
|
||||
if usedDownstream then `(funBinder| [$curNames.back : $stxT]) -- here `curNames.size == 1`
|
||||
else `(funBinder| [$stxT])
|
||||
| _ , _ => unreachable!;
|
||||
|
|
@ -524,10 +526,11 @@ def delabForall : Delab :=
|
|||
let prop ← try isProp e catch _ => false
|
||||
let stxT ← withBindingDomain delab
|
||||
let group ← match e.binderInfo with
|
||||
| BinderInfo.implicit => `(bracketedBinderF|{$curNames* : $stxT})
|
||||
| BinderInfo.implicit => `(bracketedBinderF|{$curNames* : $stxT})
|
||||
| BinderInfo.strictImplicit => `(bracketedBinderF|⦃$curNames* : $stxT⦄)
|
||||
-- here `curNames.size == 1`
|
||||
| BinderInfo.instImplicit => `(bracketedBinderF|[$curNames.back : $stxT])
|
||||
| _ =>
|
||||
| BinderInfo.instImplicit => `(bracketedBinderF|[$curNames.back : $stxT])
|
||||
| _ =>
|
||||
-- heuristic: use non-dependent arrows only if possible for whole group to avoid
|
||||
-- noisy mix like `(α : Type) → Type → (γ : Type) → ...`.
|
||||
let dependent := curNames.any $ fun n => hasIdent n.getId stxBody
|
||||
|
|
|
|||
7
tests/lean/strictImplicit.lean
Normal file
7
tests/lean/strictImplicit.lean
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
def g {α : Type} (a : α) := a
|
||||
def f {{α : Type}} (a : α) := a
|
||||
|
||||
#check g
|
||||
#check f
|
||||
#check g 1
|
||||
#check f 1
|
||||
4
tests/lean/strictImplicit.lean.expected.out
Normal file
4
tests/lean/strictImplicit.lean.expected.out
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
g : ?m → ?m
|
||||
f : ⦃α : Type⦄ → α → α
|
||||
g 1 : Nat
|
||||
f 1 : Nat
|
||||
Loading…
Add table
Reference in a new issue