refactor: use isDefEq instead of custom unify procedure

See comment with new issue at #1361
This commit is contained in:
Leonardo de Moura 2022-08-02 17:59:32 -07:00
parent ae5db0f563
commit a9e7290e4b
2 changed files with 68 additions and 249 deletions

View file

@ -221,134 +221,62 @@ private def processVariable (p : Problem) : MetaM Problem := withGoalOf p do
| _ => unreachable!
return { p with alts := alts, vars := xs }
private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do
let t ← inferType e
throwError "failed to compile pattern matching, inductive type expected{indentExpr e}\nhas type{indentExpr t}"
/-
TODO: FIX the following issue.
`fvarId` is not an alternative variable, and we used to return `false` here, but it is incorrect, and may
incorrectly discard applicable alternatives. It was buggy because of the way we handle inaccessible patterns
in variable transitions. The bug was exposed by issue #1279
Here is a simplified version of the example on this issue (see test: `1279_simplified.lean`)
```lean
inductive Arrow : Type → Type → Type 1
| id : Arrow a a
| unit : Arrow Unit Unit
| comp : Arrow β γ → Arrow α β → Arrow α γ
deriving Repr
def Arrow.compose (f : Arrow β γ) (g : Arrow α β) : Arrow α γ :=
match f, g with
| id, g => g
| f, id => f
| f, g => comp f g
```
The initial state for the `match`-expression above is
```lean
[Meta.Match.match] remaining variables: [β✝:(Type), γ✝:(Type), f✝:(Arrow β✝ γ✝), g✝:(Arrow α β✝)]
alternatives:
[β:(Type), g:(Arrow α β)] |- [β, .(β), (Arrow.id .(β)), g] => h_1 β g
[γ:(Type), f:(Arrow α γ)] |- [.(α), γ, f, (Arrow.id .(α))] => h_2 γ f
[β:(Type), γ:(Type), f:(Arrow β γ), g:(Arrow α β)] |- [β, γ, f, g] => h_3 β γ f g
```
The first step is a variable-transition which replaces `β` with `β✝` in the first and third alternatives.
The constraint `β✝ === α` in the second alternative is lost. Note that `α` is not an alternative variable.
After applying the variable-transition step twice, we reach the following state
```lean
[Meta.Match.match] remaining variables: [f✝:(Arrow β✝ γ✝), g✝:(Arrow α β✝)]
alternatives:
[g:(Arrow α β✝)] |- [(Arrow.id .(β✝)), g] => h_1 β✝ g
[f:(Arrow α γ✝)] |- [f, (Arrow.id .(α))] => h_2 γ✝ f
[f:(Arrow β✝ γ✝), g:(Arrow α β✝)] |- [f, g] => h_3 β✝ γ✝ f g
```
A constructor-transition should be used, and the functions `expandVarIntoCtor?` is required for the second and
third alternatives. There are 3 constructors, in the `Arrow.id` case, we use unify to solve
```
Arrow a a =?= Arrow α β✝
```
Where `a` is new alternative variable corresponding to the `Arrow.id` field.
The first assignment is fine `a := α`.
In the second assignment we have `α := β✝` where both `α` and `β✝` are not alternative variables.
We did not store information that `β✝ === α` in the first step, and the alternative was being incorrectly discarded.
Returning `true` here "solves" the problem, but it is a bit hackish. We see two possible improvements:
- We store the constraint `β✝ === α`.
- We postpone variable-transition steps.
It is unclear at this point what is the best solution. We should keep accumulating problematic examples.
-/
private def inLocalDecls (localDecls : List LocalDecl) (fvarId : FVarId) : Bool :=
localDecls.any fun d => d.fvarId == fvarId
namespace Unify
structure Context where
altFVarDecls : List LocalDecl
structure State where
fvarSubst : FVarSubst := {}
abbrev M := ReaderT Context $ StateRefT State MetaM
def isAltVar (fvarId : FVarId) : M Bool := do
return inLocalDecls (← read).altFVarDecls fvarId
def expandIfVar (e : Expr) : M Expr := do
match e with
| .fvar _ => return (← get).fvarSubst.apply e
| _ => return e
def occurs (fvarId : FVarId) (v : Expr) : Bool :=
Option.isSome <| v.find? fun e => match e with
| .fvar fvarId' => fvarId == fvarId'
| _ => false
def assign (fvarId : FVarId) (v : Expr) : M Bool := do
if occurs fvarId v then
trace[Meta.Match.unify] "assign occurs check failed, {mkFVar fvarId} := {v}"
return false
else
if (← isAltVar fvarId) then
trace[Meta.Match.unify] "{mkFVar fvarId} := {v}"
modify fun s => { s with fvarSubst := s.fvarSubst.insert fvarId v }
return true
else
/-
TODO: improve this branch. Returning `true` here is an approximation.
`fvarId` is not an alternative variable, and we used to return `false` here, but it is incorrect, and may
incorrectly discard applicable alternatives. It was buggy because of the way we handle inaccessible patterns
in variable transitions. The bug was exposed by issue #1279
Here is a simplified version of the example on this issue (see test: `1279_simplified.lean`)
```lean
inductive Arrow : Type → Type → Type 1
| id : Arrow a a
| unit : Arrow Unit Unit
| comp : Arrow β γ → Arrow α β → Arrow α γ
deriving Repr
def Arrow.compose (f : Arrow β γ) (g : Arrow α β) : Arrow α γ :=
match f, g with
| id, g => g
| f, id => f
| f, g => comp f g
```
The initial state for the `match`-expression above is
```lean
[Meta.Match.match] remaining variables: [β✝:(Type), γ✝:(Type), f✝:(Arrow β✝ γ✝), g✝:(Arrow α β✝)]
alternatives:
[β:(Type), g:(Arrow α β)] |- [β, .(β), (Arrow.id .(β)), g] => h_1 β g
[γ:(Type), f:(Arrow α γ)] |- [.(α), γ, f, (Arrow.id .(α))] => h_2 γ f
[β:(Type), γ:(Type), f:(Arrow β γ), g:(Arrow α β)] |- [β, γ, f, g] => h_3 β γ f g
```
The first step is a variable-transition which replaces `β` with `β✝` in the first and third alternatives.
The constraint `β✝ === α` in the second alternative is lost. Note that `α` is not an alternative variable.
After applying the variable-transition step twice, we reach the following state
```lean
[Meta.Match.match] remaining variables: [f✝:(Arrow β✝ γ✝), g✝:(Arrow α β✝)]
alternatives:
[g:(Arrow α β✝)] |- [(Arrow.id .(β✝)), g] => h_1 β✝ g
[f:(Arrow α γ✝)] |- [f, (Arrow.id .(α))] => h_2 γ✝ f
[f:(Arrow β✝ γ✝), g:(Arrow α β✝)] |- [f, g] => h_3 β✝ γ✝ f g
```
A constructor-transition should be used, and the functions `expandVarIntoCtor?` is required for the second and
third alternatives. There are 3 constructors, in the `Arrow.id` case, we use unify to solve
```
Arrow a a =?= Arrow α β✝
```
Where `a` is new alternative variable corresponding to the `Arrow.id` field.
The first assignment is fine `a := α`.
In the second assignment we have `α := β✝` where both `α` and `β✝` are not alternative variables.
We did not store information that `β✝ === α` in the first step, and the alternative was being incorrectly discarded.
Returning `true` here "solves" the problem, but it is a bit hackish. We see two possible improvements:
- We store the constraint `β✝ === α`.
- We postpone variable-transition steps.
It is unclear at this point what is the best solution. We should keep accumulating problematic examples.
-/
return true
partial def unify (a : Expr) (b : Expr) : M Bool := do
trace[Meta.Match.unify] "{a} =?= {b}"
if (← isDefEq a b) then
return true
else
let a' ← whnfD (← expandIfVar a)
let b' ← whnfD (← expandIfVar b)
if a != a' || b != b' then
unify a' b'
else match a, b with
| .fvar aFvarId, .fvar bFVarId => assign aFvarId b <||> assign bFVarId a
| .fvar aFvarId, b => assign aFvarId b
| a, .fvar bFVarId => assign bFVarId a
| .app aFn aArg, .app bFn bArg => unify aFn bFn <&&> unify aArg bArg
| _, _ =>
let a' := (← get).fvarSubst.apply a
let b' := (← get).fvarSubst.apply b
if a != a' || b != b' then
unify a' b'
else
return false
end Unify
private def unify? (altFVarDecls : List LocalDecl) (a b : Expr) : MetaM (Option FVarSubst) := do
trace[Meta.Match.unify] "altFVarDecls: {altFVarDecls.map fun d => d.userName}, {a} =?= {b}"
let a ← instantiateMVars a
let b ← instantiateMVars b
let (r, s) ← Unify.unify a b { altFVarDecls := altFVarDecls} |>.run {}
if r then
return s.fvarSubst
else
trace[Meta.Match.unify] "failed to unify{indentExpr a}\nwith{indentExpr b}"
return none
private def expandVarIntoCtor? (alt : Alt) (fvarId : FVarId) (ctorName : Name) : MetaM (Option Alt) :=
withExistingLocalDecls alt.fvarDecls do
trace[Meta.Match.unify] "expandVarIntoCtor? fvarId: {mkFVar fvarId}, ctorName: {ctorName}, alt:\n{← alt.toMessageData}"
@ -357,27 +285,23 @@ private def expandVarIntoCtor? (alt : Alt) (fvarId : FVarId) (ctorName : Name) :
let (ctorLevels, ctorParams) ← getInductiveUniverseAndParams expectedType
let ctor := mkAppN (mkConst ctorName ctorLevels) ctorParams
let ctorType ← inferType ctor
-- TODO: try to rewrite this code using metavariables using `isDefEq` instead of `unify?`, and then
-- convert unassigned metavariables to fresh free variables.
-- Reason: `unify?` is too buggy
forallTelescopeReducing ctorType fun ctorFields resultType => do
let ctor := mkAppN ctor ctorFields
let ctorAbst? ← withNewMCtxDepth do
let (ctorFields, _, resultType) ← forallMetaTelescopeReducing ctorType
unless (← isDefEq resultType expectedType) do
return none
return some (← abstractMVars (mkAppN ctor ctorFields))
let some ctorAbst := ctorAbst? | return none
lambdaTelescope ctorAbst.expr fun newFVars ctor => do
let ctorArgs := ctor.getAppArgs
let ctorFields := ctorArgs[ctorParams.size:].toArray
let alt := alt.replaceFVarId fvarId ctor
let ctorFieldDecls ← ctorFields.mapM fun ctorField => ctorField.fvarId!.getDecl
let newAltDecls := ctorFieldDecls.toList ++ alt.fvarDecls
trace[Meta.Match.unify] "expandVarIntoCtor? {mkFVar fvarId} : {expectedType}, ctor: {ctor}, resultType: {resultType}"
let subst? ← unify? newAltDecls resultType expectedType
match subst? with
| none => return none
| some subst =>
let newAltDecls := newAltDecls.filter fun d => !subst.contains d.fvarId -- remove declarations that were assigned
let newAltDecls := newAltDecls.map fun d => d.applyFVarSubst subst -- apply substitution to remaining declaration types
let patterns := alt.patterns.map fun p => p.applyFVarSubst subst
let rhs := subst.apply alt.rhs
let ctorFieldPatterns := ctorFields.toList.map fun ctorField => match subst.get ctorField.fvarId! with
| e@(.fvar fvarId) => if inLocalDecls newAltDecls fvarId then .var fvarId else .inaccessible e
| e => .inaccessible e
return some { alt with fvarDecls := newAltDecls, rhs := rhs, patterns := ctorFieldPatterns ++ patterns }
let newAltDecls ← newFVars.mapM fun newFVar => newFVar.fvarId!.getDecl
let newAltDecls := newAltDecls.toList ++ alt.fvarDecls
trace[Meta.Match.unify] "expandVarIntoCtor? {mkFVar fvarId} : {expectedType}, ctor: {ctor}"
let ctorFieldPatterns := ctorFields.toList.map fun ctorField => match ctorField with
| .fvar fvarId => if inLocalDecls newAltDecls fvarId then Pattern.var fvarId else Pattern.inaccessible ctorField
| _ => Pattern.inaccessible ctorField
return some { alt with fvarDecls := newAltDecls, patterns := ctorFieldPatterns ++ alt.patterns }
private def getInductiveVal? (x : Expr) : MetaM (Option InductiveVal) := do
let xType ← inferType x

