feat: improve "constant approximation" heuristic used at isDefEq
This commit is contained in:
parent
19bcb5fb31
commit
7a81589c49
2 changed files with 56 additions and 5 deletions
|
|
@ -920,14 +920,60 @@ private def assignConst (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool :=
|
|||
trace[Meta.isDefEq.constApprox] "{mvar} := {v}"
|
||||
checkTypesAndAssign mvar v
|
||||
|
||||
private def processConstApprox (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool := do
|
||||
/--
|
||||
Auxiliary procedure for solving `?m args =?= v` when `args[:patternVarPrefix]` contains
|
||||
only pairwise distinct free variables.
|
||||
Let `args[:patternVarPrefix] = #[a₁, ..., aₙ]`, and `args[patternVarPrefix:] = #[b₁, ..., bᵢ]`,
|
||||
this procedure first reduces the constraint to
|
||||
```
|
||||
?m a₁ ... aₙ =?= fun x₁ ... xᵢ => v
|
||||
```
|
||||
where the left-hand-side is a constant function.
|
||||
Then, it tries to find the longest prefix `#[a₁, ..., aⱼ]` of `#[a₁, ..., aₙ]` such that the following assignment is valid.
|
||||
```
|
||||
?m := fun y₁ ... yⱼ => (fun y_{j+1} ... yₙ x₁ ... xᵢ => v)[a₁/y₁, .., aⱼ/yⱼ]
|
||||
```
|
||||
That is, after the longest prefix is found, we solve the contraint as the lhs was a pattern. See the definition of "pattern" above.
|
||||
-/
|
||||
private partial def processConstApprox (mvar : Expr) (args : Array Expr) (patternVarPrefix : Nat) (v : Expr) : MetaM Bool := do
|
||||
trace[Meta.isDefEq.constApprox] "{mvar} {args} := {v}"
|
||||
let rec defaultCase : MetaM Bool := assignConst mvar args.size v
|
||||
let cfg ← getConfig
|
||||
let mvarId := mvar.mvarId!
|
||||
let mvarDecl ← getMVarDecl mvarId
|
||||
if mvarDecl.numScopeArgs == numArgs || cfg.constApprox then
|
||||
assignConst mvar numArgs v
|
||||
let numArgs := args.size
|
||||
if mvarDecl.numScopeArgs != numArgs && !cfg.constApprox then
|
||||
return false
|
||||
else if patternVarPrefix == 0 then
|
||||
defaultCase
|
||||
else
|
||||
pure false
|
||||
let argsPrefix : Array Expr := args[:patternVarPrefix]
|
||||
let type ← instantiateForall mvarDecl.type argsPrefix
|
||||
let suffixSize := numArgs - argsPrefix.size
|
||||
forallBoundedTelescope type suffixSize fun xs _ => do
|
||||
if xs.size != suffixSize then
|
||||
defaultCase
|
||||
else
|
||||
let some v ← mkLambdaFVarsWithLetDeps xs v | defaultCase
|
||||
let rec go (argsPrefix : Array Expr) (v : Expr) : MetaM Bool := do
|
||||
trace[Meta.isDefEq] "processConstApprox.go {mvar} {argsPrefix} := {v}"
|
||||
let rec cont : MetaM Bool := do
|
||||
if argsPrefix.isEmpty then
|
||||
defaultCase
|
||||
else
|
||||
let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back] v | defaultCase
|
||||
go argsPrefix.pop v
|
||||
match (← checkAssignment mvarId argsPrefix v) with
|
||||
| none => cont
|
||||
| some vNew =>
|
||||
let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew | cont
|
||||
if argsPrefix.any (fun arg => mvarDecl.lctx.containsFVar arg) then
|
||||
/- We need to type check `vNew` because abstraction using `mkLambdaFVars` may have produced
|
||||
a type incorrect term. See discussion at A2 -/
|
||||
(isTypeCorrect vNew <&&> checkTypesAndAssign mvar vNew) <||> cont
|
||||
else
|
||||
checkTypesAndAssign mvar vNew <||> cont
|
||||
go argsPrefix v
|
||||
|
||||
/-- Tries to solve `?m a₁ ... aₙ =?= v` by assigning `?m`.
|
||||
It assumes `?m` is unassigned. -/
|
||||
|
|
@ -939,7 +985,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool :
|
|||
let rec process (i : Nat) (args : Array Expr) (v : Expr) := do
|
||||
let cfg ← getConfig
|
||||
let useFOApprox (args : Array Expr) : MetaM Bool :=
|
||||
processAssignmentFOApprox mvar args v <||> processConstApprox mvar args.size v
|
||||
processAssignmentFOApprox mvar args v <||> processConstApprox mvar args i v
|
||||
if h : i < args.size then
|
||||
let arg := args.get ⟨i, h⟩
|
||||
let arg ← simpAssignmentArg arg
|
||||
|
|
|
|||
5
tests/lean/run/isDefEqConstApproxIssue.lean
Normal file
5
tests/lean/run/isDefEqConstApproxIssue.lean
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
def allPairsAux (xs: List α) (ys: List β) (accum: List (α × β)) :=
|
||||
match xs, ys with
|
||||
| _, [] => accum
|
||||
| [], _ => accum
|
||||
| x::xs, y::ys => allPairsAux xs ys ((x, y)::accum)
|
||||
Loading…
Add table
Reference in a new issue