feat: improve matchType inference
This commit is contained in:
parent
cc3b48ce16
commit
fde43e071d
3 changed files with 53 additions and 23 deletions
|
|
@ -48,22 +48,7 @@ private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expe
|
|||
newStx ← `(let $lhsVar : $type := $discr; $rhs);
|
||||
withMacroExpansion stx newStx $ elabTerm newStx expectedType?
|
||||
|
||||
private def expandMatchOptTypeAux (ref : Syntax) : Nat → MacroM Syntax
|
||||
| 0 => pure $ mkHole ref
|
||||
| n+1 => do t ← expandMatchOptTypeAux n; r ← `(forall _, $t); pure (r.copyInfo ref)
|
||||
|
||||
private def expandMatchOptType (ref : Syntax) (optType : Syntax) (numDiscrs : Nat) : MacroM Syntax :=
|
||||
if optType.isNone then
|
||||
expandMatchOptTypeAux ref numDiscrs
|
||||
else
|
||||
pure $ (optType.getArg 0).getArg 1
|
||||
|
||||
private def elabMatchOptType (matchOptType : Syntax) (numDiscrs : Nat) : TermElabM Expr := do
|
||||
ref ← getRef;
|
||||
typeStx ← liftMacroM $ expandMatchOptType ref matchOptType numDiscrs;
|
||||
elabType typeStx
|
||||
|
||||
private partial def elabDiscrsAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr)
|
||||
private partial def elabDiscrsWitMatchTypeAux (discrStxs : Array Syntax) (expectedType : Expr) : Nat → Expr → Array Expr → TermElabM (Array Expr)
|
||||
| i, matchType, discrs =>
|
||||
if h : i < discrStxs.size then do
|
||||
let discrStx := discrStxs.get ⟨i, h⟩;
|
||||
|
|
@ -72,15 +57,40 @@ private partial def elabDiscrsAux (discrStxs : Array Syntax) (expectedType : Exp
|
|||
| Expr.forallE _ d b _ => do
|
||||
discr ← fullApproxDefEq $ elabTermEnsuringType discrStx d;
|
||||
trace `Elab.match fun _ => "discr #" ++ toString i ++ " " ++ discr ++ " : " ++ d;
|
||||
elabDiscrsAux (i+1) (b.instantiate1 discr) (discrs.push discr)
|
||||
elabDiscrsWitMatchTypeAux (i+1) (b.instantiate1 discr) (discrs.push discr)
|
||||
| _ => throwError ("invalid type provided to match-expression, function type with arity #" ++ toString discrStxs ++ " expected")
|
||||
else do
|
||||
unlessM (fullApproxDefEq $ isDefEq matchType expectedType) $
|
||||
throwError ("invalid result type provided to match-expression" ++ indentExpr matchType ++ Format.line ++ "expected type" ++ indentExpr expectedType);
|
||||
pure discrs
|
||||
|
||||
private def elabDiscrs (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) :=
|
||||
elabDiscrsAux discrStxs expectedType 0 matchType #[]
|
||||
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) :=
|
||||
elabDiscrsWitMatchTypeAux discrStxs expectedType 0 matchType #[]
|
||||
|
||||
private def mkUserNameFor (e : Expr) : TermElabM Name :=
|
||||
match e with
|
||||
| Expr.fvar fvarId _ => do localDecl ← getLocalDecl fvarId; pure localDecl.userName
|
||||
| _ => withFreshMacroScope do x ← `(x); pure x.getId
|
||||
|
||||
private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (expectedType : Expr) : TermElabM (Array Expr × Expr) :=
|
||||
let numDiscrs := discrStxs.size;
|
||||
if matchOptType.isNone then do
|
||||
discrs ← discrStxs.mapM fun discrStx => elabTerm discrStx none;
|
||||
matchType ← discrs.foldrM
|
||||
(fun (discr : Expr) (matchType : Expr) => do
|
||||
discr ← instantiateMVars discr;
|
||||
discrType ← inferType discr;
|
||||
discrType ← instantiateMVars discrType;
|
||||
matchTypeBody ← kabstract matchType discr;
|
||||
userName ← mkUserNameFor discr;
|
||||
pure $ Lean.mkForall userName BinderInfo.default discrType matchTypeBody)
|
||||
expectedType;
|
||||
pure (discrs, matchType)
|
||||
else do
|
||||
let matchTypeStx := (matchOptType.getArg 0).getArg 1;
|
||||
matchType ← elabType matchTypeStx;
|
||||
discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType;
|
||||
pure (discrs, matchType)
|
||||
|
||||
/-
|
||||
nodeWithAntiquot "matchAlt" `Lean.Parser.Term.matchAlt $ sepBy1 termParser ", " >> darrow >> termParser
|
||||
|
|
@ -575,6 +585,7 @@ match val? with
|
|||
/- HACK: `fvarId` is not in the scope of `mvarId`
|
||||
If this generates problems in the future, we should update the metavariable declarations. -/
|
||||
assignExprMVar mvarId (mkFVar fvarId);
|
||||
-- TODO: use macro scopes for creating userName
|
||||
let userName := (`_x).appendIndexAfter (s.localDecls.size+1);
|
||||
let newDecl := LocalDecl.cdecl (arbitrary _) fvarId userName type BinderInfo.default;
|
||||
modify $ fun s =>
|
||||
|
|
@ -688,9 +699,8 @@ unless result.unusedAltIdxs.isEmpty $
|
|||
|
||||
private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
|
||||
: TermElabM Expr := do
|
||||
matchType ← elabMatchOptType matchOptType discrStxs.size;
|
||||
(discrs, matchType) ← elabMatchTypeAndDiscrs discrStxs matchOptType expectedType;
|
||||
matchAlts ← expandMacrosInPatterns altViews;
|
||||
discrs ← elabDiscrs discrStxs matchType expectedType;
|
||||
trace `Elab.match fun _ => "matchType: " ++ matchType;
|
||||
alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType;
|
||||
let rhss := alts.map Prod.snd;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
Sum.someRight c : Option Nat
|
||||
evalWithMVar.lean:13:6: error: don't know how to synthesize placeholder
|
||||
@Sum.someRight ?m.178 … …
|
||||
@Sum.someRight ?m.168 … …
|
||||
context:
|
||||
⊢ Type ?
|
||||
evalWithMVar.lean:13:20: error: don't know how to synthesize placeholder
|
||||
@c ?m.178
|
||||
@c ?m.168
|
||||
context:
|
||||
⊢ Type ?
|
||||
Sum.someRight c : Option Nat
|
||||
|
|
|
|||
|
|
@ -59,3 +59,23 @@ by match h with
|
|||
| Or.inr (Or.inr h) =>
|
||||
apply Or.inl;
|
||||
assumption
|
||||
|
||||
inductive ListLast.{u} {α : Type u} : List α → Type u
|
||||
| empty : ListLast []
|
||||
| nonEmpty : (as : List α) → (a : α) → ListLast (as ++ [a])
|
||||
|
||||
axiom last {α} (xs : List α) : ListLast xs
|
||||
axiom back {α} [Inhabited α] (xs : List α) : α
|
||||
axiom popBack {α} : List α → List α
|
||||
axiom backEq {α} [Inhabited α] : (xs : List α) → (x : α) → back (xs ++ [x]) = x
|
||||
axiom popBackEq {α} : (xs : List α) → (x : α) → popBack (xs ++ [x]) = xs
|
||||
|
||||
theorem tst8 {α} [Inhabited α] (xs : List α) : xs ≠ [] → xs = popBack xs ++ [back xs] :=
|
||||
match xs, last xs with
|
||||
| _, ListLast.empty => fun h => absurd rfl h
|
||||
| _, ListLast.nonEmpty ys y => fun _ => sorry
|
||||
|
||||
theorem tst9 {α} [Inhabited α] (xs : List α) : xs ≠ [] → xs = popBack xs ++ [back xs] := by
|
||||
match xs, last xs with
|
||||
| _, ListLast.empty => intro h; exact absurd rfl h
|
||||
| _, ListLast.nonEmpty ys y => intro; exact sorry
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue