feat: expand for x in xs notation
This commit is contained in:
parent
ac16393ae9
commit
eafd9bc0ad
2 changed files with 106 additions and 25 deletions
|
|
@ -588,7 +588,10 @@ else
|
|||
namespace ToTerm
|
||||
|
||||
inductive Kind
|
||||
| regular | forInNestedTerm | forIn | forInMap
|
||||
| regular
|
||||
| forInNestedTerm
|
||||
| forIn
|
||||
| forInMap (x : Name)
|
||||
|
||||
structure Context :=
|
||||
(m : Syntax) -- Syntax to reference the monad associated with the do notation.
|
||||
|
|
@ -614,12 +617,22 @@ else do
|
|||
uvars ← mkUVarTuple ref;
|
||||
liftM $ mkTuple ref #[val, uvars]
|
||||
|
||||
private def mkForInYield (ref : Syntax) : M Syntax := do
|
||||
u ← mkUVarTuple ref;
|
||||
`(HasPure.pure (ForInStep.yield $u))
|
||||
|
||||
private def mkForInMapYield (ref : Syntax) (x : Name) : M Syntax := do
|
||||
u ← mkUVarTuple ref;
|
||||
r ← liftM $ mkTuple ref #[mkIdentFrom ref x, u];
|
||||
`(HasPure.pure (ForInStep.yield $r))
|
||||
|
||||
def returnToTermCore (ref : Syntax) (val : Syntax) : M Syntax := do
|
||||
ctx ← read;
|
||||
match ctx.kind with
|
||||
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«return» $u))
|
||||
| Kind.regular => do r ← mkResultUVarTuple ref val; `(HasPure.pure $r)
|
||||
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u))
|
||||
| Kind.forIn => mkForInYield ref
|
||||
| Kind.forInMap x => mkForInMapYield ref x
|
||||
|
||||
def returnToTerm (ref : Syntax) (val : Syntax) : M Syntax := do
|
||||
r ← returnToTermCore ref val;
|
||||
|
|
@ -630,7 +643,8 @@ ctx ← read;
|
|||
match ctx.kind with
|
||||
| Kind.regular => unreachable!
|
||||
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«continue» $u))
|
||||
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u))
|
||||
| Kind.forIn => mkForInYield ref
|
||||
| Kind.forInMap x => mkForInMapYield ref x
|
||||
|
||||
def continueToTerm (ref : Syntax) : M Syntax := do
|
||||
r ← continueToTermCore ref;
|
||||
|
|
@ -641,7 +655,8 @@ ctx ← read;
|
|||
match ctx.kind with
|
||||
| Kind.regular => unreachable!
|
||||
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«break» $u))
|
||||
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u))
|
||||
| Kind.forIn => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u))
|
||||
| Kind.forInMap x => do u ← mkUVarTuple ref; r ← liftM $ mkTuple ref #[mkIdentFrom ref x, u]; `(HasPure.pure (ForInStep.done $r))
|
||||
|
||||
def breakToTerm (ref : Syntax) : M Syntax := do
|
||||
r ← breakToTermCore ref;
|
||||
|
|
@ -748,25 +763,9 @@ partial def toTerm : Code → M Syntax
|
|||
| Code.ite ref _ o c t e => do t ← toTerm t; e ← toTerm e; pure $ mkIte ref o c t e
|
||||
| _ => liftM $ Macro.throwError Syntax.missing "WIP"
|
||||
|
||||
private def getKindUVars (c : CodeBlock) (forInVar? : Option Name) : Kind × Array Name :=
|
||||
match forInVar? with
|
||||
| none =>
|
||||
if hasContinueBreak c.code then
|
||||
(Kind.forInNestedTerm, nameSetToArray c.uvars)
|
||||
else
|
||||
(Kind.regular, nameSetToArray c.uvars)
|
||||
| some forInVar =>
|
||||
if c.uvars.contains forInVar then
|
||||
let uvars := #[forInVar] ++ nameSetToArray (c.uvars.erase forInVar);
|
||||
(Kind.forInMap, uvars)
|
||||
else
|
||||
(Kind.forIn, nameSetToArray c.uvars)
|
||||
|
||||
def run (c : CodeBlock) (m : Syntax) (forInVar? : Option Name := none) : MacroM (Array Name × Syntax) := do
|
||||
let code := c.code;
|
||||
let (kind, uvars) := getKindUVars c forInVar?;
|
||||
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
|
||||
term ← toTerm code { m := m, kind := kind, uvars := uvars };
|
||||
pure (uvars, term)
|
||||
pure term
|
||||
|
||||
end ToTerm
|
||||
|
||||
|
|
@ -779,6 +778,27 @@ structure Context :=
|
|||
|
||||
abbrev M := ReaderT Context TermElabM
|
||||
|
||||
@[inline] def withFor {α} (x : M α) : M α :=
|
||||
adaptReader (fun (ctx : Context) => { ctx with insideFor := true }) x
|
||||
|
||||
structure ToForInTermResult :=
|
||||
(isForInMap : Bool)
|
||||
(uvars : Array Name)
|
||||
(term : Syntax)
|
||||
|
||||
def toForInTerm (x : Syntax) (forCodeBlock : CodeBlock) : M ToForInTermResult := do
|
||||
ctx ← read;
|
||||
let uvars := forCodeBlock.uvars;
|
||||
if x.isIdent && uvars.contains x.getId then do
|
||||
-- It is a forInMap
|
||||
let uvars := nameSetToArray (uvars.erase x.getId);
|
||||
term ← liftMacroM $ ToTerm.run forCodeBlock.code ctx.m uvars (ToTerm.Kind.forInMap x.getId);
|
||||
pure ⟨true, uvars, term⟩
|
||||
else do
|
||||
let uvars := nameSetToArray uvars;
|
||||
term ← liftMacroM $ ToTerm.run forCodeBlock.code ctx.m uvars ToTerm.Kind.forIn;
|
||||
pure ⟨false, uvars, term⟩
|
||||
|
||||
def ensureInsideFor : M Unit := do
|
||||
ctx ← read;
|
||||
unless ctx.insideFor $
|
||||
|
|
@ -869,8 +889,23 @@ partial def doSeqToCode : List Syntax → M CodeBlock
|
|||
body ← doSeqToCode (getDoSeqElems doSeq);
|
||||
unless ← liftM $ mkUnless ref cond body;
|
||||
concatWithRest unless
|
||||
else if k == `Lean.Parser.Term.doFor then
|
||||
throwError "WIP"
|
||||
else if k == `Lean.Parser.Term.doFor then withFreshMacroScope do
|
||||
let ref := doElem;
|
||||
let x := doElem.getArg 1;
|
||||
let xs := doElem.getArg 3;
|
||||
let forElems := getDoSeqElems (doElem.getArg 5);
|
||||
forCodeBlock ← withFor (doSeqToCode forElems);
|
||||
⟨isForInMap, uvars, forInBody⟩ ← toForInTerm x forCodeBlock;
|
||||
uvarsTuple ← liftMacroM $ mkTuple ref (uvars.map (mkIdentFrom ref));
|
||||
auxDo ← if isForInMap then do
|
||||
forInTerm ← `($(xs).forInMap $uvarsTuple fun $x $uvarsTuple => $forInBody);
|
||||
`(do let r ← $forInTerm; $uvarsTuple:term := r.2; return r.1)
|
||||
else do {
|
||||
forInTerm ← `($(xs).forIn $uvarsTuple fun $x $uvarsTuple => $forInBody);
|
||||
`(do let r ← $forInTerm; $uvarsTuple:term := r)
|
||||
};
|
||||
let doElemsNew := getDoSeqElems (getDoSeq auxDo);
|
||||
doSeqToCode (doElemsNew ++ doElems)
|
||||
else if k == `Lean.Parser.Term.doMatch then
|
||||
throwError "WIP"
|
||||
else if k == `Lean.Parser.Term.doTry then
|
||||
|
|
@ -927,7 +962,7 @@ fun stx expectedType? => do
|
|||
m ← mkMonadAlias bindInfo.m;
|
||||
codeBlock ← ToCodeBlock.run stx m;
|
||||
-- trace! `Elab.do ("codeBlock: " ++ Format.line ++ codeBlock.toMessageData);
|
||||
(_, stxNew) ← liftMacroM $ ToTerm.run { codeBlock with uvars := {} } m;
|
||||
stxNew ← liftMacroM $ ToTerm.run codeBlock.code m;
|
||||
trace! `Elab.do stxNew;
|
||||
let expectedType := mkApp bindInfo.m bindInfo.α;
|
||||
withMacroExpansion stx stxNew $ elabTermEnsuringType stxNew expectedType
|
||||
|
|
|
|||
|
|
@ -44,3 +44,49 @@ rfl
|
|||
|
||||
theorem ex4 (y : Nat) : h 1 y = (1 + 1) + y :=
|
||||
rfl
|
||||
|
||||
def sumOdd (xs : List Nat) (threshold : Nat) : Nat := do
|
||||
let sum := 0
|
||||
for x in xs do
|
||||
if x % 2 == 1 then
|
||||
sum := sum + x
|
||||
if sum > threshold then
|
||||
break
|
||||
unless x % 2 == 1 do
|
||||
continue
|
||||
dbgTrace! ">> x: " ++ toString x
|
||||
return sum
|
||||
|
||||
#eval sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10
|
||||
|
||||
theorem ex5 : sumOdd [1, 2, 3, 4, 5, 6, 7, 9, 11, 101] 10 = 16 :=
|
||||
rfl
|
||||
|
||||
def mapOdd (f : Nat → Nat) (xs : List Nat) : List Nat := do
|
||||
for x in xs do
|
||||
if x % 2 == 1 then
|
||||
x := f x
|
||||
dbgTrace! ">> mapOdd x: " ++ toString x
|
||||
|
||||
#eval mapOdd (·+10) [1, 2, 3, 4, 5, 6, 7, 9]
|
||||
|
||||
theorem ex6 : mapOdd (·+10) [1, 2, 3, 4, 5, 6, 7, 9] = [11, 2, 13, 4, 15, 6, 17, 19] :=
|
||||
rfl
|
||||
|
||||
-- We need `Id.run` because we still have `Monad Option`
|
||||
def find? (xs : List Nat) (p : Nat → Bool) : Option Nat := Id.run do
|
||||
let result := none
|
||||
for x in xs do
|
||||
if p x then
|
||||
result := x
|
||||
break
|
||||
return result
|
||||
|
||||
def sumDiff (ps : List (Nat × Nat)) : Nat := do
|
||||
let sum := 0
|
||||
for (x, y) in ps do
|
||||
sum := sum + x - y
|
||||
return sum
|
||||
|
||||
theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 :=
|
||||
rfl
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue