feat: in an inductive family the longest fixed prefix of indices is now promoted to parameters

This modification is relevant for fixing regressions on recent changes
to the auto implicit behavior for inductive families.

The following declarations are now accepted:
```lean
inductive HasType : Fin n → Vector Ty n → Ty → Type where
  | stop : HasType 0 (ty :: ctx) ty
  | pop  : HasType k ctx ty → HasType k.succ (u :: ctx) ty

inductive Sublist : List α → List α → Prop
  | slnil : Sublist [] []
  | cons l₁ l₂ a : Sublist l₁ l₂ → Sublist l₁ (a :: l₂)
  | cons2 l₁ l₂ a : Sublist l₁ l₂ → Sublist (a :: l₁) (a :: l₂)

inductive Lst : Type u → Type u
  | nil  : Lst α
  | cons : α → Lst α → Lst α
```

TODO: universe inference for `inductive` should be improved. The
current approach is not good enough when we have auto implicits.

TODO: allow implicit fixed indices that do not depend on indices that
cannot be moved to become parameters.
This commit is contained in:
Leonardo de Moura 2022-03-08 17:45:59 -08:00
parent 234046521a
commit 4ee131981d
10 changed files with 141 additions and 78 deletions

View file

@ -118,7 +118,7 @@ def Array.insertAtAux (i : Nat) (as : Array α) (j : Nat) : Array α :=
* Add support for `for h : x in xs do ...` notation where `h : x ∈ xs`. This is mainly useful for showing termination.
* Auto implicit behavior changed for inductive families. An auto implicit argument occurring in inductive family index is also treated as an index. For example
* Auto implicit behavior changed for inductive families. An auto implicit argument occurring in inductive family index is also treated as an index (IF it is not fixed, see next item). For example
```lean
inductive HasType : Index n → Vector Ty n → Ty → Type where
```
@ -127,6 +127,20 @@ is now interpreted as
inductive HasType : {n : Nat} → Index n → Vector Ty n → Ty → Type where
```
* To make the previous feature more convenient to use, we promote a fixed prefix of inductive family indices to parameters. For example, the following declaration is now accepted by Lean
```lean
inductive Lst : Type u → Type u
| nil : Lst α
| cons : α → Lst α → Lst α
```
and `α` in `Lst α` is a parameter. The actual number of parameters can be inspected using the command `#print Lst`. This feature also makes sure we still accept the declaration
```lean
inductive Sublist : List α → List α → Prop
| slnil : Sublist [] []
| cons l₁ l₂ a : Sublist l₁ l₂ → Sublist l₁ (a :: l₂)
| cons2 l₁ l₂ a : Sublist l₁ l₂ → Sublist (a :: l₁) (a :: l₂)
```
* Added auto implicit "chaining". Unassigned metavariables occurring in the auto implicit types now become new auto implicit locals. Consider the following example:
```lean
inductive HasType : Fin n → Vector Ty n → Ty → Type where

View file

@ -464,6 +464,84 @@ private def mkAuxConstructions (views : Array InductiveView) : TermElabM Unit :=
if hasUnit && hasProd then mkBRecOn n
if hasUnit && hasProd then mkBInductionOn n
private def getArity (indType : InductiveType) : MetaM Nat :=
forallTelescopeReducing indType.type fun xs _ => return xs.size
/--
Compute a bit-mask that for `indType`. The size of the resulting array `result` is the arity of `indType`.
The first `numParams` elements are `false` since they are parameters.
For `i ∈ [numParams, arity)`, we have that `result[i]` if this index of the inductive family is fixed.
-/
private def computeFixedIndexBitMask (numParams : Nat) (indType : InductiveType) (indFVars : Array Expr) : MetaM (Array Bool) := do
let arity ← getArity indType
if arity ≤ numParams then
return mkArray arity false
else
let maskRef ← IO.mkRef (mkArray numParams false ++ mkArray (arity - numParams) true)
let rec go (ctors : List Constructor) : MetaM (Array Bool) := do
match ctors with
| [] => maskRef.get
| ctor :: ctors =>
forallTelescopeReducing ctor.type fun xs type => do
let I := type.getAppFn.constName!
let typeArgs := type.getAppArgs
for i in [numParams:arity] do
unless i < xs.size && xs[i] == typeArgs[i] do -- Remark: if we want to allow arguments to be rearranged, this test should be xs.contains typeArgs[i]
maskRef.modify fun mask => mask.set! i false
for x in xs[numParams:] do
let xType ← inferType x
xType.forEach fun e => do
if indFVars.any (fun indFVar => e.getAppFn == indFVar) && e.getAppNumArgs > numParams then
let eArgs := e.getAppArgs
for i in [numParams:eArgs.size] do
if i >= typeArgs.size then
maskRef.modify fun mask => mask.set! i false
else
unless eArgs[i] == typeArgs[i] do
maskRef.modify fun mask => mask.set! i false
go ctors
go indType.ctors
/-- Return true iff `arrowType` is an arrow and its domain is defeq to `type` -/
private def isDomainDefEq (arrowType : Expr) (type : Expr) : MetaM Bool := do
if !arrowType.isForall then
return false
else
withNewMCtxDepth do -- Make sure we do not assign univers metavariables
isDefEq arrowType.bindingDomain! type
/--
Convert fixed indices to parameters.
TODO: we currently only convert a prefix of the indices, and we do not try to reorder binders.
-/
private partial def fixedIndicesToParams (numParams : Nat) (indTypes : Array InductiveType) (indFVars : Array Expr) : MetaM (Nat × List InductiveType) := do
let masks ← indTypes.mapM (computeFixedIndexBitMask numParams . indFVars)
if masks.all fun mask => !mask.contains true then
return (numParams, indTypes.toList)
-- We process just a non-fixed prefix of the indices for now. Reason: we don't want to change the order.
-- TODO: extend it in the future. For example, it should be reasonable to change
-- the order of indices generated by the auto implicit feature.
let mask := masks[0]
forallBoundedTelescope indTypes[0].type numParams fun params type => do
let otherTypes ← indTypes[1:].toArray.mapM fun indType => do whnfD (← instantiateForall indType.type params)
let ctorTypes ← indTypes.toList.mapM fun indType => indType.ctors.mapM fun ctor => do whnfD (← instantiateForall ctor.type params)
let typesToCheck := otherTypes.toList ++ ctorTypes.join
let rec go (i : Nat) (typesToCheck : List Expr) : MetaM Nat := do
if i < mask.size then
if !masks.all fun mask => i < mask.size && mask[i] then
return i
if !type.isForall then
return i
let paramType := type.bindingDomain!
if !(← typesToCheck.allM fun type => isDomainDefEq type paramType) then
return i
withLocalDeclD `a paramType fun paramNew => do
let typesToCheck ← typesToCheck.mapM fun type => whnfD (type.bindingBody!.instantiate1 paramNew)
go (i+1) typesToCheck
else
return i
return (← go numParams typesToCheck, indTypes.toList)
private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := do
let view0 := views[0]
let scopeLevelNames ← Term.getLevelNames
@ -473,7 +551,6 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
withRef view0.ref <| Term.withLevelNames allUserLevelNames do
let rs ← elabHeader views
withInductiveLocalDecls rs fun params indFVars => do
let numExplicitParams := params.size
let mut indTypesArray := #[]
for i in [:views.size] do
let indFVar := indFVars[i]
@ -481,9 +558,9 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
let type ← mkForallFVars params r.type
let ctors ← elabCtors indFVars indFVar params r
indTypesArray := indTypesArray.push { name := r.view.declName, type := type, ctors := ctors : InductiveType }
let indTypes := indTypesArray.toList
-- TODO: convert fixed indices to parameters
Term.synthesizeSyntheticMVarsNoPostponing
let (numExplicitParams, indTypes) ← fixedIndicesToParams params.size indTypesArray indFVars
trace[Meta.debug] "numExplicitParams: {numExplicitParams}"
let u ← getResultingUniverse indTypes
let inferLevel ← shouldInferResultUniverse u
withUsed vars indTypes fun vars => do

