diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index ca36e25448..860621edca 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -38,6 +38,7 @@ where go (i+1) hs else k hs + /-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/ private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) (lhss : List AltLHS) (k : List Alt → Array (Expr × Nat) → MetaM α) : MetaM α := loop lhss [] #[] @@ -165,25 +166,26 @@ private def processSkipInaccessible (p : Problem) : Problem := | _ => unreachable! { p with alts := alts, vars := xs } -private def processLeaf (p : Problem) : StateRefT State MetaM Unit := +private def processLeaf (p : Problem) : StateRefT State MetaM Unit := do match p.alts with - | [] => do + | [] => /- TODO: allow users to configure which tactic is used to close leaves. -/ unless (← contradictionCore p.mvarId {}) do trace[Meta.Match.match] "missing alternative" admit p.mvarId modify fun s => { s with counterExamples := p.examples :: s.counterExamples } - | alt :: _ => do + | alt :: _ => -- TODO: check whether we have unassigned metavars in rhs - liftM $ assignGoalOf p alt.rhs + liftM <| assignGoalOf p alt.rhs modify fun s => { s with used := s.used.insert alt.idx } private def processAsPattern (p : Problem) : MetaM Problem := match p.vars with | [] => unreachable! | x :: xs => withGoalOf p do - let alts ← p.alts.mapM fun alt => match alt.patterns with - | Pattern.as fvarId p h :: ps => do + let alts ← p.alts.mapM fun alt => do + match alt.patterns with + | Pattern.as fvarId p h :: ps => /- We used to use `checkAndReplaceFVarId` here, but `x` and `fvarId` may have different types when dependent types are beind used. Let's consider the repro for issue #471 ``` @@ -209,19 +211,20 @@ private def processAsPattern (p : Problem) : MetaM Problem := we the pattern `(vec.cons n h t)`. TODO: try to find a cleaner solution. -/ let r ← mkEqRefl x - pure <| { alt with patterns := p :: ps }.replaceFVarId fvarId x |>.replaceFVarId h r - | _ => pure alt - pure { p with alts := alts } + return { alt with patterns := p :: ps }.replaceFVarId fvarId x |>.replaceFVarId h r + | _ => return alt + return { p with alts := alts } private def processVariable (p : Problem) : MetaM Problem := match p.vars with | [] => unreachable! | x :: xs => withGoalOf p do - let alts ← p.alts.mapM fun alt => match alt.patterns with - | Pattern.inaccessible _ :: ps => pure { alt with patterns := ps } - | Pattern.var fvarId :: ps => { alt with patterns := ps }.checkAndReplaceFVarId fvarId x + let alts ← p.alts.mapM fun alt => do + match alt.patterns with + | Pattern.inaccessible _ :: ps => return { alt with patterns := ps } + | Pattern.var fvarId :: ps => ({ alt with patterns := ps }).checkAndReplaceFVarId fvarId x | _ => unreachable! - pure { p with alts := alts, vars := xs } + return { p with alts := alts, vars := xs } private def throwInductiveTypeExpected {α} (e : Expr) : MetaM α := do let t ← inferType e @@ -249,28 +252,28 @@ def expandIfVar (e : Expr) : M Expr := do | _ => return e def occurs (fvarId : FVarId) (v : Expr) : Bool := - Option.isSome $ v.find? fun e => match e with + Option.isSome <| v.find? fun e => match e with | Expr.fvar fvarId' _ => fvarId == fvarId' | _=> false def assign (fvarId : FVarId) (v : Expr) : M Bool := do if occurs fvarId v then trace[Meta.Match.unify] "assign occurs check failed, {mkFVar fvarId} := {v}" - pure false + return false else let ctx ← read if (← isAltVar fvarId) then trace[Meta.Match.unify] "{mkFVar fvarId} := {v}" modify fun s => { s with fvarSubst := s.fvarSubst.insert fvarId v } - pure true + return true else trace[Meta.Match.unify] "assign failed variable is not local, {mkFVar fvarId} := {v}" - pure false + return false partial def unify (a : Expr) (b : Expr) : M Bool := do trace[Meta.Match.unify] "{a} =?= {b}" if (← isDefEq a b) then - pure true + return true else let a' ← whnfD (← expandIfVar a) let b' ← whnfD (← expandIfVar b) @@ -281,7 +284,7 @@ partial def unify (a : Expr) (b : Expr) : M Bool := do | Expr.fvar aFvarId _, b => assign aFvarId b | a, Expr.fvar bFVarId _ => assign bFVarId a | Expr.app aFn aArg _, Expr.app bFn bArg _ => unify aFn bFn <&&> unify aArg bArg - | _, _ => pure false + | _, _ => return false end Unify @@ -290,10 +293,10 @@ private def unify? (altFVarDecls : List LocalDecl) (a b : Expr) : MetaM (Option let b ← instantiateMVars b let (r, s) ← Unify.unify a b { altFVarDecls := altFVarDecls} |>.run {} if r then - pure s.fvarSubst + return s.fvarSubst else trace[Meta.Match.unify] "failed to unify{indentExpr a}\nwith{indentExpr b}" - pure none + return none private def expandVarIntoCtor? (alt : Alt) (fvarId : FVarId) (ctorName : Name) : MetaM (Option Alt) := withExistingLocalDecls alt.fvarDecls do @@ -311,7 +314,7 @@ private def expandVarIntoCtor? (alt : Alt) (fvarId : FVarId) (ctorName : Name) : let newAltDecls := ctorFieldDecls.toList ++ alt.fvarDecls let subst? ← unify? newAltDecls resultType expectedType match subst? with - | none => pure none + | none => return none | some subst => let newAltDecls := newAltDecls.filter fun d => !subst.contains d.fvarId -- remove declarations that were assigned let newAltDecls := newAltDecls.map fun d => d.applyFVarSubst subst -- apply substitution to remaining declaration types @@ -320,7 +323,7 @@ private def expandVarIntoCtor? (alt : Alt) (fvarId : FVarId) (ctorName : Name) : let ctorFieldPatterns := ctorFields.toList.map fun ctorField => match subst.get ctorField.fvarId! with | e@(Expr.fvar fvarId _) => if inLocalDecls newAltDecls fvarId then Pattern.var fvarId else Pattern.inaccessible e | e => Pattern.inaccessible e - pure $ some { alt with fvarDecls := newAltDecls, rhs := rhs, patterns := ctorFieldPatterns ++ patterns } + return some { alt with fvarDecls := newAltDecls, rhs := rhs, patterns := ctorFieldPatterns ++ patterns } private def getInductiveVal? (x : Expr) : MetaM (Option InductiveVal) := do let xType ← inferType x @@ -329,14 +332,14 @@ private def getInductiveVal? (x : Expr) : MetaM (Option InductiveVal) := do | Expr.const constName _ _ => let cinfo ← getConstInfo constName match cinfo with - | ConstantInfo.inductInfo val => pure (some val) - | _ => pure none - | _ => pure none + | ConstantInfo.inductInfo val => return some val + | _ => return none + | _ => return none private def hasRecursiveType (x : Expr) : MetaM Bool := do match (← getInductiveVal? x) with - | some val => pure val.isRec - | _ => pure false + | some val => return val.isRec + | _ => return false /- Given `alt` s.t. the next pattern is an inaccessible pattern `e`, try to normalize `e` into a constructor application. @@ -357,9 +360,9 @@ def processInaccessibleAsCtor (alt : Alt) (ctorName : Name) : MetaM (Option Alt) if ctorVal.name == ctorName then let fields := ctorArgs.extract ctorVal.numParams ctorArgs.size let fields := fields.toList.map Pattern.inaccessible - pure $ some { alt with patterns := fields ++ ps } + return some { alt with patterns := fields ++ ps } else - pure none + return none | _ => throwErrorAt alt.ref "dependent match elimination failed, inaccessible pattern found{indentD p.toMessageData}\nconstructor expected" | _ => unreachable! @@ -396,41 +399,45 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do throwCasesException p ex if subgoals.isEmpty then /- Easy case: we have solved problem `p` since there are no subgoals -/ - pure (some #[]) + return some #[] else if !p.alts.isEmpty then - pure (some subgoals) + return some subgoals else do - let isRec ← withGoalOf p $ hasRecursiveType x + let isRec ← withGoalOf p <| hasRecursiveType x /- If there are no alternatives and the type of the current variable is recursive, we do NOT consider a constructor-transition to avoid nontermination. TODO: implement a more general approach if this is not sufficient in practice -/ - if isRec then pure none - else pure (some subgoals) + if isRec then + return none + else + return some subgoals + match subgoals? with - | none => pure #[{ p with vars := xs }] + | none => return #[{ p with vars := xs }] | some subgoals => subgoals.mapM fun subgoal => withMVarContext subgoal.mvarId do let subst := subgoal.subst let fields := subgoal.fields.toList let newVars := fields ++ xs let newVars := newVars.map fun x => x.applyFVarSubst subst - let subex := Example.ctor subgoal.ctorName $ fields.map fun field => match field with + let subex := Example.ctor subgoal.ctorName <| fields.map fun field => match field with | Expr.fvar fvarId _ => Example.var fvarId | _ => Example.underscore -- This case can happen due to dependent elimination - let examples := p.examples.map $ Example.replaceFVarId x.fvarId! subex - let examples := examples.map $ Example.applyFVarSubst subst + let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex + let examples := examples.map <| Example.applyFVarSubst subst let newAlts := p.alts.filter fun alt => match alt.patterns with | Pattern.ctor n _ _ _ :: _ => n == subgoal.ctorName | Pattern.var _ :: _ => true | Pattern.inaccessible _ :: _ => true | _ => false let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst - let newAlts ← newAlts.filterMapM fun alt => match alt.patterns with - | Pattern.ctor _ _ _ fields :: ps => pure $ some { alt with patterns := fields ++ ps } + let newAlts ← newAlts.filterMapM fun alt => do + match alt.patterns with + | Pattern.ctor _ _ _ fields :: ps => return some { alt with patterns := fields ++ ps } | Pattern.var fvarId :: ps => expandVarIntoCtor? { alt with patterns := ps } fvarId subgoal.ctorName | Pattern.inaccessible _ :: _ => processInaccessibleAsCtor alt subgoal.ctorName | _ => unreachable! - pure { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } + return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } private def processNonVariable (p : Problem) : MetaM Problem := match p.vars with @@ -440,27 +447,28 @@ private def processNonVariable (p : Problem) : MetaM Problem := let env ← getEnv match x.constructorApp? env with | some (ctorVal, xArgs) => - let alts ← p.alts.filterMapM fun alt => match alt.patterns with + let alts ← p.alts.filterMapM fun alt => do + match alt.patterns with | Pattern.ctor n _ _ fields :: ps => if n != ctorVal.name then - pure none + return none else - pure $ some { alt with patterns := fields ++ ps } + return some { alt with patterns := fields ++ ps } | Pattern.inaccessible _ :: _ => processInaccessibleAsCtor alt ctorVal.name | p :: _ => throwError "failed to compile pattern matching, inaccessible pattern or constructor expected{indentD p.toMessageData}" | _ => unreachable! let xFields := xArgs.extract ctorVal.numParams xArgs.size - pure { p with alts := alts, vars := xFields.toList ++ xs } + return { p with alts := alts, vars := xFields.toList ++ xs } | none => let alts ← p.alts.filterMapM fun alt => match alt.patterns with | Pattern.inaccessible e :: ps => do if (← isDefEq x e) then - pure $ some { alt with patterns := ps } + return some { alt with patterns := ps } else - pure none + return none | p :: _ => throwError "failed to compile pattern matching, unexpected pattern{indentD p.toMessageData}\ndiscriminant{indentExpr x}" | _ => unreachable! - pure { p with alts := alts, vars := xs } + return { p with alts := alts, vars := xs } private def collectValues (p : Problem) : Array Expr := p.alts.foldl (init := #[]) fun values alt => @@ -477,7 +485,7 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do trace[Meta.Match.match] "value step" match p.vars with | [] => unreachable! - | x :: xs => do + | x :: xs => let values := collectValues p let subgoals ← caseValues p.mvarId x.fvarId! values (substNewEqs := true) subgoals.mapIdxM fun i subgoal => do @@ -487,8 +495,8 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do -- (x = value) branch let subst := subgoal.subst trace[Meta.Match.match] "processValue subst: {subst.map.toList.map fun p => mkFVar p.1}, {subst.map.toList.map fun p => p.2}" - let examples := p.examples.map $ Example.replaceFVarId x.fvarId! (Example.val value) - let examples := examples.map $ Example.applyFVarSubst subst + let examples := p.examples.map <| Example.replaceFVarId x.fvarId! (Example.val value) + let examples := examples.map <| Example.applyFVarSubst subst let newAlts := p.alts.filter fun alt => match alt.patterns with | Pattern.val v :: _ => v == value | Pattern.var _ :: _ => true @@ -501,11 +509,11 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do alt.replaceFVarId fvarId value | _ => unreachable! let newVars := xs.map fun x => x.applyFVarSubst subst - pure { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } + return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } else -- else branch for value let newAlts := p.alts.filter isFirstPatternVar - pure { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs } + return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs } private def collectArraySizes (p : Problem) : Array Nat := p.alts.foldl (init := #[]) fun sizes alt => @@ -517,16 +525,17 @@ private def expandVarIntoArrayLit (alt : Alt) (fvarId : FVarId) (arrayElemType : withExistingLocalDecls alt.fvarDecls do let fvarDecl ← getLocalDecl fvarId let varNamePrefix := fvarDecl.userName - let rec loop - | n+1, newVars => - withLocalDeclD (varNamePrefix.appendIndexAfter (n+1)) arrayElemType fun x => - loop n (newVars.push x) - | 0, newVars => do - let arrayLit ← mkArrayLit arrayElemType newVars.toList - let alt := alt.replaceFVarId fvarId arrayLit - let newDecls ← newVars.toList.mapM fun newVar => getLocalDecl newVar.fvarId! - let newPatterns := newVars.toList.map fun newVar => Pattern.var newVar.fvarId! - pure { alt with fvarDecls := newDecls ++ alt.fvarDecls, patterns := newPatterns ++ alt.patterns } + let rec loop (n : Nat) (newVars : Array Expr) := do + match n with + | n+1 => + withLocalDeclD (varNamePrefix.appendIndexAfter (n+1)) arrayElemType fun x => + loop n (newVars.push x) + | 0 => + let arrayLit ← mkArrayLit arrayElemType newVars.toList + let alt := alt.replaceFVarId fvarId arrayLit + let newDecls ← newVars.toList.mapM fun newVar => getLocalDecl newVar.fvarId! + let newPatterns := newVars.toList.map fun newVar => Pattern.var newVar.fvarId! + return { alt with fvarDecls := newDecls ++ alt.fvarDecls, patterns := newPatterns ++ alt.patterns } loop arraySize #[] private def processArrayLit (p : Problem) : MetaM (Array Problem) := do @@ -551,17 +560,18 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do | Pattern.var _ :: _ => true | _ => false let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst - let newAlts ← newAlts.mapM fun alt => match alt.patterns with - | Pattern.arrayLit _ pats :: ps => pure { alt with patterns := pats ++ ps } - | Pattern.var fvarId :: ps => do + let newAlts ← newAlts.mapM fun alt => do + match alt.patterns with + | Pattern.arrayLit _ pats :: ps => return { alt with patterns := pats ++ ps } + | Pattern.var fvarId :: ps => let α ← getArrayArgType <| subst.apply x expandVarIntoArrayLit { alt with patterns := ps } fvarId α size | _ => unreachable! - pure { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } - else do + return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } + else -- else branch let newAlts := p.alts.filter isFirstPatternVar - pure { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs } + return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs } private def expandNatValuePattern (p : Problem) : Problem := let alts := p.alts.map fun alt => match alt.patterns with @@ -583,10 +593,10 @@ private def throwNonSupported (p : Problem) : MetaM Unit := def isCurrVarInductive (p : Problem) : MetaM Bool := do match p.vars with - | [] => pure false + | [] => return false | x::_ => withGoalOf p do let val? ← getInductiveVal? x - pure val?.isSome + return val?.isSome private def checkNextPatternTypes (p : Problem) : MetaM Unit := do match p.vars with @@ -595,7 +605,7 @@ private def checkNextPatternTypes (p : Problem) : MetaM Unit := do for alt in p.alts do withRef alt.ref do match alt.patterns with - | [] => pure () + | [] => return () | p::_ => let e ← p.toExpr let xType ← inferType x @@ -632,7 +642,7 @@ where /- If `p.vars` is empty, then we are done. Otherwise, we process `p.vars[0]`. -/ tryToProcess (p : Problem) : StateRefT State MetaM Unit := withIncRecDepth do traceState p - let isInductive ← liftM $ isCurrVarInductive p + let isInductive ← isCurrVarInductive p if isDone p then processLeaf p else if hasAsPattern p then @@ -738,7 +748,7 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E let mkMatcherConst name := mkAppN (mkConst name result.levelArgs.toList) result.exprArgs match (matcherExt.getState env).find? (result.value, compile) with - | some nameNew => pure (mkMatcherConst nameNew, none) + | some nameNew => return (mkMatcherConst nameNew, none) | none => let decl := Declaration.defnDecl { name @@ -875,18 +885,17 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput let matcherName := matcherApp.matcherName let some matcherInfo ← getMatcherInfo? matcherName | throwError "not a matcher: {matcherName}" let matcherConst ← getConstInfo matcherName - let matcherType ← instantiateForall matcherConst.type $ matcherApp.params ++ #[matcherApp.motive] + let matcherType ← instantiateForall matcherConst.type <| matcherApp.params ++ #[matcherApp.motive] let matchType ← do let u := if let some idx := matcherInfo.uElimPos? then mkLevelParam matcherConst.levelParams.toArray[idx] else levelZero - forallBoundedTelescope matcherType (some matcherInfo.numDiscrs) fun discrs t => do mkForallFVars discrs (mkConst ``PUnit [u]) let matcherType ← instantiateForall matcherType matcherApp.discrs - let lhss := Array.toList $ ←forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ => + let lhss ← forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ => alts.mapM fun alt => do let ty ← inferType alt forallTelescope ty fun xs body => do @@ -900,12 +909,10 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput fvarDecls := localDecls.toList patterns := patterns.toList : Match.AltLHS } - return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss } + return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss := lhss.toList } - -def withMkMatcherInput - (matcherName : Name) - (k : MkMatcherInput → MetaM α) : MetaM α := do +/- This function is only used for testing purposes -/ +def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : MetaM α := do let some matcherInfo ← getMatcherInfo? matcherName | throwError "not a matcher: {matcherName}" let matcherConst ← getConstInfo matcherName forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs t => do @@ -929,12 +936,11 @@ private partial def updateAlts (typeNew : Expr) (altNumParams : Array Nat) (alts let alt ← try instantiateLambda alt xs catch _ => throwError "unexpected matcher application, insufficient number of parameters in alternative" forallBoundedTelescope d (some 1) fun x d => do let alt ← mkLambdaFVars x alt -- x is the new argument we are adding to the alternative - let alt ← mkLambdaFVars xs alt - pure alt + mkLambdaFVars xs alt updateAlts (b.instantiate1 alt) (altNumParams.set! i (numParams+1)) (alts.set ⟨i, h⟩ alt) (i+1) | _ => throwError "unexpected type at MatcherApp.addArg" else - pure (altNumParams, alts) + return (altNumParams, alts) /- Given - matcherApp `match_i As (fun xs => motive[xs]) discrs (fun ys_1 => (alt_1 : motive (C_1[ys_1])) ... (fun ys_n => (alt_n : motive (C_n[ys_n]) remaining`, and @@ -956,25 +962,24 @@ def MatcherApp.addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp := let motiveArg := motiveArgs[i] let discr := matcherApp.discrs[i] let eTypeAbst ← kabstract eTypeAbst discr - pure $ eTypeAbst.instantiate1 motiveArg + return eTypeAbst.instantiate1 motiveArg let motiveBody ← mkArrow eTypeAbst motiveBody let matcherLevels ← match matcherApp.uElimPos? with | none => pure matcherApp.matcherLevels | some pos => let uElim ← getLevel motiveBody - pure $ matcherApp.matcherLevels.set! pos uElim + pure <| matcherApp.matcherLevels.set! pos uElim let motive ← mkLambdaFVars motiveArgs motiveBody -- Construct `aux` `match_i As (fun xs => B[xs] → motive[xs]) discrs`, and infer its type `auxType`. -- We use `auxType` to infer the type `B[C_i[ys_i]]` of the new argument in each alternative. let aux := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) matcherApp.params let aux := mkApp aux motive let aux := mkAppN aux matcherApp.discrs - check aux unless (← isTypeCorrect aux) do throwError "failed to add argument to matcher application, type error when constructing the new motive" let auxType ← inferType aux let (altNumParams, alts) ← updateAlts auxType matcherApp.altNumParams matcherApp.alts 0 - pure { matcherApp with + return { matcherApp with matcherLevels := matcherLevels, motive := motive, alts := alts,