From f657aed79894401d7fa593f70f3812805614555b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 11 Jul 2022 17:21:31 -0700 Subject: [PATCH] feat: store `outParam` positions --- src/Lean/Class.lean | 78 ++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/src/Lean/Class.lean b/src/Lean/Class.lean index e63eca0642..eb1642ac18 100644 --- a/src/Lean/Class.lean +++ b/src/Lean/Class.lean @@ -8,8 +8,16 @@ import Lean.Attributes namespace Lean structure ClassEntry where - name : Name - hasOutParam : Bool + name : Name + /-- + Position of the class `outParams`. + For example, for class + ``` + class GetElem (Cont : Type u) (Idx : Type v) (Elem : outParam (Type w)) (Dom : outParam (Cont → Idx → Prop)) where + ``` + `outParams := #[2, 3]` + -/ + outParams : Array Nat namespace ClassEntry @@ -19,16 +27,16 @@ def lt (a b : ClassEntry) : Bool := end ClassEntry structure ClassState where - hasOutParam : SMap Name Bool := SMap.empty + outParamMap : SMap Name (Array Nat) := SMap.empty deriving Inhabited namespace ClassState def addEntry (s : ClassState) (entry : ClassEntry) : ClassState := - { s with hasOutParam := s.hasOutParam.insert entry.name entry.hasOutParam } + { s with outParamMap := s.outParamMap.insert entry.name entry.outParams } def switch (s : ClassState) : ClassState := - { s with hasOutParam := s.hasOutParam.switch } + { s with outParamMap := s.outParamMap.switch } end ClassState @@ -42,17 +50,23 @@ builtin_initialize classExtension : SimplePersistentEnvExtension ClassEntry Clas @[export lean_is_class] def isClass (env : Environment) (n : Name) : Bool := - (classExtension.getState env).hasOutParam.contains n - -@[export lean_has_out_params] -def hasOutParams (env : Environment) (n : Name) : Bool := - match (classExtension.getState env).hasOutParam.find? n with - | some b => b - | none => false + (classExtension.getState env).outParamMap.contains n /-- - Auxiliary function for checking whether a class has `outParam`, and - whether they are being correctly used. + If `declName` is a class, return the position of its `outParams`. +-/ +def getOutParamPositions? (env : Environment) (declName : Name) : Option (Array Nat) := + (classExtension.getState env).outParamMap.find? declName + +@[export lean_has_out_params] +def hasOutParams (env : Environment) (declName : Name) : Bool := + match getOutParamPositions? env declName with + | some outParams => !outParams.isEmpty + | none => false + +/-- + Auxiliary function for collection the position class `outParams`, and + checking whether they are being correctly used. A regular (i.e., non `outParam`) must not depend on an `outParam`. Reason for this restriction: When performing type class resolution, we replace arguments that @@ -62,19 +76,19 @@ def hasOutParams (env : Environment) (n : Name) : Bool := incorrect. This transformation would be counterintuitive to users since we would implicitly treat these regular parameters as `outParam`s. -/ -private partial def checkOutParam : Nat → Array FVarId → Expr → Except String Bool - | i, outParams, Expr.forallE _ d b _ => +private partial def checkOutParam (i : Nat) (outParamFVarIds : Array FVarId) (outParams : Array Nat) (type : Expr) : Except String (Array Nat) := + match type with + | .forallE _ d b _ => if d.isOutParam then - let fvarId := { name := Name.mkNum `_fvar outParams.size } - let outParams := outParams.push fvarId + let fvarId := { name := Name.mkNum `_fvar outParamFVarIds.size } let fvar := mkFVar fvarId let b := b.instantiate1 fvar - checkOutParam (i+1) outParams b - else if d.hasAnyFVar fun fvarId => outParams.contains fvarId then - Except.error s!"invalid class, parameter #{i} depends on `outParam`, but it is not an `outParam`" + checkOutParam (i+1) (outParamFVarIds.push fvarId) (outParams.push i) b + else if d.hasAnyFVar fun fvarId => outParamFVarIds.contains fvarId then + Except.error s!"invalid class, parameter #{i+1} depends on `outParam`, but it is not an `outParam`" else - checkOutParam (i+1) outParams b - | _, outParams, _ => pure (outParams.size > 0) + checkOutParam (i+1) outParamFVarIds outParams b + | _ => return outParams def addClass (env : Environment) (clsName : Name) : Except String Environment := do if isClass env clsName then @@ -83,17 +97,17 @@ def addClass (env : Environment) (clsName : Name) : Except String Environment := | throw s!"unknown declaration '{clsName}'" unless decl matches .inductInfo .. | .axiomInfo .. do throw s!"invalid 'class', declaration '{clsName}' must be inductive datatype, structure, or constant" - let b ← checkOutParam 1 #[] decl.type - return classExtension.addEntry env { name := clsName, hasOutParam := b } + let outParams ← checkOutParam 0 #[] #[] decl.type + return classExtension.addEntry env { name := clsName, outParams } private def consumeNLambdas : Nat → Expr → Option Expr - | 0, e => some e - | i+1, Expr.lam _ _ b _ => consumeNLambdas i b - | _, _ => none + | 0, e => some e + | i+1, .lam _ _ b _ => consumeNLambdas i b + | _, _ => none partial def getClassName (env : Environment) : Expr → Option Name - | Expr.forallE _ _ b _ => getClassName env b - | e => do + | .forallE _ _ b _ => getClassName env b + | e => do let Expr.const c _ ← pure e.getAppFn | none let info ← env.find? c match info.value? with @@ -106,8 +120,8 @@ partial def getClassName (env : Environment) : Expr → Option Name builtin_initialize registerBuiltinAttribute { - name := `class, - descr := "type class", + name := `class + descr := "type class" add := fun decl stx kind => do let env ← getEnv Attribute.Builtin.ensureNoArgs stx