View file

@ -14,7 +14,7 @@ a : α
α : Type @ ⟨7, 13⟩-⟨7, 14⟩
a (isBinder := true) : α @ ⟨7, 9⟩-⟨7, 10⟩
Fam2 α β : Type 1 @ ⟨7, 21⟩-⟨7, 29⟩ @ Lean.Elab.Term.elabApp
[.] `Fam2 : some Sort.{?_uniq.284} @ ⟨7, 21⟩-⟨7, 25⟩
[.] `Fam2 : some Sort.{?_uniq.288} @ ⟨7, 21⟩-⟨7, 25⟩
Fam2 : Type → Type → Type 1 @ ⟨7, 21⟩-⟨7, 25⟩
α : Type @ ⟨7, 26⟩-⟨7, 27⟩ @ Lean.Elab.Term.elabIdent
α : Type @ ⟨7, 26⟩-⟨7, 27⟩
@ -43,7 +43,7 @@ a : α
α : Type @ ⟨1, 0⟩†-⟨1, 0⟩† @ Lean.Elab.Term.elabMVarWithIdKind
α : Type @ ⟨1, 0⟩†-⟨1, 0⟩† @ Lean.Elab.Term.elabMVarWithIdKind
Fam2.any : Fam2 α α @ ⟨9, 4⟩†-⟨9, 12⟩ @ Lean.Elab.Term.elabApp
[.] `Fam2.any : some Fam2 ([mdata _inaccessible:1 ?_uniq.618]) ([mdata _inaccessible:1 ?_uniq.619]) @ ⟨9, 4⟩-⟨9, 12⟩
[.] `Fam2.any : some Fam2 ([mdata _inaccessible:1 ?_uniq.622]) ([mdata _inaccessible:1 ?_uniq.623]) @ ⟨9, 4⟩-⟨9, 12⟩
@Fam2.any : {α : Type} → Fam2 α α @ ⟨9, 4⟩-⟨9, 12⟩
α : Type @ ⟨1, 0⟩†-⟨1, 0⟩† @ Lean.Elab.Term.elabMVarWithIdKind
a : α @ ⟨8, 2⟩†-⟨10, 19⟩† @ Lean.Elab.Term.elabIdent
@ -56,7 +56,7 @@ a : α
Nat : Type @ ⟨1, 0⟩†-⟨1, 0⟩† @ Lean.Elab.Term.elabMVarWithIdKind
Nat : Type @ ⟨1, 0⟩†-⟨1, 0⟩† @ Lean.Elab.Term.elabMVarWithIdKind
Fam2.nat n : Fam2 Nat Nat @ ⟨10, 4⟩†-⟨10, 14⟩ @ Lean.Elab.Term.elabApp
[.] `Fam2.nat : some Fam2 ([mdata _inaccessible:1 ?_uniq.636]) ([mdata _inaccessible:1 ?_uniq.637]) @ ⟨10, 4⟩-⟨10, 12⟩
[.] `Fam2.nat : some Fam2 ([mdata _inaccessible:1 ?_uniq.640]) ([mdata _inaccessible:1 ?_uniq.641]) @ ⟨10, 4⟩-⟨10, 12⟩
Fam2.nat : Nat → Fam2 Nat Nat @ ⟨10, 4⟩-⟨10, 12⟩
n : Nat @ ⟨10, 13⟩-⟨10, 14⟩ @ Lean.Elab.Term.elabIdent
n : Nat @ ⟨10, 13⟩-⟨10, 14⟩

