feat: improve matchType inference

This commit is contained in:
Leonardo de Moura 2020-09-14 19:37:05 -07:00
parent cc3b48ce16
commit fde43e071d
3 changed files with 53 additions and 23 deletions

View file

@ -48,22 +48,7 @@ private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expe
newStx ← `(let $lhsVar : $type := $discr; $rhs);
withMacroExpansion stx newStx $ elabTerm newStx expectedType?
private def expandMatchOptTypeAux (ref : Syntax) : Nat → MacroM Syntax
| 0 => pure $ mkHole ref
| n+1 => do t ← expandMatchOptTypeAux n; r ← `(forall _, $t); pure (r.copyInfo ref)
private def expandMatchOptType (ref : Syntax) (optType : Syntax) (numDiscrs : Nat) : MacroM Syntax :=
if optType.isNone then
expandMatchOptTypeAux ref numDiscrs
else
pure $ (optType.getArg 0).getArg 1
private def elabMatchOptType (matchOptType : Syntax) (numDiscrs : Nat) : TermElabM Expr := do
ref ← getRef;
typeStx ← liftMacroM $ expandMatchOptType ref matchOptType numDiscrs;
elabType typeStx
private partial def elabDiscrsAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr)
private partial def elabDiscrsWitMatchTypeAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr)
| i, matchType, discrs =>
if h : i < discrStxs.size then do
let discrStx := discrStxs.get ⟨i, h⟩;
@ -72,15 +57,40 @@ private partial def elabDiscrsAux (discrStxs : Array Syntax) (expectedType : Exp
| Expr.forallE _ d b _ => do
discr ← fullApproxDefEq $ elabTermEnsuringType discrStx d;
trace `Elab.match fun _ => "discr #" ++ toString i ++ " " ++ discr ++ " : " ++ d;
elabDiscrsAux (i+1) (b.instantiate1 discr) (discrs.push discr)
elabDiscrsWitMatchTypeAux (i+1) (b.instantiate1 discr) (discrs.push discr)
| _ => throwError ("invalid type provided to match-expression, function type with arity #" ++ toString discrStxs ++ " expected")
else do
unlessM (fullApproxDefEq $ isDefEq matchType expectedType) $
throwError ("invalid result type provided to match-expression" ++ indentExpr matchType ++ Format.line ++ "expected type" ++ indentExpr expectedType);
pure discrs
private def elabDiscrs (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) :=
elabDiscrsAux discrStxs expectedType 0 matchType #[]
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) :=
elabDiscrsWitMatchTypeAux discrStxs expectedType 0 matchType #[]
private def mkUserNameFor (e : Expr) : TermElabM Name :=
match e with
| Expr.fvar fvarId _ => do localDecl ← getLocalDecl fvarId; pure localDecl.userName
| _ => withFreshMacroScope do x ← `(x); pure x.getId
private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (expectedType : Expr) : TermElabM (Array Expr × Expr) :=
let numDiscrs := discrStxs.size;
if matchOptType.isNone then do
discrs ← discrStxs.mapM fun discrStx => elabTerm discrStx none;
matchType ← discrs.foldrM
(fun (discr : Expr) (matchType : Expr) => do
discr ← instantiateMVars discr;
discrType ← inferType discr;
discrType ← instantiateMVars discrType;
matchTypeBody ← kabstract matchType discr;
userName ← mkUserNameFor discr;
pure $ Lean.mkForall userName BinderInfo.default discrType matchTypeBody)
expectedType;
pure (discrs, matchType)
else do
let matchTypeStx := (matchOptType.getArg 0).getArg 1;
matchType ← elabType matchTypeStx;
discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType;
pure (discrs, matchType)
/-
nodeWithAntiquot "matchAlt" `Lean.Parser.Term.matchAlt $ sepBy1 termParser ", " >> darrow >> termParser
@ -575,6 +585,7 @@ match val? with
/- HACK: `fvarId` is not in the scope of `mvarId`
If this generates problems in the future, we should update the metavariable declarations. -/
assignExprMVar mvarId (mkFVar fvarId);
-- TODO: use macro scopes for creating userName
let userName := (`_x).appendIndexAfter (s.localDecls.size+1);
let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default;
modify $ fun s =>
@ -688,9 +699,8 @@ unless result.unusedAltIdxs.isEmpty $
private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
: TermElabM Expr := do
matchType ← elabMatchOptType matchOptType discrStxs.size;
(discrs, matchType) ← elabMatchTypeAndDiscrs discrStxs matchOptType expectedType;
matchAlts ← expandMacrosInPatterns altViews;
discrs ← elabDiscrs discrStxs matchType expectedType;
trace `Elab.match fun _ => "matchType: " ++ matchType;
alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType;
let rhss := alts.map Prod.snd;

View file

@ -1,10 +1,10 @@
Sum.someRight c : Option Nat
evalWithMVar.lean:13:6: error: don't know how to synthesize placeholder
@Sum.someRight ?m.178 … …
@Sum.someRight ?m.168 … …
context:
⊢ Type ?
evalWithMVar.lean:13:20: error: don't know how to synthesize placeholder
@c ?m.178
@c ?m.168
context:
⊢ Type ?
Sum.someRight c : Option Nat

View file

@ -59,3 +59,23 @@ by match h with
| Or.inr (Or.inr h) =>
apply Or.inl;
assumption
inductive ListLast.{u} {α : Type u} : List α → Type u
| empty : ListLast []
| nonEmpty : (as : List α) → (a : α) → ListLast (as ++ [a])
axiom last {α} (xs : List α) : ListLast xs
axiom back {α} [Inhabited α] (xs : List α) : α
axiom popBack {α} : List α → List α
axiom backEq {α} [Inhabited α] : (xs : List α) → (x : α) → back (xs ++ [x]) = x
axiom popBackEq {α} : (xs : List α) → (x : α) → popBack (xs ++ [x]) = xs
theorem tst8 {α} [Inhabited α] (xs : List α) : xs ≠ [] → xs = popBack xs ++ [back xs] :=
match xs, last xs with
| _, ListLast.empty => fun h => absurd rfl h
| _, ListLast.nonEmpty ys y => fun _ => sorry
theorem tst9 {α} [Inhabited α] (xs : List α) : xs ≠ [] → xs = popBack xs ++ [back xs] := by
match xs, last xs with
| _, ListLast.empty => intro h; exact absurd rfl h
| _, ListLast.nonEmpty ys y => intro; exact sorry