fix: matchUnit simplification
This commit is contained in:
parent
d1d26e5ba6
commit
d493e700cc
2 changed files with 26 additions and 14 deletions
|
|
@ -36,10 +36,11 @@ private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expe
|
|||
let newStx ← `(let $lhsVar : $type := $discr; $rhs)
|
||||
withMacroExpansion stx newStx <| elabTerm newStx expectedType?
|
||||
|
||||
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := do
|
||||
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr × Bool) := do
|
||||
let mut discrs := #[]
|
||||
let mut i := 0
|
||||
let mut matchType := matchType
|
||||
let mut isDep := false
|
||||
for discrStx in discrStxs do
|
||||
i := i + 1
|
||||
matchType ← whnf matchType
|
||||
|
|
@ -47,11 +48,13 @@ private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr)
|
|||
| Expr.forallE _ d b _ =>
|
||||
let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d
|
||||
trace[Elab.match]! "discr #{i} {discr} : {d}"
|
||||
if b.hasLooseBVars then
|
||||
isDep := true
|
||||
matchType ← b.instantiate1 discr
|
||||
discrs := discrs.push discr
|
||||
| _ =>
|
||||
throwError! "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected"
|
||||
pure discrs
|
||||
pure (discrs, isDep)
|
||||
|
||||
private def mkUserNameFor (e : Expr) : TermElabM Name :=
|
||||
match e with
|
||||
|
|
@ -74,13 +77,19 @@ private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
|
|||
pure localDecl.value
|
||||
| _ => throwErrorAt discr "unexpected discriminant"
|
||||
|
||||
structure ElabMatchTypeAndDiscsResult where
|
||||
discrs : Array Expr
|
||||
matchType : Expr
|
||||
isDep : Bool
|
||||
alts : Array MatchAltView
|
||||
|
||||
private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
|
||||
: TermElabM (Array Expr × Expr × Array MatchAltView) := do
|
||||
: TermElabM ElabMatchTypeAndDiscsResult := do
|
||||
let numDiscrs := discrStxs.size
|
||||
if matchOptType.isNone then
|
||||
let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (matchAltViews : Array MatchAltView) := do
|
||||
let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (isDep : Bool) (matchAltViews : Array MatchAltView) := do
|
||||
match i with
|
||||
| 0 => pure (discrs.reverse, matchType, matchAltViews)
|
||||
| 0 => return { discrs := discrs.reverse, matchType := matchType, isDep := isDep, alts := matchAltViews }
|
||||
| i+1 =>
|
||||
let discrStx := discrStxs[i]
|
||||
let discr ← elabAtomicDiscr discrStx
|
||||
|
|
@ -88,9 +97,10 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy
|
|||
let discrType ← inferType discr
|
||||
let discrType ← instantiateMVars discrType
|
||||
let matchTypeBody ← kabstract matchType discr
|
||||
let isDep := isDep || matchTypeBody.hasLooseBVars
|
||||
let userName ← mkUserNameFor discr
|
||||
if discrStx[0].isNone then
|
||||
loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) matchAltViews
|
||||
loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) isDep matchAltViews
|
||||
else
|
||||
let identStx := discrStx[0][0]
|
||||
withLocalDeclD userName discrType fun x => do
|
||||
|
|
@ -102,13 +112,13 @@ private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Sy
|
|||
let discrs := (discrs.push refl).push discr
|
||||
let matchAltViews := matchAltViews.map fun altView =>
|
||||
{ altView with patterns := altView.patterns.insertAt (i+1) identStx }
|
||||
loop i discrs matchType matchAltViews
|
||||
loop discrStxs.size #[] expectedType matchAltViews
|
||||
loop i discrs matchType isDep matchAltViews
|
||||
loop discrStxs.size (discrs := #[]) (isDep := false) expectedType matchAltViews
|
||||
else
|
||||
let matchTypeStx := matchOptType[0][1]
|
||||
let matchType ← elabType matchTypeStx
|
||||
let discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType
|
||||
pure (discrs, matchType, matchAltViews)
|
||||
let (discrs, isDep) ← elabDiscrsWitMatchType discrStxs matchType expectedType
|
||||
return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews }
|
||||
|
||||
def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do
|
||||
matchAlts.mapM fun matchAlt => do
|
||||
|
|
@ -723,8 +733,8 @@ private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : Met
|
|||
|
||||
private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
|
||||
: TermElabM Expr := do
|
||||
let (discrs, matchType, altLHSS, rhss) ← commitIfDidNotPostpone do
|
||||
let (discrs, matchType, altViews) ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
|
||||
let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do
|
||||
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
|
||||
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
|
||||
trace[Elab.match]! "matchType: {matchType}"
|
||||
let alts ← matchAlts.mapM fun alt => elabMatchAltView alt matchType
|
||||
|
|
@ -780,8 +790,8 @@ private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltVi
|
|||
tryPostpone
|
||||
throwMVarError m!"invalid match-expression, pattern contains metavariables{indentExpr (← p.toExpr)}"
|
||||
pure altLHS
|
||||
return (discrs, matchType, altLHSS, rhss)
|
||||
if let some r ← isMatchUnit? altLHSS rhss then
|
||||
return (discrs, matchType, altLHSS, isDep, rhss)
|
||||
if let some r ← if isDep then pure none else isMatchUnit? altLHSS rhss then
|
||||
return r
|
||||
else
|
||||
let numDiscrs := discrs.size
|
||||
|
|
|
|||
2
tests/lean/run/match_unit.lean
Normal file
2
tests/lean/run/match_unit.lean
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
theorem ex : ∀ x : Unit, x = () := by
|
||||
intro (); rfl
|
||||
Loading…
Add table
Reference in a new issue