View file

@ -0,0 +1,16 @@
inductive sublist : List α → List α → Prop
| slnil : sublist [] []
| cons l₁ l₂ a : sublist l₁ l₂ → sublist l₁ (a :: l₂)
| cons2 l₁ l₂ a : sublist l₁ l₂ → sublist (a :: l₁) (a :: l₂)
#print sublist
inductive Foo : List α → Type _ -- Make sure we don't need to use `_` or can use `u`
| mk₁ : Foo []
| mk₂ : (a : α) → Foo as → Foo (a::as)
inductive Bla : Foo as → Type _
| mk₁ : Bla Foo.mk₁
#print Foo
#print Bla

View file

@ -0,0 +1,15 @@
inductive sublist.{u_1} : {α : Type u_1} → List α → List α → Prop
number of parameters: 1
constructors:
sublist.slnil : ∀ {a : Type u_1}, sublist [] []
sublist.cons : ∀ {a : Type u_1} (l₁ l₂ : List a) (a_1 : a), sublist l₁ l₂ → sublist l₁ (a_1 :: l₂)
sublist.cons2 : ∀ {a : Type u_1} (l₁ l₂ : List a) (a_1 : a), sublist l₁ l₂ → sublist (a_1 :: l₁) (a_1 :: l₂)
inductive Foo.{u_1} : {α : Type u_1} → List α → Type u_1
number of parameters: 1
constructors:
Foo.mk₁ : {a : Type u_1} → Foo []
Foo.mk₂ : {α : Type u_1} → {as : List α} → (a : α) → Foo as → Foo (a :: as)
inductive Bla.{u_1} : {a : Type u_1} → {as : List a} → Foo as → Type
number of parameters: 1
constructors:
Bla.mk₁ : {a : Type u_1} → Bla Foo.mk₁

