lean4-htt/src/Lean/Meta/Match/Match.lean
2022-06-23 10:21:32 -07:00

1023 lines
45 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.CollectLevelParams
import Lean.Util.CollectFVars
import Lean.Util.Recognizers
import Lean.Compiler.ExternAttr
import Lean.Meta.Check
import Lean.Meta.Closure
import Lean.Meta.Tactic.Cases
import Lean.Meta.Tactic.Contradiction
import Lean.Meta.GeneralizeTelescope
import Lean.Meta.Match.Basic
import Lean.Meta.Match.MVarRenaming
import Lean.Meta.Match.CaseValues
namespace Lean.Meta.Match
/-- The number of patterns in each AltLHS must be equal to the number of discriminants. -/
private def checkNumPatterns (numDiscrs : Nat) (lhss : List AltLHS) : MetaM Unit := do
if lhss.any fun lhs => lhs.patterns.length != numDiscrs then
throwError "incorrect number of patterns"
/--
Execute `k hs` where `hs` contains new equalities `h : lhs[i] = rhs[i]` for each `discrInfos[i] = some h`.
Assume `lhs.size == rhs.size == discrInfos.size`
-/
private partial def withEqs (lhs rhs : Array Expr) (discrInfos : Array DiscrInfo) (k : Array Expr → MetaM α) : MetaM α := do
go 0 #[]
where
go (i : Nat) (hs : Array Expr) : MetaM α := do
if i < lhs.size then
if let some hName := discrInfos[i].hName? then
withLocalDeclD hName (← mkEqHEq lhs[i] rhs[i]) fun h =>
go (i+1) (hs.push h)
else
go (i+1) hs
else
k hs
/-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/
private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) (lhss : List AltLHS) (k : List Alt → Array (Expr × Nat) → MetaM α) : MetaM α :=
loop lhss [] #[]
where
mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr :=
withExistingLocalDecls lhs.fvarDecls do
let args ← lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true))
let minorType := mkAppN motive args
withEqs discrs args discrInfos fun eqs => do
mkForallFVars (xs ++ eqs) minorType
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array (Expr × Nat)) : MetaM α := do
match lhss with
| [] => k alts.reverse minors
| lhs::lhss =>
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let minorType ← mkMinorType xs lhs
let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let (minorType, minorNumParams) := if hasParams then (minorType, xs.size) else (mkSimpleThunkType minorType, 1)
let idx := alts.length
let minorName := (`h).appendIndexAfter (idx+1)
trace[Meta.Match.debug] "minor premise {minorName} : {minorType}"
withLocalDeclD minorName minorType fun minor => do
let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit)
let minors := minors.push (minor, minorNumParams)
let fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars
let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns : Alt } :: alts
loop lhss alts minors
def assignGoalOf (p : Problem) (e : Expr) : MetaM Unit :=
withGoalOf p do
let mvar := mkMVar p.mvarId
let mvarType ← inferType mvar
let eType ← inferType e
unless (← isDefEq mvarType eType) do
throwError "dependent elimination failed, type mismatch when solving alternative with type{indentExpr eType}\nbut expected{indentExpr mvarType}"
assignExprMVar p.mvarId e
structure State where
used : Std.HashSet Nat := {} -- used alternatives
counterExamples : List (List Example) := []
/-- Return true if the given (sub-)problem has been solved. -/
private def isDone (p : Problem) : Bool :=
p.vars.isEmpty
/-- Return true if the next element on the `p.vars` list is a variable. -/
private def isNextVar (p : Problem) : Bool :=
match p.vars with
| Expr.fvar _ _ :: _ => true
| _ => false
private def hasAsPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.as _ _ _ :: _ => true
| _ => false
private def hasCtorPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => true
| _ => false
private def hasValPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.val _ :: _ => true
| _ => false
private def hasNatValPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.val v :: _ => v.isNatLit
| _ => false
private def hasVarPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.var _ :: _ => true
| _ => false
private def hasArrayLitPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.arrayLit _ _ :: _ => true
| _ => false
private def isVariableTransition (p : Problem) : Bool :=
p.alts.all fun alt => match alt.patterns with
| Pattern.inaccessible _ :: _ => true
| Pattern.var _ :: _ => true
| _ => false
private def isConstructorTransition (p : Problem) : Bool :=
(hasCtorPattern p || p.alts.isEmpty)
&& p.alts.all fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => true
| Pattern.var _ :: _ => true
| Pattern.inaccessible _ :: _ => true
| _ => false
private def isValueTransition (p : Problem) : Bool :=
hasVarPattern p && hasValPattern p
&& p.alts.all fun alt => match alt.patterns with
| Pattern.val _ :: _ => true
| Pattern.var _ :: _ => true
| _ => false
private def isArrayLitTransition (p : Problem) : Bool :=
hasArrayLitPattern p && hasVarPattern p
&& p.alts.all fun alt => match alt.patterns with
| Pattern.arrayLit _ _ :: _ => true
| Pattern.var _ :: _ => true
| _ => false
private def isNatValueTransition (p : Problem) : Bool :=
hasNatValPattern p
&& (!isNextVar p ||
p.alts.any fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => true
| Pattern.inaccessible _ :: _ => true
| _ => false)
private def processSkipInaccessible (p : Problem) : Problem :=
match p.vars with
| [] => unreachable!
| _ :: xs =>
let alts := p.alts.map fun alt => match alt.patterns with
| Pattern.inaccessible _ :: ps => { alt with patterns := ps }
| _ => unreachable!
{ p with alts := alts, vars := xs }
private def processLeaf (p : Problem) : StateRefT State MetaM Unit := do
match p.alts with
| [] =>
/- TODO: allow users to configure which tactic is used to close leaves. -/
unless (← contradictionCore p.mvarId {}) do
trace[Meta.Match.match] "missing alternative"
admit p.mvarId
modify fun s => { s with counterExamples := p.examples :: s.counterExamples }
| alt :: _ =>
-- TODO: check whether we have unassigned metavars in rhs
liftM <| assignGoalOf p alt.rhs
modify fun s => { s with used := s.used.insert alt.idx }
private def processAsPattern (p : Problem) : MetaM Problem :=
match p.vars with
| [] => unreachable!
| x :: _ => withGoalOf p do
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| Pattern.as fvarId p h :: ps =>
/- We used to use `checkAndReplaceFVarId` here, but `x` and `fvarId` may have different types
when dependent types are beind used. Let's consider the repro for issue #471
```
inductive vec : Nat → Type
| nil : vec 0
| cons : Int → vec n → vec n.succ
def vec_len : vec n → Nat
| vec.nil => 0
| x@(vec.cons h t) => vec_len t + 1
```
we reach the state
```
[Meta.Match.match] remaining variables: [x✝:(vec n✝)]
alternatives:
[n:(Nat), x:(vec (Nat.succ n)), h:(Int), t:(vec n)] |- [x@(vec.cons n h t)] => h_1 n x h t
[x✝:(vec n✝)] |- [x✝] => h_2 n✝ x✝
```
The variables `x✝:(vec n✝)` and `x:(vec (Nat.succ n))` have different types, but we perform the substitution anyway,
because we claim the "discrepancy" will be corrected after we process the pattern `(vec.cons n h t)`.
The right-hand-side is temporarily type incorrect, but we claim this is fine because it will be type correct again after
we the pattern `(vec.cons n h t)`. TODO: try to find a cleaner solution.
-/
let r ← mkEqRefl x
return { alt with patterns := p :: ps }.replaceFVarId fvarId x |>.replaceFVarId h r
| _ => return alt
return { p with alts := alts }
private def processVariable (p : Problem) : MetaM Problem :=
match p.vars with
| [] => unreachable!
| x :: xs => withGoalOf p do
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| Pattern.inaccessible _ :: ps => return { alt with patterns := ps }
| Pattern.var fvarId :: ps => ({ alt with patterns := ps }).checkAndReplaceFVarId fvarId x
| _ => unreachable!
return { p with alts := alts, vars := xs }
private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do
let t ← inferType e
throwError "failed to compile pattern matching, inductive type expected{indentExpr e}\nhas type{indentExpr t}"
private def inLocalDecls (localDecls : List LocalDecl) (fvarId : FVarId) : Bool :=
localDecls.any fun d => d.fvarId == fvarId
namespace Unify
structure Context where
altFVarDecls : List LocalDecl
structure State where
fvarSubst : FVarSubst := {}
abbrev M := ReaderT Context $ StateRefT State MetaM
def isAltVar (fvarId : FVarId) : M Bool := do
return inLocalDecls (← read).altFVarDecls fvarId
def expandIfVar (e : Expr) : M Expr := do
match e with
| Expr.fvar _ _ => return (← get).fvarSubst.apply e
| _ => return e
def occurs (fvarId : FVarId) (v : Expr) : Bool :=
Option.isSome <| v.find? fun e => match e with
| Expr.fvar fvarId' _ => fvarId == fvarId'
| _=> false
def assign (fvarId : FVarId) (v : Expr) : M Bool := do
if occurs fvarId v then
trace[Meta.Match.unify] "assign occurs check failed, {mkFVar fvarId} := {v}"
return false
else
if (← isAltVar fvarId) then
trace[Meta.Match.unify] "{mkFVar fvarId} := {v}"
modify fun s => { s with fvarSubst := s.fvarSubst.insert fvarId v }
return true
else
trace[Meta.Match.unify] "assign failed variable is not local, {mkFVar fvarId} := {v}"
return false
partial def unify (a : Expr) (b : Expr) : M Bool := do
trace[Meta.Match.unify] "{a} =?= {b}"
if (← isDefEq a b) then
return true
else
let a' ← whnfD (← expandIfVar a)
let b' ← whnfD (← expandIfVar b)
if a != a' || b != b' then
unify a' b'
else match a, b with
| Expr.fvar aFvarId _, Expr.fvar bFVarId _ => assign aFvarId b <||> assign bFVarId a
| Expr.fvar aFvarId _, b => assign aFvarId b
| a, Expr.fvar bFVarId _ => assign bFVarId a
| Expr.app aFn aArg _, Expr.app bFn bArg _ => unify aFn bFn <&&> unify aArg bArg
| _, _ => return false
end Unify
private def unify? (altFVarDecls : List LocalDecl) (a b : Expr) : MetaM (Option FVarSubst) := do
let a ← instantiateMVars a
let b ← instantiateMVars b
let (r, s) ← Unify.unify a b { altFVarDecls := altFVarDecls} |>.run {}
if r then
return s.fvarSubst
else
trace[Meta.Match.unify] "failed to unify{indentExpr a}\nwith{indentExpr b}"
return none
private def expandVarIntoCtor? (alt : Alt) (fvarId : FVarId) (ctorName : Name) : MetaM (Option Alt) :=
withExistingLocalDecls alt.fvarDecls do
let expectedType ← inferType (mkFVar fvarId)
let expectedType ← whnfD expectedType
let (ctorLevels, ctorParams) ← getInductiveUniverseAndParams expectedType
let ctor := mkAppN (mkConst ctorName ctorLevels) ctorParams
let ctorType ← inferType ctor
forallTelescopeReducing ctorType fun ctorFields resultType => do
let ctor := mkAppN ctor ctorFields
let alt := alt.replaceFVarId fvarId ctor
let ctorFieldDecls ← ctorFields.mapM fun ctorField => getLocalDecl ctorField.fvarId!
let newAltDecls := ctorFieldDecls.toList ++ alt.fvarDecls
let subst? ← unify? newAltDecls resultType expectedType
match subst? with
| none => return none
| some subst =>
let newAltDecls := newAltDecls.filter fun d => !subst.contains d.fvarId -- remove declarations that were assigned
let newAltDecls := newAltDecls.map fun d => d.applyFVarSubst subst -- apply substitution to remaining declaration types
let patterns := alt.patterns.map fun p => p.applyFVarSubst subst
let rhs := subst.apply alt.rhs
let ctorFieldPatterns := ctorFields.toList.map fun ctorField => match subst.get ctorField.fvarId! with
| e@(Expr.fvar fvarId _) => if inLocalDecls newAltDecls fvarId then Pattern.var fvarId else Pattern.inaccessible e
| e => Pattern.inaccessible e
return some { alt with fvarDecls := newAltDecls, rhs := rhs, patterns := ctorFieldPatterns ++ patterns }
private def getInductiveVal? (x : Expr) : MetaM (Option InductiveVal) := do
let xType ← inferType x
let xType ← whnfD xType
match xType.getAppFn with
| Expr.const constName _ _ =>
let cinfo ← getConstInfo constName
match cinfo with
| ConstantInfo.inductInfo val => return some val
| _ => return none
| _ => return none
private def hasRecursiveType (x : Expr) : MetaM Bool := do
match (← getInductiveVal? x) with
| some val => return val.isRec
| _ => return false
/- Given `alt` s.t. the next pattern is an inaccessible pattern `e`,
try to normalize `e` into a constructor application.
If it is not a constructor, throw an error.
Otherwise, if it is a constructor application of `ctorName`,
update the next patterns with the fields of the constructor.
Otherwise, return none. -/
def processInaccessibleAsCtor (alt : Alt) (ctorName : Name) : MetaM (Option Alt) := do
let env ← getEnv
match alt.patterns with
| p@(Pattern.inaccessible e) :: ps =>
trace[Meta.Match.match] "inaccessible in ctor step {e}"
withExistingLocalDecls alt.fvarDecls do
-- Try to push inaccessible annotations.
let e ← whnfD e
match e.constructorApp? env with
| some (ctorVal, ctorArgs) =>
if ctorVal.name == ctorName then
let fields := ctorArgs.extract ctorVal.numParams ctorArgs.size
let fields := fields.toList.map Pattern.inaccessible
return some { alt with patterns := fields ++ ps }
else
return none
| _ => throwErrorAt alt.ref "dependent match elimination failed, inaccessible pattern found{indentD p.toMessageData}\nconstructor expected"
| _ => unreachable!
private def hasNonTrivialExample (p : Problem) : Bool :=
p.examples.any fun | Example.underscore => false | _ => true
private def throwCasesException (p : Problem) (ex : Exception) : MetaM α := do
match ex with
| Exception.error ref msg =>
let exampleMsg :=
if hasNonTrivialExample p then m!" after processing{indentD <| examplesToMessageData p.examples}" else ""
throw <| Exception.error ref <| m!"{msg}{exampleMsg}\n" ++
"the dependent pattern matcher can solve the following kinds of equations\n" ++
"- <var> = <term> and <term> = <var>\n" ++
"- <term> = <term> where the terms are definitionally equal\n" ++
"- <constructor> = <constructor>, examples: List.cons x xs = List.cons y ys, and List.cons x xs = List.nil"
| _ => throw ex
private def processConstructor (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "constructor step"
match p.vars with
| [] => unreachable!
| x :: xs => do
let subgoals? ← commitWhenSome? do
let subgoals ←
try
cases p.mvarId x.fvarId!
catch ex =>
if p.alts.isEmpty then
/- If we have no alternatives and dependent pattern matching fails, then a "missing cases" error is bettern than a "stuck" error message. -/
return none
else
throwCasesException p ex
if subgoals.isEmpty then
/- Easy case: we have solved problem `p` since there are no subgoals -/
return some #[]
else if !p.alts.isEmpty then
return some subgoals
else do
let isRec ← withGoalOf p <| hasRecursiveType x
/- If there are no alternatives and the type of the current variable is recursive, we do NOT consider
a constructor-transition to avoid nontermination.
TODO: implement a more general approach if this is not sufficient in practice -/
if isRec then
return none
else
return some subgoals
match subgoals? with
| none => return #[{ p with vars := xs }]
| some subgoals =>
subgoals.mapM fun subgoal => withMVarContext subgoal.mvarId do
let subst := subgoal.subst
let fields := subgoal.fields.toList
let newVars := fields ++ xs
let newVars := newVars.map fun x => x.applyFVarSubst subst
let subex := Example.ctor subgoal.ctorName <| fields.map fun field => match field with
| Expr.fvar fvarId _ => Example.var fvarId
| _ => Example.underscore -- This case can happen due to dependent elimination
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex
let examples := examples.map <| Example.applyFVarSubst subst
let newAlts := p.alts.filter fun alt => match alt.patterns with
| Pattern.ctor n _ _ _ :: _ => n == subgoal.ctorName
| Pattern.var _ :: _ => true
| Pattern.inaccessible _ :: _ => true
| _ => false
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
let newAlts ← newAlts.filterMapM fun alt => do
match alt.patterns with
| Pattern.ctor _ _ _ fields :: ps => return some { alt with patterns := fields ++ ps }
| Pattern.var fvarId :: ps => expandVarIntoCtor? { alt with patterns := ps } fvarId subgoal.ctorName
| Pattern.inaccessible _ :: _ => processInaccessibleAsCtor alt subgoal.ctorName
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
private def processNonVariable (p : Problem) : MetaM Problem :=
match p.vars with
| [] => unreachable!
| x :: xs => withGoalOf p do
let x ← whnfD x
let env ← getEnv
match x.constructorApp? env with
| some (ctorVal, xArgs) =>
let alts ← p.alts.filterMapM fun alt => do
match alt.patterns with
| Pattern.ctor n _ _ fields :: ps =>
if n != ctorVal.name then
return none
else
return some { alt with patterns := fields ++ ps }
| Pattern.inaccessible _ :: _ => processInaccessibleAsCtor alt ctorVal.name
| p :: _ => throwError "failed to compile pattern matching, inaccessible pattern or constructor expected{indentD p.toMessageData}"
| _ => unreachable!
let xFields := xArgs.extract ctorVal.numParams xArgs.size
return { p with alts := alts, vars := xFields.toList ++ xs }
| none =>
let alts ← p.alts.filterMapM fun alt => match alt.patterns with
| Pattern.inaccessible e :: ps => do
if (← isDefEq x e) then
return some { alt with patterns := ps }
else
return none
| p :: _ => throwError "failed to compile pattern matching, unexpected pattern{indentD p.toMessageData}\ndiscriminant{indentExpr x}"
| _ => unreachable!
return { p with alts := alts, vars := xs }
private def collectValues (p : Problem) : Array Expr :=
p.alts.foldl (init := #[]) fun values alt =>
match alt.patterns with
| Pattern.val v :: _ => if values.contains v then values else values.push v
| _ => values
private def isFirstPatternVar (alt : Alt) : Bool :=
match alt.patterns with
| Pattern.var _ :: _ => true
| _ => false
private def processValue (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "value step"
match p.vars with
| [] => unreachable!
| x :: xs =>
let values := collectValues p
let subgoals ← caseValues p.mvarId x.fvarId! values (substNewEqs := true)
subgoals.mapIdxM fun i subgoal => do
trace[Meta.Match.match] "processValue subgoal\n{MessageData.ofGoal subgoal.mvarId}"
if h : i.val < values.size then
let value := values.get ⟨i, h⟩
-- (x = value) branch
let subst := subgoal.subst
trace[Meta.Match.match] "processValue subst: {subst.map.toList.map fun p => mkFVar p.1}, {subst.map.toList.map fun p => p.2}"
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! (Example.val value)
let examples := examples.map <| Example.applyFVarSubst subst
let newAlts := p.alts.filter fun alt => match alt.patterns with
| Pattern.val v :: _ => v == value
| Pattern.var _ :: _ => true
| _ => false
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
let newAlts := newAlts.map fun alt => match alt.patterns with
| Pattern.val _ :: ps => { alt with patterns := ps }
| Pattern.var fvarId :: ps =>
let alt := { alt with patterns := ps }
alt.replaceFVarId fvarId value
| _ => unreachable!
let newVars := xs.map fun x => x.applyFVarSubst subst
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch for value
let newAlts := p.alts.filter isFirstPatternVar
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs }
private def collectArraySizes (p : Problem) : Array Nat :=
p.alts.foldl (init := #[]) fun sizes alt =>
match alt.patterns with
| Pattern.arrayLit _ ps :: _ => let sz := ps.length; if sizes.contains sz then sizes else sizes.push sz
| _ => sizes
private def expandVarIntoArrayLit (alt : Alt) (fvarId : FVarId) (arrayElemType : Expr) (arraySize : Nat) : MetaM Alt :=
withExistingLocalDecls alt.fvarDecls do
let fvarDecl ← getLocalDecl fvarId
let varNamePrefix := fvarDecl.userName
let rec loop (n : Nat) (newVars : Array Expr) := do
match n with
| n+1 =>
withLocalDeclD (varNamePrefix.appendIndexAfter (n+1)) arrayElemType fun x =>
loop n (newVars.push x)
| 0 =>
let arrayLit ← mkArrayLit arrayElemType newVars.toList
let alt := alt.replaceFVarId fvarId arrayLit
let newDecls ← newVars.toList.mapM fun newVar => getLocalDecl newVar.fvarId!
let newPatterns := newVars.toList.map fun newVar => Pattern.var newVar.fvarId!
return { alt with fvarDecls := newDecls ++ alt.fvarDecls, patterns := newPatterns ++ alt.patterns }
loop arraySize #[]
private def processArrayLit (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "array literal step"
match p.vars with
| [] => unreachable!
| x :: xs => do
let sizes := collectArraySizes p
let subgoals ← caseArraySizes p.mvarId x.fvarId! sizes
subgoals.mapIdxM fun i subgoal => do
if i.val < sizes.size then
let size := sizes.get! i
let subst := subgoal.subst
let elems := subgoal.elems.toList
let newVars := elems.map mkFVar ++ xs
let newVars := newVars.map fun x => x.applyFVarSubst subst
let subex := Example.arrayLit <| elems.map Example.var
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex
let examples := examples.map <| Example.applyFVarSubst subst
let newAlts := p.alts.filter fun alt => match alt.patterns with
| Pattern.arrayLit _ ps :: _ => ps.length == size
| Pattern.var _ :: _ => true
| _ => false
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
let newAlts ← newAlts.mapM fun alt => do
match alt.patterns with
| Pattern.arrayLit _ pats :: ps => return { alt with patterns := pats ++ ps }
| Pattern.var fvarId :: ps =>
let α ← getArrayArgType <| subst.apply x
expandVarIntoArrayLit { alt with patterns := ps } fvarId α size
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch
let newAlts := p.alts.filter isFirstPatternVar
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs }
private def expandNatValuePattern (p : Problem) : Problem :=
let alts := p.alts.map fun alt => match alt.patterns with
| Pattern.val (Expr.lit (Literal.natVal 0) _) :: ps => { alt with patterns := Pattern.ctor `Nat.zero [] [] [] :: ps }
| Pattern.val (Expr.lit (Literal.natVal (n+1)) _) :: ps => { alt with patterns := Pattern.ctor `Nat.succ [] [] [Pattern.val (mkRawNatLit n)] :: ps }
| _ => alt
{ p with alts := alts }
private def traceStep (msg : String) : StateRefT State MetaM Unit := do
trace[Meta.Match.match] "{msg} step"
private def traceState (p : Problem) : MetaM Unit :=
withGoalOf p (traceM `Meta.Match.match p.toMessageData)
private def throwNonSupported (p : Problem) : MetaM Unit :=
withGoalOf p do
let msg ← p.toMessageData
throwError "failed to compile pattern matching, stuck at{indentD msg}"
def isCurrVarInductive (p : Problem) : MetaM Bool := do
match p.vars with
| [] => return false
| x::_ => withGoalOf p do
let val? ← getInductiveVal? x
return val?.isSome
private def checkNextPatternTypes (p : Problem) : MetaM Unit := do
match p.vars with
| [] => return ()
| x::_ => withGoalOf p do
for alt in p.alts do
withRef alt.ref do
match alt.patterns with
| [] => return ()
| p::_ =>
let e ← p.toExpr
let xType ← inferType x
let eType ← inferType e
unless (← isDefEq xType eType) do
throwError "pattern{indentExpr e}\n{← mkHasTypeButIsExpectedMsg eType xType}"
private def List.moveToFront [Inhabited α] (as : List α) (i : Nat) : List α :=
let rec loop : (as : List α) → (i : Nat) → α × List α
| [], _ => unreachable!
| a::as, 0 => (a, as)
| a::as, i+1 =>
let (b, bs) := loop as i
(b, a::bs)
let (b, bs) := loop as i
b :: bs
/-- Move variable `#i` to the beginning of the to-do list `p.vars`. -/
private def moveToFront (p : Problem) (i : Nat) : Problem :=
if i == 0 then
p
else if i < p.vars.length then
{ p with
vars := List.moveToFront p.vars i
alts := p.alts.map fun alt => { alt with patterns := List.moveToFront alt.patterns i }
}
else
p
private partial def process (p : Problem) : StateRefT State MetaM Unit :=
search 0
where
/- If `p.vars` is empty, then we are done. Otherwise, we process `p.vars[0]`. -/
tryToProcess (p : Problem) : StateRefT State MetaM Unit := withIncRecDepth do
traceState p
let isInductive ← isCurrVarInductive p
if isDone p then
processLeaf p
else if hasAsPattern p then
traceStep ("as-pattern")
let p ← processAsPattern p
process p
else if isNatValueTransition p then
traceStep ("nat value to constructor")
process (expandNatValuePattern p)
else if !isNextVar p then
traceStep ("non variable")
let p ← processNonVariable p
process p
else if isInductive && isConstructorTransition p then
let ps ← processConstructor p
ps.forM process
else if isVariableTransition p then
traceStep ("variable")
let p ← processVariable p
process p
else if isValueTransition p then
let ps ← processValue p
ps.forM process
else if isArrayLitTransition p then
let ps ← processArrayLit p
ps.forM process
else if hasNatValPattern p then
-- This branch is reachable when `p`, for example, is just values without an else-alternative.
-- We added it just to get better error messages.
traceStep ("nat value to constructor")
process (expandNatValuePattern p)
else
checkNextPatternTypes p
throwNonSupported p
/- Return `true` if `type` does not depend on the first `i` elements in `xs` -/
checkVarDeps (xs : List Expr) (i : Nat) (type : Expr) : MetaM Bool := do
match i, xs with
| 0, _ => return true
| _, [] => unreachable!
| i+1, x::xs =>
if x.isFVar then
if (← dependsOn type x.fvarId!) then
return false
checkVarDeps xs i type
/--
Auxiliary method for `search`. Find next variable to "try".
`i` is the position of the variable s.t. `tryToProcess` failed.
We only consider variables that do not depend on other variables at `p.vars`. -/
findNext (i : Nat) : MetaM (Option Nat) := do
if h : i < p.vars.length then
let x := p.vars.get ⟨i, h⟩
if (← checkVarDeps p.vars i (← inferType x)) then
return i
else
findNext (i+1)
else
return none
/--
Auxiliary method for trying the next variable to process.
It moves variable `#i` to the front of the to-do list and invokes `tryToProcess`.
Note that for non-dependent elimination, variable `#0` always work.
If variable `#i` fails, we use `findNext` to select the next variable to try.
Remark: the "missing cases" error is not considered a failure. -/
search (i : Nat) : StateRefT State MetaM Unit := do
let s₁ ← getThe Meta.State
let s₂ ← get
let p' := moveToFront p i
try
tryToProcess p'
catch ex =>
match (← withGoalOf p <| findNext (i+1)) with
| none => throw ex
| some j =>
trace[Meta.Match.match] "failed with #{i}, trying #{j}{indentD ex.toMessageData}"
set s₁; set s₂; search j
private def getUElimPos? (matcherLevels : List Level) (uElim : Level) : MetaM (Option Nat) :=
if uElim == levelZero then
return none
else match matcherLevels.toArray.indexOf? uElim with
| none => throwError "dependent match elimination failed, universe level not found"
| some pos => return some pos.val
/- See comment at `mkMatcher` before `mkAuxDefinition` -/
register_builtin_option bootstrap.genMatcherCode : Bool := {
defValue := true
group := "bootstrap"
descr := "disable code generation for auxiliary matcher function"
}
builtin_initialize matcherExt : EnvExtension (Std.PHashMap (Expr × Bool) Name) ← registerEnvExtension (pure {})
/- Similar to `mkAuxDefinition`, but uses the cache `matcherExt`.
It also returns an Boolean that indicates whether a new matcher function was added to the environment or not. -/
def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (Expr × Option (MatcherInfo → MetaM Unit)) := do
trace[Meta.Match.debug] "{name} : {type} := {value}"
let compile := bootstrap.genMatcherCode.get (← getOptions)
let result ← Closure.mkValueTypeClosure type value (zeta := false)
let env ← getEnv
let mkMatcherConst name :=
mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
match (matcherExt.getState env).find? (result.value, compile) with
| some nameNew => return (mkMatcherConst nameNew, none)
| none =>
let decl := Declaration.defnDecl {
name
levelParams := result.levelParams.toList
type := result.type
value := result.value
hints := ReducibilityHints.abbrev
safety := if env.hasUnsafe result.type || env.hasUnsafe result.value then DefinitionSafety.unsafe else DefinitionSafety.safe
}
trace[Meta.Match.debug] "{name} : {result.type} := {result.value}"
let addMatcher : MatcherInfo → MetaM Unit := fun mi => do
addDecl decl
modifyEnv fun env => matcherExt.modifyState env fun s => s.insert (result.value, compile) name
addMatcherInfo name mi
setInlineAttribute name
if compile then
compileDecl decl
return (mkMatcherConst name, some addMatcher)
structure MkMatcherInput where
matcherName : Name
matchType : Expr
discrInfos : Array DiscrInfo
lhss : List AltLHS
def MkMatcherInput.numDiscrs (m : MkMatcherInput) :=
m.discrInfos.size
def MkMatcherInput.collectFVars (m : MkMatcherInput) : StateRefT CollectFVars.State MetaM Unit := do
m.matchType.collectFVars
m.lhss.forM fun alt => alt.collectFVars
def MkMatcherInput.collectDependencies (m : MkMatcherInput) : MetaM FVarIdSet := do
let (_, s) ← m.collectFVars |>.run {}
let s ← s.addDependencies
return s.fvarSet
/--
Auxiliary method used at `mkMatcher`. It executes `k` in a local context that contains only
the local declarations `m` depends on. This is important because otherwise dependent elimination
may "refine" the types of unnecessary declarations and accidentally introduce unnecessary dependencies
in the auto-generated auxiliary declaration. Note that this is not just an optimization because the
unnecessary dependencies may prevent the termination checker from succeeding. For an example,
see issue #1237.
-/
def withCleanLCtxFor (m : MkMatcherInput) (k : MetaM α) : MetaM α := do
let s ← m.collectDependencies
let lctx ← getLCtx
let lctx := lctx.foldr (init := lctx) fun localDecl lctx =>
if s.contains localDecl.fvarId then lctx else lctx.erase localDecl.fvarId
let localInstances := (← getLocalInstances).filter fun localInst => s.contains localInst.fvar.fvarId!
withLCtx lctx localInstances k
/--
Create a dependent matcher for `matchType` where `matchType` is of the form
`(a_1 : A_1) -> (a_2 : A_2[a_1]) -> ... -> (a_n : A_n[a_1, a_2, ... a_{n-1}]) -> B[a_1, ..., a_n]`
where `n = numDiscrs`, and the `lhss` are the left-hand-sides of the `match`-expression alternatives.
Each `AltLHS` has a list of local declarations and a list of patterns.
The number of patterns must be the same in each `AltLHS`.
The generated matcher has the structure described at `MatcherInfo`. The motive argument is of the form
`(motive : (a_1 : A_1) -> (a_2 : A_2[a_1]) -> ... -> (a_n : A_n[a_1, a_2, ... a_{n-1}]) -> Sort v)`
where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. -/
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do
let ⟨matcherName, matchType, discrInfos, lhss⟩ := input
let numDiscrs := discrInfos.size
let numEqs := getNumEqsFromDiscrInfos discrInfos
checkNumPatterns numDiscrs lhss
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
/- We generate an matcher that can eliminate using different motives with different universe levels.
`uElim` is the universe level the caller wants to eliminate to.
If it is not levelZero, we create a matcher that can eliminate in any universe level.
This is useful for implementing `MatcherApp.addArg` because it may have to change the universe level. -/
let uElim ← getLevel matchTypeBody
let uElimGen ← if uElim == levelZero then pure levelZero else mkFreshLevelMVar
let mkMatcher (type val : Expr) (minors : Array (Expr × Nat)) (s : State) : MetaM MatcherResult := do
trace[Meta.Match.debug] "matcher value: {val}\ntype: {type}"
trace[Meta.Match.debug] "minors num params: {minors.map (·.2)}"
/- The option `bootstrap.gen_matcher_code` is a helper hack. It is useful, for example,
for compiling `src/Init/Data/Int`. It is needed because the compiler uses `Int.decLt`
for generating code for `Int.casesOn` applications, but `Int.casesOn` is used to
give the reference implementation for
```
@[extern "lean_int_neg"] def neg (n : @& Int) : Int :=
match n with
| ofNat n => negOfNat n
| negSucc n => succ n
```
which is defined **before** `Int.decLt` -/
let (matcher, addMatcher) ← mkMatcherAuxDefinition matcherName type val
trace[Meta.Match.debug] "matcher levels: {matcher.getAppFn.constLevels!}, uElim: {uElimGen}"
let uElimPos? ← getUElimPos? matcher.getAppFn.constLevels! uElimGen
discard <| isLevelDefEq uElimGen uElim
let addMatcher :=
match addMatcher with
| some addMatcher => addMatcher <|
{ numParams := matcher.getAppNumArgs
altNumParams := minors.map fun minor => minor.2 + numEqs
discrInfos
numDiscrs
uElimPos?
}
| none => pure ()
trace[Meta.Match.debug] "matcher: {matcher}"
let unusedAltIdxs := lhss.length.fold (init := []) fun i r =>
if s.used.contains i then r else i::r
return {
matcher,
counterExamples := s.counterExamples,
unusedAltIdxs := unusedAltIdxs.reverse,
addMatcher
}
let motiveType ← mkForallFVars discrs (mkSort uElimGen)
trace[Meta.Match.debug] "motiveType: {motiveType}"
withLocalDeclD `motive motiveType fun motive => do
if discrInfos.any fun info => info.hName?.isSome then
forallBoundedTelescope matchType numDiscrs fun discrs' _ => do
let (mvarType, isEqMask) ← withEqs discrs discrs' discrInfos fun eqs => do
let mvarType ← mkForallFVars eqs (mkAppN motive discrs')
let isEqMask ← eqs.mapM fun eq => return (← inferType eq).isEq
return (mvarType, isEqMask)
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors => do
let mvar ← mkFreshExprMVar mvarType
trace[Meta.Match.debug] "goal\n{mvar.mvarId!}"
let examples := discrs'.toList.map fun discr => Example.var discr.fvarId!
let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs'.toList, alts := alts, examples := examples }).run {}
let val ← mkLambdaFVars discrs' mvar
trace[Meta.Match.debug] "matcher\nvalue: {val}\ntype: {← inferType val}"
let mut rfls := #[]
let mut isEqMaskIdx := 0
for discr in discrs, info in discrInfos do
if info.hName?.isSome then
if isEqMask[isEqMaskIdx] then
rfls := rfls.push (← mkEqRefl discr)
else
rfls := rfls.push (← mkHEqRefl discr)
isEqMaskIdx := isEqMaskIdx + 1
let val := mkAppN (mkAppN val discrs) rfls
let args := #[motive] ++ discrs ++ minors.map Prod.fst
let val ← mkLambdaFVars args val
let type ← mkForallFVars args (mkAppN motive discrs)
mkMatcher type val minors s
else
let mvarType := mkAppN motive discrs
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors => do
let mvar ← mkFreshExprMVar mvarType
let examples := discrs.toList.map fun discr => Example.var discr.fvarId!
let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {}
let args := #[motive] ++ discrs ++ minors.map Prod.fst
let type ← mkForallFVars args mvarType
let val ← mkLambdaFVars args mvar
mkMatcher type val minors s
def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do
let matcherName := matcherApp.matcherName
let some matcherInfo ← getMatcherInfo? matcherName | throwError "not a matcher: {matcherName}"
let matcherConst ← getConstInfo matcherName
let matcherType ← instantiateForall matcherConst.type <| matcherApp.params ++ #[matcherApp.motive]
let matchType ← do
let u :=
if let some idx := matcherInfo.uElimPos?
then mkLevelParam matcherConst.levelParams.toArray[idx]
else levelZero
forallBoundedTelescope matcherType (some matcherInfo.numDiscrs) fun discrs _ => do
mkForallFVars discrs (mkConst ``PUnit [u])
let matcherType ← instantiateForall matcherType matcherApp.discrs
let lhss ← forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ =>
alts.mapM fun alt => do
let ty ← inferType alt
forallTelescope ty fun xs body => do
let xs ← xs.filterM fun x => dependsOn body x.fvarId!
body.withApp fun _ args => do
let ctx ← getLCtx
let localDecls := xs.map ctx.getFVar!
let patterns ← args.mapM Match.toPattern
return {
ref := Syntax.missing
fvarDecls := localDecls.toList
patterns := patterns.toList : Match.AltLHS }
return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss := lhss.toList }
/- This function is only used for testing purposes -/
def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : MetaM α := do
let some matcherInfo ← getMatcherInfo? matcherName | throwError "not a matcher: {matcherName}"
let matcherConst ← getConstInfo matcherName
forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs _ => do
let matcherApp ← mkConstWithLevelParams matcherConst.name
let matcherApp := mkAppN matcherApp xs
let some matcherApp ← matchMatcherApp? matcherApp | throwError "not a matcher app: {matcherApp}"
let mkMatcherInput ← getMkMatcherInputInContext matcherApp
k mkMatcherInput
end Match
/- Auxiliary function for MatcherApp.addArg -/
private partial def updateAlts (typeNew : Expr) (altNumParams : Array Nat) (alts : Array Expr) (i : Nat) : MetaM (Array Nat × Array Expr) := do
if h : i < alts.size then
let alt := alts.get ⟨i, h⟩
let numParams := altNumParams[i]
let typeNew ← whnfD typeNew
match typeNew with
| Expr.forallE _ d b _ =>
let alt ← forallBoundedTelescope d (some numParams) fun xs d => do
let alt ← try instantiateLambda alt xs catch _ => throwError "unexpected matcher application, insufficient number of parameters in alternative"
forallBoundedTelescope d (some 1) fun x _ => do
let alt ← mkLambdaFVars x alt -- x is the new argument we are adding to the alternative
mkLambdaFVars xs alt
updateAlts (b.instantiate1 alt) (altNumParams.set! i (numParams+1)) (alts.set ⟨i, h⟩ alt) (i+1)
| _ => throwError "unexpected type at MatcherApp.addArg"
else
return (altNumParams, alts)
/- Given
- matcherApp `match_i As (fun xs => motive[xs]) discrs (fun ys_1 => (alt_1 : motive (C_1[ys_1])) ... (fun ys_n => (alt_n : motive (C_n[ys_n]) remaining`, and
- expression `e : B[discrs]`,
Construct the term
`match_i As (fun xs => B[xs] -> motive[xs]) discrs (fun ys_1 (y : B[C_1[ys_1]]) => alt_1) ... (fun ys_n (y : B[C_n[ys_n]]) => alt_n) e remaining`, and
We use `kabstract` to abstract the discriminants from `B[discrs]`.
This method assumes
- the `matcherApp.motive` is a lambda abstraction where `xs.size == discrs.size`
- each alternative is a lambda abstraction where `ys_i.size == matcherApp.altNumParams[i]`
-/
def MatcherApp.addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp :=
lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do
unless motiveArgs.size == matcherApp.discrs.size do
-- This error can only happen if someone implemented a transformation that rewrites the motive created by `mkMatcher`.
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
let eType ← inferType e
let eTypeAbst ← matcherApp.discrs.size.foldRevM (init := eType) fun i eTypeAbst => do
let motiveArg := motiveArgs[i]
let discr := matcherApp.discrs[i]
let eTypeAbst ← kabstract eTypeAbst discr
return eTypeAbst.instantiate1 motiveArg
let motiveBody ← mkArrow eTypeAbst motiveBody
let matcherLevels ← match matcherApp.uElimPos? with
| none => pure matcherApp.matcherLevels
| some pos =>
let uElim ← getLevel motiveBody
pure <| matcherApp.matcherLevels.set! pos uElim
let motive ← mkLambdaFVars motiveArgs motiveBody
-- Construct `aux` `match_i As (fun xs => B[xs] → motive[xs]) discrs`, and infer its type `auxType`.
-- We use `auxType` to infer the type `B[C_i[ys_i]]` of the new argument in each alternative.
let aux := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) matcherApp.params
let aux := mkApp aux motive
let aux := mkAppN aux matcherApp.discrs
unless (← isTypeCorrect aux) do
throwError "failed to add argument to matcher application, type error when constructing the new motive"
let auxType ← inferType aux
let (altNumParams, alts) ← updateAlts auxType matcherApp.altNumParams matcherApp.alts 0
return { matcherApp with
matcherLevels := matcherLevels,
motive := motive,
alts := alts,
altNumParams := altNumParams,
remaining := #[e] ++ matcherApp.remaining
}
/-- Similar `MatcherApp.addArg?`, but returns `none` on failure. -/
def MatcherApp.addArg? (matcherApp : MatcherApp) (e : Expr) : MetaM (Option MatcherApp) :=
try
return some (← matcherApp.addArg e)
catch _ =>
return none
builtin_initialize
registerTraceClass `Meta.Match.match
registerTraceClass `Meta.Match.debug
registerTraceClass `Meta.Match.unify
end Lean.Meta