View file

@ -240,108 +240,3 @@ default
-- set_option trace.Meta.Match.match true
-- set_option trace.Meta.Match.debug true
#eval test `ex6 2 `elimTest6
-- #print elimTest6
def ex7 (α : Type u) (n : Nat) (xs : Vec α n) :
LHS (forall (a : α), Pat (inaccessible 1) × Pat (Vec.cons a Vec.nil))
× LHS (forall (N : Nat) (XS : Vec α N), Pat (inaccessible N) × Pat XS) :=
default
#eval test `ex7 2 `elimTest7
-- #check elimTest7
def isSizeOne {n : Nat} (xs : Vec Nat n) : Bool :=
elimTest7 _ (fun _ _ => Bool) n xs (fun _ => true) (fun _ _ => false)
#eval isSizeOne Vec.nil
#eval isSizeOne (Vec.cons 1 Vec.nil)
#eval isSizeOne (Vec.cons 2 (Vec.cons 1 Vec.nil))
def singleton? {n : Nat} (xs : Vec Nat n) : Option Nat :=
elimTest7 _ (fun _ _ => Option Nat) n xs (fun a => some a) (fun _ _ => none)
#eval singleton? Vec.nil
#eval singleton? (Vec.cons 10 Vec.nil)
#eval singleton? (Vec.cons 20 (Vec.cons 10 Vec.nil))
def ex8 (α : Type u) (n : Nat) (xs : Vec α n) :
LHS (forall (a b : α), Pat (inaccessible 2) × Pat (Vec.cons a (Vec.cons b Vec.nil)))
× LHS (forall (N : Nat) (XS : Vec α N), Pat (inaccessible N) × Pat XS) :=
default
#eval test `ex8 2 `elimTest8
#print elimTest8
def pair? {n : Nat} (xs : Vec Nat n) : Option (Nat × Nat) :=
elimTest8 _ (fun _ _ => Option (Nat × Nat)) n xs (fun a b => some (a, b)) (fun _ _ => none)
#eval pair? Vec.nil
#eval pair? (Vec.cons 10 Vec.nil)
#eval pair? (Vec.cons 20 (Vec.cons 10 Vec.nil))
inductive Op : Nat → Nat → Type
| mk : ∀ n, Op n n
#print Op
inductive Foo : Bool → Prop
| bar : Foo false
| baz : Foo false
def ex10 (x : Bool) (y : Foo x) :
LHS (Pat (inaccessible false) × Pat Foo.bar)
× LHS (forall (x : Bool) (y : Foo x), Pat (inaccessible x) × Pat y) :=
default
#eval test `ex10 2 `elimTest10 true
def ex12 (x y z : Bool) :
LHS (forall (x y : Bool), Pat x × Pat y × Pat true)
× LHS (forall (x z : Bool), Pat false × Pat true × Pat z)
× LHS (forall (y z : Bool), Pat true × Pat false × Pat z) :=
default
#eval testFailure `ex12 3 `elimTest12 -- should produce error message
def ex14 (x y : Nat) :
LHS (Pat (val 1) × Pat (val 2))
× LHS (Pat (val 2) × Pat (val 3))
× LHS (forall (x y : Nat), Pat x × Pat y) :=
default
-- set_option trace.Meta.Match true
#eval test `ex14 2 `elimTest14
#print elimTest14
def h2 (x y : Nat) : Nat :=
elimTest14 (fun _ _ => Nat) x y (fun _ => 0) (fun _ => 1) (fun x y => x + y)
#eval check (h2 1 2 == 0)
#eval check (h2 1 4 == 5)
#eval check (h2 2 3 == 1)
#eval check (h2 2 4 == 6)
#eval check (h2 3 4 == 7)
def ex15 (xs : Array (List Nat)) :
LHS (forall (a : Nat), Pat (ArrayLit1 [a]))
× LHS (forall (a b : Nat), Pat (ArrayLit2 [a] [b]))
× LHS (forall (ys : Array (List Nat)), Pat ys) :=
default
#eval test `ex15 1 `elimTest15
-- #check elimTest15
def h3 (xs : Array (List Nat)) : Nat :=
elimTest15 (fun _ => Nat) xs
(fun a => a + 1)
(fun a b => a + b)
(fun ys => ys.size)
#eval check (h3 #[[1]] == 2)
#eval check (h3 #[[3], [2]] == 5)
#eval check (h3 #[[1, 2]] == 1)
#eval check (h3 #[[1, 2], [2, 3], [3]] == 3)