View file

@ -10,10 +10,10 @@ structure Node : Type where
def h1 (x : List Node) : Bool :=
match x with
| _ :: Node.mk _ _ (Op.mk 0) :: _ => true
| _ => false
| _ :: Node.mk 0 _ Op.mk :: _ => true
| _ => false
def mkNode (n : Nat) : Node := { id₁ := n, id₂ := n, o := Op.mk n }
def mkNode (n : Nat) : Node := { id₁ := n, id₂ := n, o := Op.mk }
#eval h1 [mkNode 1, mkNode 0, mkNode 3]
#eval h1 [mkNode 1, mkNode 1, mkNode 3]
@ -53,7 +53,7 @@ def h3 {b : Bool} (x : Foo b) : Bool :=
def h4 (x : List Node) : Bool :=
match x with
| _ :: ⟨1, 1, Op.mk 1⟩ :: _ => true
| _ :: ⟨1, 1, Op.mk⟩ :: _ => true
| _ => false
#eval h4 [mkNode 1, mkNode 0, mkNode 3]

View file

@ -3,8 +3,8 @@ abbrev semantics (α:Type) := StateM (List Nat) α
inductive expression : Nat → Type
| const : (n : Nat) → expression n
def uext {w:Nat} (x: expression w) (o:Nat) : expression w := expression.const w
def eval {n : Nat} (v:expression n) : semantics (expression n) := pure (expression.const n)
def uext {w:Nat} (x: expression w) (o:Nat) : expression w := expression.const
def eval {n : Nat} (v:expression n) : semantics (expression n) := pure expression.const
def set_overflow {w : Nat} (e : expression w) : semantics Unit := pure ()
structure instruction :=
@ -13,7 +13,7 @@ structure instruction :=
def definst (mnem:String) (body: expression 8 -> semantics Unit) : instruction :=
{ mnemonic := mnem
, patterns := ((body (expression.const 8)).run []).snd.reverse
, patterns := ((body expression.const).run []).snd.reverse
}
def mul : instruction := Id.run <| do -- this is a "pure" do block (as in it is the Id monad)

View file

