diff --git a/src/Lean/Meta/CongrTheorems.lean b/src/Lean/Meta/CongrTheorems.lean index e93593ebf8..fa93449bc3 100644 --- a/src/Lean/Meta/CongrTheorems.lean +++ b/src/Lean/Meta/CongrTheorems.lean @@ -26,6 +26,11 @@ inductive CongrArgKind where The lemma contains three parameters for this kind of argument `a_i`, `b_i` and `eq_i : HEq a_i b_i`. `a_i` and `b_i` represent the left and right hand sides, and `eq_i` is a proof for their heterogeneous equality. -/ heq + | /-- + For congr-simp theorems only. Indicates a decidable instance argument. + The lemma contains two arguments [a_i : Decidable ...] [b_i : Decidable ...] -/ + subsingletonInst + deriving Inhabited structure CongrTheorem where type : Expr @@ -110,4 +115,144 @@ where def mkHCongr (f : Expr) : MetaM CongrTheorem := do mkHCongrWithArity f (← getFunInfo f).getArity +/-- + Ensure that all dependencies for `congr_arg_kind::Eq` are `congr_arg_kind::Fixed`. +-/ +private def fixKindsForDependencies (info : FunInfo) (kinds : Array CongrArgKind) : Array CongrArgKind := Id.run do + let mut kinds := kinds + for i in [:info.paramInfo.size] do + for j in [i+1:info.paramInfo.size] do + if info.paramInfo[j].backDeps.contains i then + if kinds[j] matches CongrArgKind.eq || kinds[j] matches CongrArgKind.fixed then + -- We must fix `i` because there is a `j` that depends on `i` and `j` is not cast-fixed. + kinds := kinds.set! i CongrArgKind.fixed + break + return kinds + +/-- + (Try to) cast expression `e` to the given type using the equations `eqs`. + `deps` contains the indices of the relevant equalities. + Remark: deps is sorted. -/ +private partial def mkCast (e : Expr) (type : Expr) (deps : Array Nat) (eqs : Array (Option Expr)) : MetaM Expr := do + let rec go (i : Nat) (type : Expr) : MetaM Expr := do + if i < deps.size then + match eqs[deps[i]] with + | none => go (i+1) type + | some major => + let some (_, lhs, rhs) := (← inferType major).eq? | unreachable! + if (← dependsOn type major.fvarId!) then + let motive ← mkLambdaFVars #[rhs, major] type + let typeNew := type.replaceFVar rhs lhs |>.replaceFVar major (← mkEqRefl lhs) + let minor ← go (i+1) typeNew + mkEqRec motive minor major + else + let motive ← mkLambdaFVars #[rhs] type + let typeNew := type.replaceFVar rhs lhs + let minor ← go (i+1) typeNew + mkEqNDRec motive minor major + else + return e + go 0 type + +private def hasCastLike (kinds : Array CongrArgKind) : Bool := + kinds.any fun kind => kind matches CongrArgKind.cast || kind matches CongrArgKind.subsingletonInst + +/-- + Create a congruence theorem that is useful for the simplifier. +-/ +partial def mkCongrSimpWithArity? (f : Expr) (numArgs : Nat) : MetaM (Option CongrTheorem) := do + let info ← getFunInfo f + let kinds := getKinds info + if let some result ← mk? f info kinds then + return some result + else if hasCastLike kinds then + -- Simplify kinds and try again + let kinds := kinds.map fun kind => + if kind matches CongrArgKind.cast || kind matches CongrArgKind.subsingletonInst then CongrArgKind.fixed else kind + mk? f info kinds + else + return none +where + /-- + Create a congruence theorem that is useful for the simplifier. + In this kind of theorem, if the i-th argument is a `cast` argument, then the theorem + contains an input `a_i` representing the i-th argument in the left-hand-side, and + it appears with a cast (e.g., `Eq.drec ... a_i ...`) in the right-hand-side. + The idea is that the right-hand-side of this lemma "tells" the simplifier + how the resulting term looks like. -/ + mk? (f : Expr) (info : FunInfo) (kinds : Array CongrArgKind) : MetaM (Option CongrTheorem) := do + try + let fType ← inferType f + forallBoundedTelescope fType kinds.size fun lhss xType => do + if lhss.size != kinds.size then return none + let rec go (i : Nat) (rhss : Array Expr) (eqs : Array (Option Expr)) (hyps : Array Expr) : MetaM CongrTheorem := do + if i == kinds.size then + let lhs := mkAppN f lhss + let rhs := mkAppN f rhss + let type ← mkForallFVars hyps (← mkEq lhs rhs) + let proof ← mkProof type kinds + return { type, proof, argKinds := kinds } + else + let hyps := hyps.push lhss[i] + match kinds[i] with + | CongrArgKind.eq => + let localDecl ← getLocalDecl lhss[i].fvarId! + withLocalDecl localDecl.userName localDecl.binderInfo localDecl.type fun rhs => do + withLocalDeclD ((`e).appendIndexAfter (eqs.size+1)) (← mkEq lhss[i] rhs) fun eq => do + go (i+1) (rhss.push rhs) (eqs.push eq) (hyps.push rhs |>.push eq) + | CongrArgKind.heq => unreachable! + | CongrArgKind.fixed => go (i+1) (rhss.push lhss[i]) (eqs.push none) hyps + | CongrArgKind.fixedNoParam => unreachable! + | CongrArgKind.cast => + let rhsType := (← inferType lhss[i]).replaceFVars (lhss[:rhss.size]) rhss + let rhs ← mkCast lhss[i] rhsType info.paramInfo[i].backDeps eqs + go (i+1) (rhss.push rhs) (eqs.push none) hyps + | CongrArgKind.subsingletonInst => + let rhsType := (← inferType lhss[i]).replaceFVars (lhss[:rhss.size]) rhss + withLocalDecl (← getLocalDecl lhss[i].fvarId!).userName BinderInfo.instImplicit rhsType fun rhs => + go (i+1) (rhss.push rhs) (eqs.push none) (hyps.push rhs) + return some (← go 0 #[] #[] #[]) + catch _ => + return none + + mkProof (type : Expr) (kinds : Array CongrArgKind) : MetaM Expr := do + mkSorry type false -- TODO + + getKinds (info : FunInfo) : Array CongrArgKind := Id.run do + /- The default `CongrArgKind` is `eq`, which allows `simp` to rewrite this + argument. However, if there are references from `i` to `j`, we cannot + rewrite both `i` and `j`. So we must change the `CongrArgKind` at + either `i` or `j`. In principle, if there is a dependency with `i` + appearing after `j`, then we set `j` to `fixed` (or `cast`). But there is + an optimization: if `i` is a subsingleton, we can fix it instead of + `j`, since all subsingletons are equal anyway. The fixing happens in + two loops: one for the special cases, and one for the general case. -/ + let mut result := #[] + for i in [:info.paramInfo.size] do + if info.resultDeps.contains i then + result := result.push CongrArgKind.fixed + else if info.paramInfo[i].isProp then + result := result.push CongrArgKind.cast + else if info.paramInfo[i].isInstImplicit then + if shouldUseSubsingletonInst info result i then + result := result.push CongrArgKind.subsingletonInst + else + result := result.push CongrArgKind.fixed + else + result := result.push CongrArgKind.eq + return fixKindsForDependencies info result + + /-- + Test whether we should use `subsingletonInst` kind for instances which depend on `eq`. + (Otherwise `fixKindsForDependencies`will downgrade them to Fixed -/ + shouldUseSubsingletonInst (info : FunInfo) (kinds : Array CongrArgKind) (i : Nat) : Bool := Id.run do + if info.paramInfo[i].isDecInst then + for j in info.paramInfo[i].backDeps do + if kinds[j] matches CongrArgKind.eq then + return true + return false + +def mkCongrSimp? (f : Expr) : MetaM (Option CongrTheorem) := do + mkCongrSimpWithArity? f (← getFunInfo f).getArity + end Lean.Meta diff --git a/tests/lean/run/congrThm.lean b/tests/lean/run/congrThm.lean new file mode 100644 index 0000000000..3eb355e5d3 --- /dev/null +++ b/tests/lean/run/congrThm.lean @@ -0,0 +1,12 @@ +import Lean + +open Lean +open Lean.Meta + +def test (f : Expr) : MetaM Unit := do + let some thm ← mkCongrSimp? f | unreachable! + check thm.type + IO.println (← Meta.ppExpr thm.type) + +#eval test (mkConst ``decide) +#eval test (mkConst ``Array.uget [levelZero])