diff --git a/library/init/lean/expander.lean b/library/init/lean/expander.lean index 0475e832c7..c3c87a5571 100644 --- a/library/init/lean/expander.lean +++ b/library/init/lean/expander.lean @@ -311,12 +311,43 @@ def pi.transform : transformer := let v := view pi stx, expandBinders (λ binders body, review pi {op := v.op, binders := binders, range := body}) v.binders v.range +def getAppArgs : Syntax → (Syntax × List Syntax) +| stx := match tryView app stx with + | some v := let (fn, args) := getAppArgs v.fn in (fn, v.Arg::args) + | none := (stx, []) + +def termToIdentsUnivs : Syntax → List identUnivs.View × Option Syntax +| stx := match stx.kind with + | some @identUnivs := ([view identUnivs stx], none) + | some @app := + let v := view app stx in + (match tryView identUnivs v.fn with + | some id := + let (ids, rem) := termToIdentsUnivs v.Arg in (id::ids, rem) + | none := ([], some stx)) + | _ := ([], some stx) + +def termToExplicitBinder (stx : Syntax) : Option explicitBinder.View := do + v ← tryView paren stx, + {Term := t, special := parenSpecial.View.typed pst} ← v.content | failure, + (ids, none) ← pure $ termToIdentsUnivs t | failure, + let bids := ids.map $ λ id, binderIdent.View.id id.id, + pure {content := explicitBinderContent.View.other {ids := bids, type := some $ {type := pst.type}}} + def arrow.transform : transformer := λ stx, do let v := view arrow stx, + -- if `stx` is of the form `(id... : e)... → f`, use the type expressions as binders + -- Ex: `(a : b) -> c` ~> `Π (a : b), c` + let (fn, args) := getAppArgs v.dom, + let groups := fn::args, pure $ review pi { op := Syntax.atom {val := "Π"}, - binders := binders.View.simple $ simpleBinder.View.explicit {id := `a, type := v.dom}, + binders := match groups.mmap termToExplicitBinder with + | some bnders := binders.View.extended { + leadingIds := [], + remainder := bindersRemainder.View.mixed $ bnders.map (mixedBinder.View.bracketed ∘ bracketedBinder.View.explicit)} + | none := binders.View.simple $ simpleBinder.View.explicit {id := `a, type := v.dom}, range := v.range} def paren.transform : transformer :=