@ -284,27 +284,7 @@ elimTest8 _ (fun _ _ => Option (Nat × Nat)) n xs (fun a b => some (a, b)) (fun
inductive Op : Nat → Nat → Type
| mk : ∀ n, Op n n
structure Node : Type :=
(id₁ id₂ : Nat)
(o : Op id₁ id₂)
def ex9 (xs : List Node) :
LHS (forall (h : Node) (t : List Node), Pat (h :: Node.mk 1 1 (Op.mk 1) :: t))
× LHS (forall (ys : List Node), Pat ys) :=
default
#eval test `ex9 1 `elimTest9
#print elimTest9
def f (xs : List Node) : Bool :=
elimTest9 (fun _ => Bool) xs
(fun _ _ => true)
(fun _ => false)
#eval check (f [] == false)
#eval check (f [⟨0, 0, Op.mk 0⟩] == false)
#eval check (f [⟨0, 0, Op.mk 0⟩, ⟨1, 1, Op.mk 1⟩])
#eval check (f [⟨0, 0, Op.mk 0⟩, ⟨2, 2, Op.mk 2⟩] == false)
#print Op
inductive Foo : Bool → Prop
| bar : Foo false
@ -317,12 +297,6 @@ default
#eval test `ex10 2 `elimTest10 true
def ex11 (xs : List Node) :
LHS (forall (h : Node) (t : List Node), Pat (h :: Node.mk 1 1 (Op.mk 1) :: t))
× LHS (Pat ([] : List Node)) :=
default
#eval testFailure `ex11 1 `elimTest11 -- should produce error message
def ex12 (x y z : Bool) :
LHS (forall (x y : Bool), Pat x × Pat y × Pat true)
@ -332,14 +306,6 @@ default
#eval testFailure `ex12 3 `elimTest12 -- should produce error message
def ex13 (xs : List Node) :
LHS (forall (h : Node) (t : List Node), Pat (h :: Node.mk 1 1 (Op.mk 1) :: t))
× LHS (forall (ys : List Node), Pat ys)
× LHS (forall (ys : List Node), Pat ys) :=
default
#eval testFailure `ex13 1 `elimTest13 -- should produce error message
def ex14 (x y : Nat) :
LHS (Pat (val 1) × Pat (val 2))
× LHS (Pat (val 2) × Pat (val 3))

View file

@ -1,30 +1,14 @@
import Lean
import Lean
open Lean
def checkGetBelowIndices (ctorName : Name) (indices : Array Nat) : MetaM Unit := do
let actualIndices ← Meta.IndPredBelow.getBelowIndices ctorName
if actualIndices != indices then
throwError "wrong indices for {ctorName}: {actualIndices} ≟ {indices}"
namespace Ex
inductive LE : Nat → Nat → Prop
| refl : LE n n
| succ : LE n m → LE n m.succ
#eval checkGetBelowIndices ``LE.refl #[1]
#eval checkGetBelowIndices ``LE.succ #[1, 2, 3]
def typeOf {α : Sort u} (a : α) := α
theorem LE_brecOn : typeOf @LE.brecOn =
∀ {motive : (a a_1 : Nat) → LE a a_1 → Prop} {a a_1 : Nat} (x : LE a a_1),
(∀ (a a_2 : Nat) (x : LE a a_2), @LE.below motive a a_2 x → motive a a_2 x) → motive a a_1 x := rfl
theorem LE.trans : LE m n → LE n o → LE m o := by
intro h1 h2
induction h2 with
| refl => assumption
| succ h2 ih => exact succ (ih h1)
theorem LE.trans' : LE m n → LE n o → LE m o
| h1, refl => h1
| h1, succ h2 => succ (trans' h1 h2) -- the structural recursion in being performed on the implicit `Nat` parameter
@ -32,8 +16,6 @@ theorem LE.trans' : LE m n → LE n o → LE m o
inductive Even : Nat → Prop
| zero : Even 0
| ss : Even n → Even n.succ.succ
#eval checkGetBelowIndices ``Even.zero #[]
#eval checkGetBelowIndices ``Even.ss #[1, 2]
theorem Even_brecOn : typeOf @Even.brecOn = ∀ {motive : (a : Nat) → Even a → Prop} {a : Nat} (x : Even a),
(∀ (a : Nat) (x : Even a), @Even.below motive a x → motive a x) → motive a x := rfl
@ -54,8 +36,6 @@ theorem mul_left_comm (n m o : Nat) : n * (m * o) = m * (n * o) := by
inductive Power2 : Nat → Prop
| base : Power2 1
| ind : Power2 n → Power2 (2*n) -- Note that index here is not a constructor
#eval checkGetBelowIndices ``Power2.base #[]
#eval checkGetBelowIndices ``Power2.ind #[1, 2]
theorem Power2_brecOn : typeOf @Power2.brecOn = ∀ {motive : (a : Nat) → Power2 a → Prop} {a : Nat} (x : Power2 a),
(∀ (a : Nat) (x : Power2 a), @Power2.below motive a x → motive a x) → motive a x := rfl
@ -92,9 +72,6 @@ inductive step : tm → tm → Prop :=
| ST_Plus2 : ∀ n1 t2 t2',
t2 ==> t2' →
P (C n1) t2 ==> P (C n1) t2'
#eval checkGetBelowIndices ``step.ST_PlusConstConst #[1, 2]
#eval checkGetBelowIndices ``step.ST_Plus1 #[1, 2, 3, 4]
#eval checkGetBelowIndices ``step.ST_Plus2 #[1, 2, 3, 4]
def deterministic {X : Type} (R : X → X → Prop) :=
∀ x y1 y2 : X, R x y1 → R x y2 → y1 = y2
@ -116,8 +93,6 @@ axiom f : Nat → Nat
inductive is_nat : Nat -> Prop
| Z : is_nat 0
| S {n} : is_nat n → is_nat (f n)
#eval checkGetBelowIndices ``is_nat.Z #[]
#eval checkGetBelowIndices ``is_nat.S #[1, 2]
axiom P : Nat → Prop
axiom F0 : P 0

View file

@ -1,10 +1,10 @@
inductive Foo (β : Type u) : Sort v → Type (max u v)
| mk {α : Sort v} (b : β) (a : α) : Foo β α
| mk₂ : Foo β PUnit
inductive Bla (α : Type u) : Type (u+1) where
| mk₁ (x : Foo (Bla α) Nat)
| mk₂ (n m : Nat) (x : Foo (Bla α) (n = m))
#exit
#print Bla.rec
#print Bla._sizeOf_1