fix: (try to) postpone when discriminant type is not known
This commit is contained in:
parent
7fec9587db
commit
c5e3da89e8
2 changed files with 133 additions and 10 deletions
|
|
@ -72,11 +72,28 @@ match e with
|
|||
| Expr.fvar fvarId _ => do localDecl ← getLocalDecl fvarId; pure localDecl.userName
|
||||
| _ => mkFreshBinderName
|
||||
|
||||
-- `expandNonAtomicDiscrs?` create auxiliary variables with base name `_discr`
|
||||
private def isAuxDiscrName (n : Name) : Bool :=
|
||||
n.eraseMacroScopes == `_discr
|
||||
|
||||
-- See expandNonAtomicDiscrs?
|
||||
private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
|
||||
let term := discr.getArg 1;
|
||||
local? ← isLocalIdent? term;
|
||||
match local? with
|
||||
| some e@(Expr.fvar fvarId _) => do
|
||||
localDecl ← getLocalDecl fvarId;
|
||||
if !isAuxDiscrName localDecl.userName then
|
||||
pure e -- it is not an auxiliary local created by `expandNonAtomicDiscrs?`
|
||||
else
|
||||
pure localDecl.value
|
||||
| _ => throwErrorAt discr "unexpected discriminant"
|
||||
|
||||
private def elabMatchTypeAndDiscrsAux (discrStxs : Array Syntax) : Nat → Array Expr → Expr → Array MatchAltView → TermElabM (Array Expr × Expr × Array MatchAltView)
|
||||
| 0, discrs, matchType, matchAltViews => pure (discrs.reverse, matchType, matchAltViews)
|
||||
| i+1, discrs, matchType, matchAltViews => do
|
||||
let discrStx := discrStxs.get! i;
|
||||
discr ← elabTerm (discrStx.getArg 1) none;
|
||||
discr ← elabAtomicDiscr discrStx;
|
||||
discr ← instantiateMVars discr;
|
||||
discrType ← inferType discr;
|
||||
discrType ← instantiateMVars discrType;
|
||||
|
|
@ -738,12 +755,100 @@ let r := mkAppN r rhss;
|
|||
trace `Elab.match fun _ => "result: " ++ r;
|
||||
pure r
|
||||
|
||||
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
|
||||
(matchStx.getArg 1).getArgs.getSepElems
|
||||
|
||||
private def getMatchOptType (matchStx : Syntax) : Syntax :=
|
||||
matchStx.getArg 2
|
||||
|
||||
private def expandNonAtomicDiscrsAux (matchStx : Syntax) : List Syntax → Array Syntax → TermElabM Syntax
|
||||
| [], discrsNew =>
|
||||
let discrs := mkSepStx discrsNew (mkAtomFrom matchStx ", ");
|
||||
pure $ matchStx.setArg 1 discrs
|
||||
| discr :: discrs, discrsNew => do
|
||||
-- Recall that
|
||||
-- matchDiscr := parser! optional (ident >> ":") >> termParser
|
||||
let term := discr.getArg 1;
|
||||
local? ← isLocalIdent? term;
|
||||
match local? with
|
||||
| some _ => expandNonAtomicDiscrsAux discrs (discrsNew.push discr)
|
||||
| none => withFreshMacroScope do
|
||||
d ← `(_discr);
|
||||
unless (isAuxDiscrName d.getId) $ -- Use assertion?
|
||||
throwError "unexpected internal auxiliary discriminant name";
|
||||
let discrNew := discr.setArg 1 d;
|
||||
r ← expandNonAtomicDiscrsAux discrs (discrsNew.push discrNew);
|
||||
`(let _discr := $term; $r)
|
||||
|
||||
private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) :=
|
||||
let matchOptType := getMatchOptType matchStx;
|
||||
if matchOptType.isNone then do
|
||||
let discrs := getDiscrs matchStx;
|
||||
allLocal ← discrs.allM fun discr => Option.isSome <$> isLocalIdent? (discr.getArg 1);
|
||||
if allLocal then
|
||||
pure none
|
||||
else
|
||||
some <$> expandNonAtomicDiscrsAux matchStx discrs.toList #[]
|
||||
else
|
||||
-- We do not pull non atomic discriminants when match type is provided explicitly by the user
|
||||
pure none
|
||||
|
||||
private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
tryPostponeIfNoneOrMVar expectedType?;
|
||||
match expectedType? with
|
||||
| some expectedType => pure expectedType
|
||||
| none => mkFreshTypeMVar
|
||||
|
||||
private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit :=
|
||||
-- We don't wait for the discriminants types when match type is provided by user
|
||||
when (getMatchOptType matchStx).isNone do
|
||||
let discrs := getDiscrs matchStx;
|
||||
discrs.forM fun discr => do
|
||||
let term := discr.getArg 1;
|
||||
local? ← isLocalIdent? term;
|
||||
match local? with
|
||||
| none => throwErrorAt discr "unexpected discriminant" -- see `expandNonAtomicDiscrs?
|
||||
| some d => do
|
||||
dType ← inferType d;
|
||||
tryPostponeIfMVar dType
|
||||
|
||||
/-
|
||||
We (try to) elaborate a `match` only when the expected type is available.
|
||||
If the `matchType` has not been provided by the user, we also try to postpone elaboration if the type
|
||||
of a discriminant is not available. That is, it is of the form `(?m ...)`.
|
||||
We use `expandNonAtomicDiscrs?` to make sure all discriminants are local variables.
|
||||
This is a standard trick we use in the elaborator, and it is also used to elaborate structure instances.
|
||||
Suppose, we are trying to elaborate
|
||||
```
|
||||
match g x with
|
||||
| ... => ...
|
||||
```
|
||||
`expandNonAtomicDiscrs?` converts it intro
|
||||
```
|
||||
let _discr := g x
|
||||
match _discr with
|
||||
| ... => ...
|
||||
```
|
||||
Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`.
|
||||
Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`.
|
||||
|
||||
This elaboration technique is needed to elaborate terms such as:
|
||||
```lean
|
||||
xs.filter fun (a, b) => a > b
|
||||
```
|
||||
which are syntax sugar for
|
||||
```lean
|
||||
List.filter (fun p => match p with | (a, b) => a > b) xs
|
||||
```
|
||||
When we visit `match p with | (a, b) => a > b`, we don't know the type of `p` yet.
|
||||
-/
|
||||
private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
tryPostponeIfNoneOrMVar expectedType?;
|
||||
tryPostponeIfDiscrTypeIsMVar matchStx;
|
||||
match expectedType? with
|
||||
| some expectedType => pure expectedType
|
||||
| none => mkFreshTypeMVar
|
||||
|
||||
/-
|
||||
```
|
||||
parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts
|
||||
|
|
@ -751,10 +856,10 @@ parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> ma
|
|||
Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`.
|
||||
-/
|
||||
private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||||
expectedType ← waitExpectedType expectedType?;
|
||||
let discrStxs := (stx.getArg 1).getArgs.getSepElems.map fun d => d;
|
||||
let altViews := getMatchAlts stx;
|
||||
let matchOptType := stx.getArg 2;
|
||||
expectedType ← waitExpectedTypeAndDiscrs stx expectedType?;
|
||||
let discrStxs := (getDiscrs stx).map fun d => d;
|
||||
let altViews := getMatchAlts stx;
|
||||
let matchOptType := getMatchOptType stx;
|
||||
elabMatchAux discrStxs altViews matchOptType expectedType
|
||||
|
||||
-- parser! "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts
|
||||
|
|
@ -765,11 +870,15 @@ fun stx expectedType? => match_syntax stx with
|
|||
| `(match $discr:term : $type with $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType?
|
||||
| `(match $discr:term : $type with | $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType?
|
||||
| _ => do
|
||||
let discrs := (stx.getArg 1).getArgs;
|
||||
let matchOptType := stx.getArg 2;
|
||||
when (!matchOptType.isNone && discrs.getSepElems.any fun d => !(d.getArg 0).isNone) $
|
||||
throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used";
|
||||
elabMatchCore stx expectedType?
|
||||
stxNew? ← expandNonAtomicDiscrs? stx;
|
||||
match stxNew? with
|
||||
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||||
| none => do
|
||||
let discrs := getDiscrs stx;
|
||||
let matchOptType := getMatchOptType stx;
|
||||
when (!matchOptType.isNone && discrs.any fun d => !(d.getArg 0).isNone) $
|
||||
throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used";
|
||||
elabMatchCore stx expectedType?
|
||||
|
||||
@[init] private def regTraceClasses : IO Unit := do
|
||||
registerTraceClass `Elab.match;
|
||||
|
|
|
|||
14
tests/lean/run/matchDiscrType.lean
Normal file
14
tests/lean/run/matchDiscrType.lean
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
new_frontend
|
||||
|
||||
def g (x : Nat) : List (Nat × List Nat) :=
|
||||
[(x, [x, x]), (x, [])]
|
||||
|
||||
def h (x : Nat) : List Nat :=
|
||||
let xs := g x $.filter (fun ⟨_, xs⟩ => xs.isEmpty)
|
||||
xs.map (·.1)
|
||||
|
||||
theorem ex1 : g 10 = [(10, [10, 10]), (10, [])] :=
|
||||
rfl
|
||||
|
||||
theorem ex2 : h 10 = [10] :=
|
||||
rfl
|
||||
Loading…
Add table
Reference in a new issue