From 66d35cdd76bbaa80fa5e428115c715fa094df477 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 24 Sep 2020 19:34:14 -0700 Subject: [PATCH] fix: the generated matcher must be able to eliminate into different universe levels --- src/Lean/Elab/Match.lean | 14 ++----- src/Lean/Meta/Match/Match.lean | 65 ++++++++++++++++++++++------- tests/lean/run/depElim1.lean | 6 +-- tests/lean/run/matcherElimUniv.lean | 9 ++++ 4 files changed, 67 insertions(+), 27 deletions(-) create mode 100644 tests/lean/run/matcherElimUniv.lean diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index ea80598747..f0925636b4 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -697,13 +697,8 @@ withPatternVars patternVars fun patternVarDecls => do -- TODO: check whether altLHS still has metavariables pure (altLHS, rhs) -def mkMotiveType (matchType : Expr) (numDiscrs : Nat) : TermElabM Expr := do -forallBoundedTelescope matchType numDiscrs fun xs matchType => do - u ← getLevel matchType; - mkForallFVars xs (mkSort u) - -def mkMatcher (elimName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult := -liftMetaM $ Meta.Match.mkMatcher elimName motiveType numDiscrs lhss +def mkMatcher (elimName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult := +liftMetaM $ Meta.Match.mkMatcher elimName matchType numDiscrs lhss def reportMatcherResultErrors (result : MatcherResult) : TermElabM Unit := do -- TODO: improve error messages @@ -721,10 +716,9 @@ alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType; let rhss := alts.map Prod.snd; let altLHSS := alts.map Prod.fst; let numDiscrs := discrs.size; -motiveType ← mkMotiveType matchType numDiscrs; -motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType; matcherName ← mkAuxName `match; -matcherResult ← mkMatcher matcherName motiveType numDiscrs altLHSS.toList; +matcherResult ← mkMatcher matcherName matchType numDiscrs altLHSS.toList; +motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType; reportMatcherResultErrors matcherResult; let r := mkApp matcherResult.matcher motive; let r := mkAppN r discrs; diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 53e0998363..ce358a32cd 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -766,9 +766,11 @@ A "matcher" auxiliary declaration has the following structure: - motive - `numDiscrs` discriminators (aka major premises) - `altNumParams.size` alternatives (aka minor premises) where alternative `i` has `altNumParams[i]` alternatives --/ +- `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and + `pos` is the position of the universe level parameter that specifies the elimination universe. + It is `none` if the matcher only eliminates into `Prop`. -/ structure MatcherInfo := -(numParams : Nat) (numDiscrs : Nat) (altNumParams : Array Nat) +(numParams : Nat) (numDiscrs : Nat) (altNumParams : Array Nat) (uElimPos? : Option Nat) def MatcherInfo.numAlts (matcherInfo : MatcherInfo) : Nat := matcherInfo.altNumParams.size @@ -809,11 +811,34 @@ end Extension def addMatcherInfo (matcherName : Name) (info : MatcherInfo) : MetaM Unit := modifyEnv fun env => Extension.addMatcherInfo env matcherName info -def mkMatcher (matcherName : Name) (motiveType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM MatcherResult := +private def getUElimPos? (matcherLevels : List Level) (uElim : Level) : MetaM (Option Nat) := +if uElim == levelZero then pure none +else match matcherLevels.toArray.indexOf? uElim with + | none => throwError "dependent match elimination failed, universe level not found" + | some pos => pure $ some pos.val + +/- +Create a dependent matcher for `matchType` where `matchType` is of the form +`(a_1 : A_1) -> (a_2 : A_2[a_1]) -> ... -> (a_n : A_n[a_1, a_2, ... a_{n-1}]) -> B[a_1, ..., a_n]` +where `n = numDiscrs`, and the `lhss` are the left-hand-sides of the `match`-expression alternatives. +Each `AltLHS` has a list of local declarations and a list of patterns. +The number of patterns must be the same in each `AltLHS`. +The generated matcher has the structure described at `MatcherInfo`. The motive argument is of the form +`(motive : (a_1 : A_1) -> (a_2 : A_2[a_1]) -> ... -> (a_n : A_n[a_1, a_2, ... a_{n-1}]) -> Sort v)` +where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. +-/ +def mkMatcher (matcherName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : MetaM MatcherResult := +forallBoundedTelescope matchType numDiscrs fun majors matchTypeBody => do +checkNumPatterns majors lhss; +/- We generate an matcher that can eliminate using different motives with different universe levels. + `uElim` is the universe level the caller wants to eliminate to. + If it is not levelZero, we create a matcher that can eliminate in any universe level. + This is useful for implementing `MatcherApp.addArg` because it may have to change the universe level. -/ +uElim ← getLevel matchTypeBody; +uElimGen ← if uElim == levelZero then pure levelZero else mkFreshLevelMVar; +motiveType ← mkForallFVars majors (mkSort uElimGen); withLocalDeclD `motive motiveType fun motive => do trace! `Meta.Match.debug ("motiveType: " ++ motiveType); -forallBoundedTelescope motiveType numDiscrs fun majors _ => do -checkNumPatterns majors lhss; let mvarType := mkAppN motive majors; trace! `Meta.Match.debug ("target: " ++ mvarType); withAlts motive lhss fun alts minors => do @@ -825,7 +850,10 @@ withAlts motive lhss fun alts minors => do val ← mkLambdaFVars args mvar; trace! `Meta.Match.debug ("matcher value: " ++ val ++ "\ntype: " ++ type); matcher ← mkAuxDefinition matcherName type val; - addMatcherInfo matcherName { numParams := matcher.getAppNumArgs, numDiscrs := majors.size, altNumParams := minors.map Prod.snd }; + trace! `Meta.Match.debug ("matcher levels: " ++ toString matcher.getAppFn.constLevels! ++ ", uElim: " ++ toString uElimGen); + uElimPos? ← getUElimPos? matcher.getAppFn.constLevels! uElimGen; + isLevelDefEq uElimGen uElim; + addMatcherInfo matcherName { numParams := matcher.getAppNumArgs, numDiscrs := numDiscrs, altNumParams := minors.map Prod.snd, uElimPos? := uElimPos? }; setInlineAttribute matcherName; trace! `Meta.Match.debug ("matcher: " ++ matcher); let unusedAltIdxs : List Nat := lhss.length.fold @@ -847,7 +875,8 @@ pure info?.isSome structure MatcherApp := (matcherName : Name) -(matcherLevels : List Level) +(matcherLevels : Array Level) +(uElimPos? : Option Nat) (params : Array Expr) (motive : Expr) (discrs : Array Expr) @@ -864,7 +893,8 @@ match e.getAppFn with else pure $ some { matcherName := declName, - matcherLevels := declLevels, + matcherLevels := declLevels.toArray, + uElimPos? := info.uElimPos?, params := args.extract 0 info.numParams, motive := args.get! info.numParams, discrs := args.extract (info.numParams + 1) (info.numParams + 1 + info.numDiscrs), @@ -875,7 +905,7 @@ match e.getAppFn with | _ => pure none def MatcherApp.toExpr (matcherApp : MatcherApp) : Expr := -let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels) matcherApp.params; +let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels.toList) matcherApp.params; let result := mkApp result matcherApp.motive; let result := mkAppN result matcherApp.discrs; let result := mkAppN result matcherApp.alts; @@ -926,10 +956,16 @@ lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do pure $ eTypeAbst.instantiate1 motiveArg) eType; motiveBody ← mkArrow eTypeAbst motiveBody; + matcherLevels ← match matcherApp.uElimPos? with + | none => pure matcherApp.matcherLevels + | some pos => do { + uElim ← getLevel motiveBody; + pure $ matcherApp.matcherLevels.set! pos uElim + }; motive ← mkLambdaFVars motiveArgs motiveBody; -- Construct `aux` `match_i As (fun xs => B[xs] → motive[xs]) discrs`, and infer its type `auxType`. -- We use `auxType` to infer the type `B[C_i[ys_i]]` of the new argument in each alternative. - let aux := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels) matcherApp.params; + let aux := mkAppN (mkConst matcherApp.matcherName matcherLevels.toList) matcherApp.params; let aux := mkApp aux motive; let aux := mkAppN aux matcherApp.discrs; trace! `Meta.debug aux; @@ -939,10 +975,11 @@ lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do auxType ← inferType aux; (altNumParams, alts) ← updateAlts auxType matcherApp.altNumParams matcherApp.alts 0; pure { matcherApp with - motive := motive, - alts := alts, - altNumParams := altNumParams, - remaining := #[e] ++ matcherApp.remaining } + matcherLevels := matcherLevels, + motive := motive, + alts := alts, + altNumParams := altNumParams, + remaining := #[e] ++ matcherApp.remaining } @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Meta.Match.match; diff --git a/tests/lean/run/depElim1.lean b/tests/lean/run/depElim1.lean index 35f041b22a..97647b5624 100644 --- a/tests/lean/run/depElim1.lean +++ b/tests/lean/run/depElim1.lean @@ -159,10 +159,10 @@ else do pure $ mkSort $ v def mkTester (elimName : Name) (majors : List Expr) (lhss : List AltLHS) (inProp : Bool := false) : MetaM MatcherResult := do -sortv ← mkElimSort majors lhss inProp; generalizeTelescope majors.toArray `_d fun majors => do - motiveType ← mkForallFVars majors sortv; - Match.mkMatcher elimName motiveType majors.size lhss + let resultType := if inProp then mkConst `True /- some proposition -/ else mkConst `Nat; + matchType ← mkForallFVars majors resultType; + Match.mkMatcher elimName matchType majors.size lhss def test (ex : Name) (numPats : Nat) (elimName : Name) (inProp : Bool := false) : MetaM Unit := withDepElimFrom ex numPats fun majors alts => do diff --git a/tests/lean/run/matcherElimUniv.lean b/tests/lean/run/matcherElimUniv.lean new file mode 100644 index 0000000000..c412ecea00 --- /dev/null +++ b/tests/lean/run/matcherElimUniv.lean @@ -0,0 +1,9 @@ +new_frontend +universes u + +def len {α : Type u} : List α → List α → Nat +| [], bs => bs.length +| a::as, bs => len as bs + 1 + +theorem ex1 : len [1, 2] [3, 4] = 4 := +rfl