fix: matchUnit simplification

This commit is contained in:
Leonardo de Moura 2021-02-19 13:51:08 -08:00
parent d1d26e5ba6
commit d493e700cc
2 changed files with 26 additions and 14 deletions

View file

@ -36,10 +36,11 @@ private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expe
let newStx ← `(let $lhsVar : $type := $discr; $rhs)
withMacroExpansion stx newStx <| elabTerm newStx expectedType?
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := do
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr × Bool) := do
let mut discrs := #[]
let mut i := 0
let mut matchType := matchType
let mut isDep := false
for discrStx in discrStxs do
i := i + 1
matchType ← whnf matchType
@ -47,11 +48,13 @@ private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr)
| Expr.forallE _ d b _ =>
let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d
trace[Elab.match]! "discr #{i} {discr} : {d}"
if b.hasLooseBVars then
isDep := true
matchType ← b.instantiate1 discr
discrs := discrs.push discr
| _ =>
throwError! "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected"
pure discrs
pure (discrs, isDep)
private def mkUserNameFor (e : Expr) : TermElabM Name :=
match e with
@ -74,13 +77,19 @@ private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
pure localDecl.value
| _ => throwErrorAt discr "unexpected discriminant"
structure ElabMatchTypeAndDiscsResult where
discrs : Array Expr
matchType : Expr
isDep : Bool
alts : Array MatchAltView
private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
: TermElabM (Array Expr × Expr × Array MatchAltView) := do
: TermElabM ElabMatchTypeAndDiscsResult := do
let numDiscrs := discrStxs.size
if matchOptType.isNone then
let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (matchAltViews : Array MatchAltView) := do
let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (isDep : Bool) (matchAltViews : Array MatchAltView) := do
match i with
| 0 => pure (discrs.reverse, matchType, matchAltViews)
| 0 => return { discrs := discrs.reverse, matchType := matchType, isDep := isDep, alts := matchAltViews }
| i+1 =>
let discrStx := discrStxs[i]
let discr ← elabAtomicDiscr discrStx
@ -88,9 +97,10 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy
let discrType ← inferType discr
let discrType ← instantiateMVars discrType
let matchTypeBody ← kabstract matchType discr
let isDep := isDep || matchTypeBody.hasLooseBVars
let userName ← mkUserNameFor discr
if discrStx[0].isNone then
loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) matchAltViews
loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) isDep matchAltViews
else
let identStx := discrStx[0][0]
withLocalDeclD userName discrType fun x => do
@ -102,13 +112,13 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy
let discrs := (discrs.push refl).push discr
let matchAltViews := matchAltViews.map fun altView =>
{ altView with patterns := altView.patterns.insertAt (i+1) identStx }
loop i discrs matchType matchAltViews
loop discrStxs.size #[] expectedType matchAltViews
loop i discrs matchType isDep matchAltViews
loop discrStxs.size (discrs := #[]) (isDep := false) expectedType matchAltViews
else
let matchTypeStx := matchOptType[0][1]
let matchType ← elabType matchTypeStx
let discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType
pure (discrs, matchType, matchAltViews)
let (discrs, isDep) ← elabDiscrsWitMatchType discrStxs matchType expectedType
return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews }
def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do
matchAlts.mapM fun matchAlt => do
@ -723,8 +733,8 @@ private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : Met
private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
: TermElabM Expr := do
let (discrs, matchType, altLHSS, rhss) ← commitIfDidNotPostpone do
let (discrs, matchType, altViews) ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
trace[Elab.match]! "matchType: {matchType}"
let alts ← matchAlts.mapM fun alt => elabMatchAltView alt matchType
@ -780,8 +790,8 @@ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltVi
tryPostpone
throwMVarError m!"invalid match-expression, pattern contains metavariables{indentExpr (← p.toExpr)}"
pure altLHS
return (discrs, matchType, altLHSS, rhss)
if let some r ← isMatchUnit? altLHSS rhss then
return (discrs, matchType, altLHSS, isDep, rhss)
if let some r ← if isDep then pure none else isMatchUnit? altLHSS rhss then
return r
else
let numDiscrs := discrs.size

View file

@ -0,0 +1,2 @@
theorem ex : ∀ x : Unit, x = () := by
intro (); rfl