diff --git a/src/Lean/Meta/Sym/Apply.lean b/src/Lean/Meta/Sym/Apply.lean index 05f52a2326..f284700519 100644 --- a/src/Lean/Meta/Sym/Apply.lean +++ b/src/Lean/Meta/Sym/Apply.lean @@ -87,6 +87,21 @@ public def mkBackwardRuleFromDecl (declName : Name) (num? : Option Nat := none) let resultPos := mkResultPos pattern return { expr := mkConst declName, pattern, resultPos } +/-- +Creates a `BackwardRule` from an expression. + +`levelParams` is not `[]` if the expression is supposed to be +universe polymorphic. + +The `num?` parameter optionally limits how many arguments are included in the pattern +(useful for partially applying theorems). +-/ +public def mkBackwardRuleFromExpr (e : Expr) (levelParams : List Name := []) (num? : Option Nat := none) : MetaM BackwardRule := do + let pattern ← mkPatternFromExpr e levelParams num? + let resultPos := mkResultPos pattern + let e := e.instantiateLevelParams levelParams (pattern.levelParams.map mkLevelParam) + return { expr := e, pattern, resultPos } + /-- Creates a value to assign to input goal metavariable using unification result. diff --git a/src/Lean/Meta/Sym/Pattern.lean b/src/Lean/Meta/Sym/Pattern.lean index 90502e0f92..741ed18caa 100644 --- a/src/Lean/Meta/Sym/Pattern.lean +++ b/src/Lean/Meta/Sym/Pattern.lean @@ -99,7 +99,7 @@ def isUVar? (n : Name) : Option Nat := Id.run do return some idx /-- Helper function for implementing `mkPatternFromDecl` and `mkEqPatternFromDecl` -/ -def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do +def preprocessDeclPattern (declName : Name) : MetaM (List Name × Expr) := do let info ← getConstInfo declName let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i let us := levelParams.map mkLevelParam @@ -107,6 +107,14 @@ def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do let type ← preprocessType type return (levelParams, type) +def preprocessExprPattern (e : Expr) (levelParams₀ : List Name) : MetaM (List Name × Expr) := do + let type ← inferType e + let levelParams := levelParams₀.mapIdx fun i _ => Name.num uvarPrefix i + let us := levelParams.map mkLevelParam + let type := type.instantiateLevelParams levelParams₀ us + let type ← preprocessType type + return (levelParams, type) + /-- Creates a mask indicating which pattern variables require type checking during matching. @@ -167,6 +175,16 @@ def mkPatternCore (type : Expr) (levelParams : List Name) (varTypes : Array Expr mkProofInstArgInfo? xs return { levelParams, varTypes, pattern, fnInfos, varInfos?, checkTypeMask? } +def mkPatternFromType (levelParams : List Name) (type : Expr) (num? : Option Nat) : MetaM Pattern := do + let hugeNumber := 10000000 + let num := num?.getD hugeNumber + let rec go (i : Nat) (pattern : Expr) (varTypes : Array Expr) : MetaM Pattern := do + if i < num then + if let .forallE _ d b _ := pattern then + return (← go (i+1) b (varTypes.push d)) + mkPatternCore type levelParams varTypes pattern + go 0 type #[] + /-- Creates a `Pattern` from the type of a theorem. @@ -181,15 +199,22 @@ If `num?` is `some n`, at most `n` leading quantifiers are stripped. If `num?` is `none`, all leading quantifiers are stripped. -/ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : MetaM Pattern := do - let (levelParams, type) ← preprocessPattern declName - let hugeNumber := 10000000 - let num := num?.getD hugeNumber - let rec go (i : Nat) (pattern : Expr) (varTypes : Array Expr) : MetaM Pattern := do - if i < num then - if let .forallE _ d b _ := pattern then - return (← go (i+1) b (varTypes.push d)) - mkPatternCore type levelParams varTypes pattern - go 0 type #[] + let (levelParams, type) ← preprocessDeclPattern declName + mkPatternFromType levelParams type num? + +public def mkPatternFromExpr (e : Expr) (levelParams : List Name := []) (num? : Option Nat := none) : MetaM Pattern := do + let (levelParams, type) ← preprocessExprPattern e levelParams + mkPatternFromType levelParams type num? + +def mkEqPatternFromType (levelParams : List Name) (type : Expr) : MetaM (Pattern × Expr) := do + let rec go (pattern : Expr) (varTypes : Array Expr) : MetaM (Pattern × Expr) := do + if let .forallE _ d b _ := pattern then + return (← go b (varTypes.push d)) + else + let_expr Eq _ lhs rhs := pattern | throwError "conclusion is not a equality{indentExpr type}" + let pattern ← mkPatternCore type levelParams varTypes lhs + return (pattern, rhs) + go type #[] /-- Creates a `Pattern` from an equational theorem, using the left-hand side of the equation. @@ -203,15 +228,8 @@ For a theorem `∀ x₁ ... xₙ, lhs = rhs`, returns a pattern matching `lhs` w Throws an error if the theorem's conclusion is not an equality. -/ public def mkEqPatternFromDecl (declName : Name) : MetaM (Pattern × Expr) := do - let (levelParams, type) ← preprocessPattern declName - let rec go (pattern : Expr) (varTypes : Array Expr) : MetaM (Pattern × Expr) := do - if let .forallE _ d b _ := pattern then - return (← go b (varTypes.push d)) - else - let_expr Eq _ lhs rhs := pattern | throwError "resulting type for `{.ofConstName declName}` is not an equality" - let pattern ← mkPatternCore type levelParams varTypes lhs - return (pattern, rhs) - go type #[] + let (levelParams, type) ← preprocessDeclPattern declName + mkEqPatternFromType levelParams type structure UnifyM.Context where pattern : Pattern diff --git a/tests/lean/run/sym_pattern.lean b/tests/lean/run/sym_pattern.lean index 1f9ca3326d..d8ff7c3b47 100644 --- a/tests/lean/run/sym_pattern.lean +++ b/tests/lean/run/sym_pattern.lean @@ -84,3 +84,58 @@ def test4 : SymM Unit := do /-- info: pFoo (3 + y) -/ #guard_msgs in #eval SymM.run test4 + +def ex₂ := ∃ x : Nat, True ∧ x = .zero + +def test5 : SymM Unit := do + let ruleEx ← mkBackwardRuleFromExpr <| mkApp (mkConst ``Exists.intro [1]) Nat.mkType + let ruleAnd ← mkBackwardRuleFromExpr <| mkApp (mkConst ``And.intro) (mkConst ``True) + let ruleTrue ← mkBackwardRuleFromExpr <| (mkConst ``True.intro) + let ruleRefl ← mkBackwardRuleFromDecl ``Eq.refl + let mvar ← mkFreshExprMVar (← getConstInfo ``ex₂).value! + let mvarId ← preprocessMVar mvar.mvarId! + let .goals [mvarId, _] ← ruleEx.apply mvarId | failure + let .goals [mvarId₁, mvarId₂] ← ruleAnd.apply mvarId | failure + let .goals [] ← ruleTrue.apply mvarId₁ | failure + let .goals [] ← ruleRefl.apply mvarId₂ | failure + logInfo mvar + +/-- +info: @Exists.intro Nat (fun x => And True (@Eq Nat x Nat.zero)) Nat.zero + (@And.intro True (@Eq Nat Nat.zero Nat.zero) True.intro (@Eq.refl Nat Nat.zero)) +-/ +#guard_msgs in +set_option pp.explicit true in +#eval SymM.run test5 + +def ex₃ := (Nat × Type) × (Nat × Prop) + +def test6 : SymM Unit := do + let ruleProd ← mkBackwardRuleFromDecl ``Prod.mk + -- `u` is universe parameter in the following rule + let ruleProdNat ← mkBackwardRuleFromExpr (mkApp (mkConst ``Prod.mk [0, mkLevelParam `u]) Nat.mkType) [`u] + let mvar ← mkFreshExprMVar (← getConstInfo ``ex₃).value! + let mvarId ← preprocessMVar mvar.mvarId! + let .goals [mvarId₁, mvarId₂] ← ruleProd.apply mvarId | failure + logInfo mvarId₁ + logInfo mvarId₂ + -- **Note**: `ruleProdNat` is applied with different `u`s in the following two applications + let .goals [mvarId₁₁, mvarId₁₂] ← ruleProdNat.apply mvarId₁ | failure + let .goals [mvarId₂₁, mvarId₂₂] ← ruleProdNat.apply mvarId₂ | failure + mvarId₁₁.assign (mkNatLit 0) + mvarId₂₁.assign (mkNatLit 1) + mvarId₁₂.assign Nat.mkType + mvarId₂₂.assign (mkConst ``True) + logInfo mvar + check (← instantiateMVars mvar) + +/-- +info: ⊢ Prod.{0, 1} Nat Type +--- +info: ⊢ Prod.{0, 0} Nat Prop +--- +info: Prod.mk.{1, 0} (Prod.mk.{0, 1} 0 Nat) (Prod.mk.{0, 0} 1 True) +-/ +#guard_msgs in +set_option pp.universes true in +#eval SymM.run test6