From 7a81589c491209881c8aa7cb200e9864135f778d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 19 Feb 2022 08:09:31 -0800 Subject: [PATCH] feat: improve "constant approximation" heuristic used at `isDefEq` --- src/Lean/Meta/ExprDefEq.lean | 56 +++++++++++++++++++-- tests/lean/run/isDefEqConstApproxIssue.lean | 5 ++ 2 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 tests/lean/run/isDefEqConstApproxIssue.lean diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 91e0c3a07c..5f38f7974e 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -920,14 +920,60 @@ private def assignConst (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool := trace[Meta.isDefEq.constApprox] "{mvar} := {v}" checkTypesAndAssign mvar v -private def processConstApprox (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool := do +/-- + Auxiliary procedure for solving `?m args =?= v` when `args[:patternVarPrefix]` contains + only pairwise distinct free variables. + Let `args[:patternVarPrefix] = #[a₁, ..., aₙ]`, and `args[patternVarPrefix:] = #[b₁, ..., bᵢ]`, + this procedure first reduces the constraint to + ``` + ?m a₁ ... aₙ =?= fun x₁ ... xᵢ => v + ``` + where the left-hand-side is a constant function. + Then, it tries to find the longest prefix `#[a₁, ..., aⱼ]` of `#[a₁, ..., aₙ]` such that the following assignment is valid. + ``` + ?m := fun y₁ ... y‌ⱼ => (fun y_{j+1} ... yₙ x₁ ... xᵢ => v)[a₁/y₁, .., aⱼ/yⱼ] + ``` + That is, after the longest prefix is found, we solve the contraint as the lhs was a pattern. See the definition of "pattern" above. +-/ +private partial def processConstApprox (mvar : Expr) (args : Array Expr) (patternVarPrefix : Nat) (v : Expr) : MetaM Bool := do + trace[Meta.isDefEq.constApprox] "{mvar} {args} := {v}" + let rec defaultCase : MetaM Bool := assignConst mvar args.size v let cfg ← getConfig let mvarId := mvar.mvarId! let mvarDecl ← getMVarDecl mvarId - if mvarDecl.numScopeArgs == numArgs || cfg.constApprox then - assignConst mvar numArgs v + let numArgs := args.size + if mvarDecl.numScopeArgs != numArgs && !cfg.constApprox then + return false + else if patternVarPrefix == 0 then + defaultCase else - pure false + let argsPrefix : Array Expr := args[:patternVarPrefix] + let type ← instantiateForall mvarDecl.type argsPrefix + let suffixSize := numArgs - argsPrefix.size + forallBoundedTelescope type suffixSize fun xs _ => do + if xs.size != suffixSize then + defaultCase + else + let some v ← mkLambdaFVarsWithLetDeps xs v | defaultCase + let rec go (argsPrefix : Array Expr) (v : Expr) : MetaM Bool := do + trace[Meta.isDefEq] "processConstApprox.go {mvar} {argsPrefix} := {v}" + let rec cont : MetaM Bool := do + if argsPrefix.isEmpty then + defaultCase + else + let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back] v | defaultCase + go argsPrefix.pop v + match (← checkAssignment mvarId argsPrefix v) with + | none => cont + | some vNew => + let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew | cont + if argsPrefix.any (fun arg => mvarDecl.lctx.containsFVar arg) then + /- We need to type check `vNew` because abstraction using `mkLambdaFVars` may have produced + a type incorrect term. See discussion at A2 -/ + (isTypeCorrect vNew <&&> checkTypesAndAssign mvar vNew) <||> cont + else + checkTypesAndAssign mvar vNew <||> cont + go argsPrefix v /-- Tries to solve `?m a₁ ... aₙ =?= v` by assigning `?m`. It assumes `?m` is unassigned. -/ @@ -939,7 +985,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : let rec process (i : Nat) (args : Array Expr) (v : Expr) := do let cfg ← getConfig let useFOApprox (args : Array Expr) : MetaM Bool := - processAssignmentFOApprox mvar args v <||> processConstApprox mvar args.size v + processAssignmentFOApprox mvar args v <||> processConstApprox mvar args i v if h : i < args.size then let arg := args.get ⟨i, h⟩ let arg ← simpAssignmentArg arg diff --git a/tests/lean/run/isDefEqConstApproxIssue.lean b/tests/lean/run/isDefEqConstApproxIssue.lean new file mode 100644 index 0000000000..4371191c05 --- /dev/null +++ b/tests/lean/run/isDefEqConstApproxIssue.lean @@ -0,0 +1,5 @@ +def allPairsAux (xs: List α) (ys: List β) (accum: List (α × β)) := + match xs, ys with + | _, [] => accum + | [], _ => accum + | x::xs, y::ys => allPairsAux xs ys ((x, y)::accum)