From fde43e071d7c1bde5ef1edda2b03d36c2c4a21b3 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 14 Sep 2020 19:37:05 -0700 Subject: [PATCH] feat: improve matchType inference --- src/Lean/Elab/Match.lean | 52 ++++++++++++++--------- tests/lean/evalWithMVar.lean.expected.out | 4 +- tests/lean/run/matchtac.lean | 20 +++++++++ 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 1863fa8708..f85b3235cc 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -48,22 +48,7 @@ private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expe newStx ← `(let $lhsVar : $type := $discr; $rhs); withMacroExpansion stx newStx $ elabTerm newStx expectedType? -private def expandMatchOptTypeAux (ref : Syntax) : Nat → MacroM Syntax -| 0 => pure $ mkHole ref -| n+1 => do t ← expandMatchOptTypeAux n; r ← `(forall _, $t); pure (r.copyInfo ref) - -private def expandMatchOptType (ref : Syntax) (optType : Syntax) (numDiscrs : Nat) : MacroM Syntax := -if optType.isNone then - expandMatchOptTypeAux ref numDiscrs -else - pure $ (optType.getArg 0).getArg 1 - -private def elabMatchOptType (matchOptType : Syntax) (numDiscrs : Nat) : TermElabM Expr := do -ref ← getRef; -typeStx ← liftMacroM $ expandMatchOptType ref matchOptType numDiscrs; -elabType typeStx - -private partial def elabDiscrsAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr) +private partial def elabDiscrsWitMatchTypeAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr) | i, matchType, discrs => if h : i < discrStxs.size then do let discrStx := discrStxs.get ⟨i, h⟩; @@ -72,15 +57,40 @@ private partial def elabDiscrsAux (discrStxs : Array Syntax) (expectedType : Exp | Expr.forallE _ d b _ => do discr ← fullApproxDefEq $ elabTermEnsuringType discrStx d; trace `Elab.match fun _ => "discr #" ++ toString i ++ " " ++ discr ++ " : " ++ d; - elabDiscrsAux (i+1) (b.instantiate1 discr) (discrs.push discr) + elabDiscrsWitMatchTypeAux (i+1) (b.instantiate1 discr) (discrs.push discr) | _ => throwError ("invalid type provided to match-expression, function type with arity #" ++ toString discrStxs ++ " expected") else do unlessM (fullApproxDefEq $ isDefEq matchType expectedType) $ throwError ("invalid result type provided to match-expression" ++ indentExpr matchType ++ Format.line ++ "expected type" ++ indentExpr expectedType); pure discrs -private def elabDiscrs (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := -elabDiscrsAux discrStxs expectedType 0 matchType #[] +private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := +elabDiscrsWitMatchTypeAux discrStxs expectedType 0 matchType #[] + +private def mkUserNameFor (e : Expr) : TermElabM Name := +match e with +| Expr.fvar fvarId _ => do localDecl ← getLocalDecl fvarId; pure localDecl.userName +| _ => withFreshMacroScope do x ← `(x); pure x.getId + +private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (expectedType : Expr) : TermElabM (Array Expr × Expr) := +let numDiscrs := discrStxs.size; +if matchOptType.isNone then do + discrs ← discrStxs.mapM fun discrStx => elabTerm discrStx none; + matchType ← discrs.foldrM + (fun (discr : Expr) (matchType : Expr) => do + discr ← instantiateMVars discr; + discrType ← inferType discr; + discrType ← instantiateMVars discrType; + matchTypeBody ← kabstract matchType discr; + userName ← mkUserNameFor discr; + pure $ Lean.mkForall userName BinderInfo.default discrType matchTypeBody) + expectedType; + pure (discrs, matchType) +else do + let matchTypeStx := (matchOptType.getArg 0).getArg 1; + matchType ← elabType matchTypeStx; + discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType; + pure (discrs, matchType) /- nodeWithAntiquot "matchAlt" `Lean.Parser.Term.matchAlt $ sepBy1 termParser ", " >> darrow >> termParser @@ -575,6 +585,7 @@ match val? with /- HACK: `fvarId` is not in the scope of `mvarId` If this generates problems in the future, we should update the metavariable declarations. -/ assignExprMVar mvarId (mkFVar fvarId); + -- TODO: use macro scopes for creating userName let userName := (`_x).appendIndexAfter (s.localDecls.size+1); let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default; modify $ fun s => @@ -688,9 +699,8 @@ unless result.unusedAltIdxs.isEmpty $ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr) : TermElabM Expr := do -matchType ← elabMatchOptType matchOptType discrStxs.size; +(discrs, matchType) ← elabMatchTypeAndDiscrs discrStxs matchOptType expectedType; matchAlts ← expandMacrosInPatterns altViews; -discrs ← elabDiscrs discrStxs matchType expectedType; trace `Elab.match fun _ => "matchType: " ++ matchType; alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType; let rhss := alts.map Prod.snd; diff --git a/tests/lean/evalWithMVar.lean.expected.out b/tests/lean/evalWithMVar.lean.expected.out index 3cf9387749..690bb23088 100644 --- a/tests/lean/evalWithMVar.lean.expected.out +++ b/tests/lean/evalWithMVar.lean.expected.out @@ -1,10 +1,10 @@ Sum.someRight c : Option Nat evalWithMVar.lean:13:6: error: don't know how to synthesize placeholder - @Sum.someRight ?m.178 … … + @Sum.someRight ?m.168 … … context: ⊢ Type ? evalWithMVar.lean:13:20: error: don't know how to synthesize placeholder - @c ?m.178 + @c ?m.168 context: ⊢ Type ? Sum.someRight c : Option Nat diff --git a/tests/lean/run/matchtac.lean b/tests/lean/run/matchtac.lean index bc598d900b..318848df1a 100644 --- a/tests/lean/run/matchtac.lean +++ b/tests/lean/run/matchtac.lean @@ -59,3 +59,23 @@ by match h with | Or.inr (Or.inr h) => apply Or.inl; assumption + +inductive ListLast.{u} {α : Type u} : List α → Type u +| empty : ListLast [] +| nonEmpty : (as : List α) → (a : α) → ListLast (as ++ [a]) + +axiom last {α} (xs : List α) : ListLast xs +axiom back {α} [Inhabited α] (xs : List α) : α +axiom popBack {α} : List α → List α +axiom backEq {α} [Inhabited α] : (xs : List α) → (x : α) → back (xs ++ [x]) = x +axiom popBackEq {α} : (xs : List α) → (x : α) → popBack (xs ++ [x]) = xs + +theorem tst8 {α} [Inhabited α] (xs : List α) : xs ≠ [] → xs = popBack xs ++ [back xs] := +match xs, last xs with +| _, ListLast.empty => fun h => absurd rfl h +| _, ListLast.nonEmpty ys y => fun _ => sorry + +theorem tst9 {α} [Inhabited α] (xs : List α) : xs ≠ [] → xs = popBack xs ++ [back xs] := by +match xs, last xs with +| _, ListLast.empty => intro h; exact absurd rfl h +| _, ListLast.nonEmpty ys y => intro; exact sorry