From dfa392fa17f4a17ffb0ceef0d9fa4f72178ef29f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 4 Mar 2020 16:27:01 -0800 Subject: [PATCH] feat: add `generalizeIndices` Helper tactic for `cases` --- src/Init/Lean/Meta/Tactic.lean | 1 + src/Init/Lean/Meta/Tactic/Cases.lean | 137 ++++++++++++++++++++++++--- src/Init/Lean/MetavarContext.lean | 2 +- tests/lean/run/genindices.lean | 27 ++++++ 4 files changed, 154 insertions(+), 13 deletions(-) create mode 100644 tests/lean/run/genindices.lean diff --git a/src/Init/Lean/Meta/Tactic.lean b/src/Init/Lean/Meta/Tactic.lean index 139b8d0c94..e1ff2d4947 100644 --- a/src/Init/Lean/Meta/Tactic.lean +++ b/src/Init/Lean/Meta/Tactic.lean @@ -15,3 +15,4 @@ import Init.Lean.Meta.Tactic.Rewrite import Init.Lean.Meta.Tactic.Generalize import Init.Lean.Meta.Tactic.LocalDecl import Init.Lean.Meta.Tactic.Induction +import Init.Lean.Meta.Tactic.Cases diff --git a/src/Init/Lean/Meta/Tactic/Cases.lean b/src/Init/Lean/Meta/Tactic/Cases.lean index 28cd5eb664..9476218492 100644 --- a/src/Init/Lean/Meta/Tactic/Cases.lean +++ b/src/Init/Lean/Meta/Tactic/Cases.lean @@ -9,6 +9,94 @@ import Init.Lean.Meta.Tactic.Induction namespace Lean namespace Meta +private def mkEq (lhs rhs : Expr) : MetaM (Expr × Expr) := do +lhsType ← inferType lhs; +rhsType ← inferType rhs; +u ← getLevel lhsType; +condM (isDefEq lhsType rhsType) + (pure (mkApp3 (mkConst `Eq [u]) lhsType lhs rhs, mkApp2 (mkConst `Eq.refl [u]) lhsType lhs)) + (pure (mkApp4 (mkConst `HEq [u]) lhsType lhs rhsType rhs, mkApp2 (mkConst `HEq.refl [u]) lhsType lhs)) + +private partial def withNewIndexEqsAux {α} (indices newIndices : Array Expr) (k : Array Expr → Array Expr → MetaM α) : Nat → Array Expr → Array Expr → MetaM α +| i, newEqs, newRefls => + if h : i < indices.size then do + let index := indices.get! i; + let newIndex := newIndices.get! i; + (newEqType, newRefl) ← mkEq index newIndex; + withLocalDecl `h newEqType BinderInfo.default $ fun newEq => do + withNewIndexEqsAux (i+1) (newEqs.push newEq) (newRefls.push newRefl) + else + k newEqs newRefls + +private def withNewIndexEqs {α} (indices newIndices : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α := +withNewIndexEqsAux indices newIndices k 0 #[] #[] + +structure GeneralizeIndicesSubgoal := +(mvarId : MVarId) +(indicesFVarIds : Array FVarId) +(fvarId : FVarId) +(numEqs : Nat) + +/-- + Given a metavariable `mvarId` representing the + ``` + Ctx, h : I A j, D |- T + ``` + where `fvarId` is `h`s id, and the type `I A j` is an inductive datatype where `A` are parameters, + and `j` the indices. Generate the goal + ``` + Ctx, h : I A j, D, j' : J, h' : I A j' |- j == j' -> h == h' -> T + ``` + Remark: `(j == j' -> h == h')` is a "telescopic" equality. + Remark: `j` is sequence of terms, and `j'` a sequence of free variables. + The result contains the fields + - `mvarId`: the new goal + - `indicesFVarIds`: `j'` ids + - `fvarId`: `h'` id + - `numEqs`: number of equations in the target -/ +def generalizeIndices (mvarId : MVarId) (fvarId : FVarId) : MetaM GeneralizeIndicesSubgoal := do +withMVarContext mvarId $ do + lctx ← getLCtx; + localInsts ← getLocalInstances; + env ← getEnv; + checkNotAssigned mvarId `generalizeIndices; + fvarDecl ← getLocalDecl fvarId; + type ← whnf fvarDecl.type; + type.withApp $ fun f args => matchConst env f (fun _ => throwTacticEx `generalizeIndices mvarId "inductive type expected") $ + fun cinfo _ => match cinfo with + | ConstantInfo.inductInfo val => do + unless (val.nindices > 0) $ throwTacticEx `generalizeIndices mvarId "indexed inductive type expected"; + unless (args.size == val.nindices + val.nparams) $ throwTacticEx `generalizeIndices mvarId "ill-formed inductive datatype"; + let indices := args.extract (args.size - val.nindices) args.size; + let IA := mkAppN f (args.extract 0 val.nparams); -- `I A` + IAType ← inferType IA; + forallTelescopeReducing IAType $ fun newIndices _ => do + let newType := mkAppN IA newIndices; + withLocalDecl fvarDecl.userName newType BinderInfo.default $ fun h' => + withNewIndexEqs indices newIndices $ fun newEqs newRefls => do + (newEqType, newRefl) ← mkEq fvarDecl.toExpr h'; + let newRefls := newRefls.push newRefl; + withLocalDecl `h newEqType BinderInfo.default $ fun newEq => do + let newEqs := newEqs.push newEq; + /- auxType `forall (j' : J) (h' : I A j'), j == j' -> h == h' -> target -/ + target ← getMVarType mvarId; + tag ← getMVarTag mvarId; + auxType ← mkForall newEqs target; + auxType ← mkForall #[h'] auxType; + auxType ← mkForall newIndices auxType; + newMVar ← mkFreshExprMVarAt lctx localInsts auxType tag MetavarKind.syntheticOpaque; + /- assign mvarId := newMVar indices h refls -/ + assignExprMVar mvarId (mkAppN (mkApp (mkAppN newMVar indices) fvarDecl.toExpr) newRefls); + (indicesFVarIds, newMVarId) ← introN newMVar.mvarId! newIndices.size; + (fvarId, newMVarId) ← intro1 newMVarId; + pure { + mvarId := newMVarId, + indicesFVarIds := indicesFVarIds, + fvarId := fvarId, + numEqs := newEqs.size + } + | _ => throwTacticEx `generalizeIndices mvarId "inductive type expected" + structure CasesSubgoal := (ctorName : Name) (mvarId : MVarId) @@ -18,11 +106,15 @@ structure CasesSubgoal := namespace Cases structure Context := -(inductiveVal : InductiveVal) -(casesOnVal : DefinitionVal) -(nminors : Nat := inductiveVal.ctors.length) +(inductiveVal : InductiveVal) +(casesOnVal : DefinitionVal) +(nminors : Nat := inductiveVal.ctors.length) +(majorDecl : LocalDecl) +(majorTypeFn : Expr) +(majorTypeArgs : Array Expr) +(majorTypeIndices : Array Expr := majorTypeArgs.extract (majorTypeArgs.size - inductiveVal.nindices) majorTypeArgs.size) -private def mkCasesContex? (majorFVarId : FVarId) : MetaM (Option Context) := do +private def mkCasesContext? (majorFVarId : FVarId) : MetaM (Option Context) := do env ← getEnv; if !env.contains `Eq || env.contains `HEq then pure none else do @@ -33,17 +125,38 @@ else do | ConstantInfo.inductInfo ival => if args.size != ival.nindices + ival.nparams then pure none else match env.find? (mkNameStr ival.name "casesOn") with - | ConstantInfo.defnInfo cval => pure $ some { inductiveVal := ival, casesOnVal := cval } + | ConstantInfo.defnInfo cval => pure $ some { + inductiveVal := ival, + casesOnVal := cval, + majorDecl := majorDecl, + majorTypeFn := f, + majorTypeArgs := args + } | _ => pure none | _ => pure none -private def mkEq (lhs rhs : Expr) : MetaM (Expr × Expr) := do -lhsType ← inferType lhs; -rhsType ← inferType rhs; -u ← getLevel lhsType; -condM (isDefEq lhsType rhsType) - (pure (mkApp3 (mkConst `Eq [u]) lhsType lhs rhs, mkApp2 (mkConst `Eq.refl [u]) lhsType lhs)) - (pure (mkApp4 (mkConst `HEq [u]) lhsType lhs rhsType rhs, mkApp2 (mkConst `HEq.refl [u]) lhsType lhs)) +/- +We say the major premise has independent indices IF +1- its type is *not* an indexed inductive family, OR +2- its type is an indexed inductive family, but all indices are distinct free variables, and + all local declarations different from the major and its indices do not depend on the indices. +-/ +private def hasIndepIndices (ctx : Context) : MetaM Bool := +if ctx.majorTypeIndices.isEmpty then + pure true +else if ctx.majorTypeIndices.any $ fun idx => !idx.isFVar then + /- One of the indices is not a free variable. -/ + pure false +else if ctx.majorTypeIndices.size.any $ fun i => i.any $ fun j => ctx.majorTypeIndices.get! i == ctx.majorTypeIndices.get! j then + /- An index ocurrs more than once -/ + pure false +else do + lctx ← getLCtx; + mctx ← getMCtx; + pure $ lctx.all $ fun decl => + decl.fvarId == ctx.majorDecl.fvarId || -- decl is the major + ctx.majorTypeIndices.any (fun index => decl.fvarId == index.fvarId!) || -- decl is one of the indices + mctx.findLocalDeclDependsOn decl (fun fvarId => ctx.majorTypeIndices.all $ fun idx => idx.fvarId! != fvarId) -- or does not depend on any index end Cases diff --git a/src/Init/Lean/MetavarContext.lean b/src/Init/Lean/MetavarContext.lean index 2f379a287a..29fa7222d3 100644 --- a/src/Init/Lean/MetavarContext.lean +++ b/src/Init/Lean/MetavarContext.lean @@ -636,7 +636,7 @@ end DependsOn (DependsOn.main mctx p e).run' {} /-- - Similar to `exprDependsOn`, but checks the expressions in the given local declaration + Similar to `findExprDependsOn`, but checks the expressions in the given local declaration depends on a free variable `x` s.t. `p x` is `true`. -/ @[inline] def findLocalDeclDependsOn (mctx : MetavarContext) (localDecl : LocalDecl) (p : FVarId → Bool) : Bool := match localDecl with diff --git a/tests/lean/run/genindices.lean b/tests/lean/run/genindices.lean new file mode 100644 index 0000000000..a7c390a8a3 --- /dev/null +++ b/tests/lean/run/genindices.lean @@ -0,0 +1,27 @@ +import Init.Lean + +universe u + +inductive Pred : ∀ (α : Type u), List α → Type (u+1) +| foo {α : Type u} (l1 : List α) (l2 : List (Option α)) : Pred (Option α) l2 → Pred α l1 + +axiom goal : forall (α : Type u) (xs : List (List α)) (h : Pred (List α) xs), xs ≠ [] → xs = xs + +open Lean +open Lean.Meta + +def tst1 : MetaM Unit := do +cinfo ← getConstInfo `goal; +let type := cinfo.type; +mvar ← mkFreshExprMVar type; +trace! `Elab (MessageData.ofGoal mvar.mvarId!); +(_, mvarId) ← introN mvar.mvarId! 2; +(fvarId, mvarId) ← intro1 mvarId; +trace! `Elab (MessageData.ofGoal mvarId); +s ← generalizeIndices mvarId fvarId; +trace! `Elab (MessageData.ofGoal s.mvarId); +pure () + +set_option trace.Elab true + +#eval tst1