feat: store outParam positions
This commit is contained in:
parent
0b1fde64ee
commit
f657aed798
1 changed files with 46 additions and 32 deletions
|
|
@ -8,8 +8,16 @@ import Lean.Attributes
|
|||
namespace Lean
|
||||
|
||||
structure ClassEntry where
|
||||
name : Name
|
||||
hasOutParam : Bool
|
||||
name : Name
|
||||
/--
|
||||
Position of the class `outParams`.
|
||||
For example, for class
|
||||
```
|
||||
class GetElem (Cont : Type u) (Idx : Type v) (Elem : outParam (Type w)) (Dom : outParam (Cont → Idx → Prop)) where
|
||||
```
|
||||
`outParams := #[2, 3]`
|
||||
-/
|
||||
outParams : Array Nat
|
||||
|
||||
namespace ClassEntry
|
||||
|
||||
|
|
@ -19,16 +27,16 @@ def lt (a b : ClassEntry) : Bool :=
|
|||
end ClassEntry
|
||||
|
||||
structure ClassState where
|
||||
hasOutParam : SMap Name Bool := SMap.empty
|
||||
outParamMap : SMap Name (Array Nat) := SMap.empty
|
||||
deriving Inhabited
|
||||
|
||||
namespace ClassState
|
||||
|
||||
def addEntry (s : ClassState) (entry : ClassEntry) : ClassState :=
|
||||
{ s with hasOutParam := s.hasOutParam.insert entry.name entry.hasOutParam }
|
||||
{ s with outParamMap := s.outParamMap.insert entry.name entry.outParams }
|
||||
|
||||
def switch (s : ClassState) : ClassState :=
|
||||
{ s with hasOutParam := s.hasOutParam.switch }
|
||||
{ s with outParamMap := s.outParamMap.switch }
|
||||
|
||||
end ClassState
|
||||
|
||||
|
|
@ -42,17 +50,23 @@ builtin_initialize classExtension : SimplePersistentEnvExtension ClassEntry Clas
|
|||
|
||||
@[export lean_is_class]
|
||||
def isClass (env : Environment) (n : Name) : Bool :=
|
||||
(classExtension.getState env).hasOutParam.contains n
|
||||
|
||||
@[export lean_has_out_params]
|
||||
def hasOutParams (env : Environment) (n : Name) : Bool :=
|
||||
match (classExtension.getState env).hasOutParam.find? n with
|
||||
| some b => b
|
||||
| none => false
|
||||
(classExtension.getState env).outParamMap.contains n
|
||||
|
||||
/--
|
||||
Auxiliary function for checking whether a class has `outParam`, and
|
||||
whether they are being correctly used.
|
||||
If `declName` is a class, return the position of its `outParams`.
|
||||
-/
|
||||
def getOutParamPositions? (env : Environment) (declName : Name) : Option (Array Nat) :=
|
||||
(classExtension.getState env).outParamMap.find? declName
|
||||
|
||||
@[export lean_has_out_params]
|
||||
def hasOutParams (env : Environment) (declName : Name) : Bool :=
|
||||
match getOutParamPositions? env declName with
|
||||
| some outParams => !outParams.isEmpty
|
||||
| none => false
|
||||
|
||||
/--
|
||||
Auxiliary function for collection the position class `outParams`, and
|
||||
checking whether they are being correctly used.
|
||||
A regular (i.e., non `outParam`) must not depend on an `outParam`.
|
||||
Reason for this restriction:
|
||||
When performing type class resolution, we replace arguments that
|
||||
|
|
@ -62,19 +76,19 @@ def hasOutParams (env : Environment) (n : Name) : Bool :=
|
|||
incorrect. This transformation would be counterintuitive to users since
|
||||
we would implicitly treat these regular parameters as `outParam`s.
|
||||
-/
|
||||
private partial def checkOutParam : Nat → Array FVarId → Expr → Except String Bool
|
||||
| i, outParams, Expr.forallE _ d b _ =>
|
||||
private partial def checkOutParam (i : Nat) (outParamFVarIds : Array FVarId) (outParams : Array Nat) (type : Expr) : Except String (Array Nat) :=
|
||||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
if d.isOutParam then
|
||||
let fvarId := { name := Name.mkNum `_fvar outParams.size }
|
||||
let outParams := outParams.push fvarId
|
||||
let fvarId := { name := Name.mkNum `_fvar outParamFVarIds.size }
|
||||
let fvar := mkFVar fvarId
|
||||
let b := b.instantiate1 fvar
|
||||
checkOutParam (i+1) outParams b
|
||||
else if d.hasAnyFVar fun fvarId => outParams.contains fvarId then
|
||||
Except.error s!"invalid class, parameter #{i} depends on `outParam`, but it is not an `outParam`"
|
||||
checkOutParam (i+1) (outParamFVarIds.push fvarId) (outParams.push i) b
|
||||
else if d.hasAnyFVar fun fvarId => outParamFVarIds.contains fvarId then
|
||||
Except.error s!"invalid class, parameter #{i+1} depends on `outParam`, but it is not an `outParam`"
|
||||
else
|
||||
checkOutParam (i+1) outParams b
|
||||
| _, outParams, _ => pure (outParams.size > 0)
|
||||
checkOutParam (i+1) outParamFVarIds outParams b
|
||||
| _ => return outParams
|
||||
|
||||
def addClass (env : Environment) (clsName : Name) : Except String Environment := do
|
||||
if isClass env clsName then
|
||||
|
|
@ -83,17 +97,17 @@ def addClass (env : Environment) (clsName : Name) : Except String Environment :=
|
|||
| throw s!"unknown declaration '{clsName}'"
|
||||
unless decl matches .inductInfo .. | .axiomInfo .. do
|
||||
throw s!"invalid 'class', declaration '{clsName}' must be inductive datatype, structure, or constant"
|
||||
let b ← checkOutParam 1 #[] decl.type
|
||||
return classExtension.addEntry env { name := clsName, hasOutParam := b }
|
||||
let outParams ← checkOutParam 0 #[] #[] decl.type
|
||||
return classExtension.addEntry env { name := clsName, outParams }
|
||||
|
||||
private def consumeNLambdas : Nat → Expr → Option Expr
|
||||
| 0, e => some e
|
||||
| i+1, Expr.lam _ _ b _ => consumeNLambdas i b
|
||||
| _, _ => none
|
||||
| 0, e => some e
|
||||
| i+1, .lam _ _ b _ => consumeNLambdas i b
|
||||
| _, _ => none
|
||||
|
||||
partial def getClassName (env : Environment) : Expr → Option Name
|
||||
| Expr.forallE _ _ b _ => getClassName env b
|
||||
| e => do
|
||||
| .forallE _ _ b _ => getClassName env b
|
||||
| e => do
|
||||
let Expr.const c _ ← pure e.getAppFn | none
|
||||
let info ← env.find? c
|
||||
match info.value? with
|
||||
|
|
@ -106,8 +120,8 @@ partial def getClassName (env : Environment) : Expr → Option Name
|
|||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
name := `class,
|
||||
descr := "type class",
|
||||
name := `class
|
||||
descr := "type class"
|
||||
add := fun decl stx kind => do
|
||||
let env ← getEnv
|
||||
Attribute.Builtin.ensureNoArgs stx
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue