feat: propagate return type to for-in block
This commit is contained in:
parent
71edf731f9
commit
bdaabd4e7b
4 changed files with 42 additions and 26 deletions
|
|
@ -87,24 +87,24 @@ private partial def hasLiftMethod : Syntax → Bool
|
|||
|
||||
structure ExtractMonadResult where
|
||||
m : Expr
|
||||
α : Expr
|
||||
returnType : Expr
|
||||
expectedType : Expr
|
||||
|
||||
private def mkUnknownMonadResult : MetaM ExtractMonadResult := do
|
||||
let u ← mkFreshLevelMVar
|
||||
let v ← mkFreshLevelMVar
|
||||
let m ← mkFreshExprMVar (← mkArrow (mkSort (mkLevelSucc u)) (mkSort (mkLevelSucc v)))
|
||||
let α ← mkFreshExprMVar (mkSort (mkLevelSucc u))
|
||||
return { m, α, expectedType := mkApp m α }
|
||||
let returnType ← mkFreshExprMVar (mkSort (mkLevelSucc u))
|
||||
return { m, returnType, expectedType := mkApp m returnType }
|
||||
|
||||
private partial def extractBind (expectedType? : Option Expr) : TermElabM ExtractMonadResult := do
|
||||
let some expectedType := expectedType? | mkUnknownMonadResult
|
||||
let extractStep? (type : Expr) : MetaM (Option ExtractMonadResult) := do
|
||||
let .app m α _ := type | return none
|
||||
let .app m returnType _ := type | return none
|
||||
try
|
||||
let bindInstType ← mkAppM ``Bind #[m]
|
||||
discard <| Meta.synthInstance bindInstType
|
||||
return some { m, α, expectedType }
|
||||
return some { m, returnType, expectedType }
|
||||
catch _ =>
|
||||
return none
|
||||
let rec extract? (type : Expr) : MetaM (Option ExtractMonadResult) := do
|
||||
|
|
@ -863,9 +863,12 @@ def Kind.isRegular : Kind → Bool
|
|||
| _ => false
|
||||
|
||||
structure Context where
|
||||
m : Syntax -- Syntax to reference the monad associated with the do notation.
|
||||
uvars : Array Var
|
||||
kind : Kind
|
||||
/-- Syntax to reference the monad associated with the do notation. -/
|
||||
m : Syntax
|
||||
/-- Syntax to reference the result of the monadic computation performed by the do notation. -/
|
||||
returnType : Syntax
|
||||
uvars : Array Var
|
||||
kind : Kind
|
||||
|
||||
abbrev M := ReaderT Context MacroM
|
||||
|
||||
|
|
@ -1031,8 +1034,8 @@ partial def toTerm (c : Code) : M Syntax := do
|
|||
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode termAlts]
|
||||
return mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", genParam, optMotive, discrs, mkAtomFrom ref "with", termMatchAlts]
|
||||
|
||||
def run (code : Code) (m : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax :=
|
||||
toTerm code { m := m, kind := kind, uvars := uvars }
|
||||
def run (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var := #[]) (kind := Kind.regular) : MacroM Syntax :=
|
||||
toTerm code { m, returnType, kind, uvars }
|
||||
|
||||
/- Given
|
||||
- `a` is true if the code block has a `Code.action _` exit point
|
||||
|
|
@ -1051,8 +1054,8 @@ def mkNestedKind (a r bc : Bool) : Kind :=
|
|||
| true, true, true => .nestedPRBC
|
||||
| false, false, false => unreachable!
|
||||
|
||||
def mkNestedTerm (code : Code) (m : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do
|
||||
ToTerm.run code m uvars (mkNestedKind a r bc)
|
||||
def mkNestedTerm (code : Code) (m : Syntax) (returnType : Syntax) (uvars : Array Var) (a r bc : Bool) : MacroM Syntax := do
|
||||
ToTerm.run code m returnType uvars (mkNestedKind a r bc)
|
||||
|
||||
/- Given a term `term` produced by `ToTerm.run`, pattern match on its result.
|
||||
See comment at the beginning of the `ToTerm` namespace.
|
||||
|
|
@ -1119,7 +1122,10 @@ namespace ToCodeBlock
|
|||
|
||||
structure Context where
|
||||
ref : Syntax
|
||||
m : Syntax -- Syntax representing the monad associated with the do notation.
|
||||
/-- Syntax representing the monad associated with the do notation. -/
|
||||
m : Syntax
|
||||
/-- Syntax to reference the result of the monadic computation performed by the do notation. -/
|
||||
returnType : Syntax
|
||||
mutableVars : VarSet := {}
|
||||
insideFor : Bool := false
|
||||
|
||||
|
|
@ -1155,7 +1161,7 @@ def mkForInBody (_ : Syntax) (forInBody : CodeBlock) : M ToForInTermResult := d
|
|||
let ctx ← read
|
||||
let uvars := forInBody.uvars
|
||||
let uvars := varSetToArray uvars
|
||||
let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn)
|
||||
let term ← liftMacroM <| ToTerm.run forInBody.code ctx.m ctx.returnType uvars (if hasReturn forInBody.code then ToTerm.Kind.forInWithReturn else ToTerm.Kind.forIn)
|
||||
return ⟨uvars, term⟩
|
||||
|
||||
def ensureInsideFor : M Unit :=
|
||||
|
|
@ -1413,10 +1419,11 @@ mutual
|
|||
let uvarsTuple ← liftMacroM do mkTuple uvars
|
||||
if hasReturn forInBodyCodeBlock.code then
|
||||
let forInBody ← liftMacroM <| destructTuple uvars (← `(r)) forInBody
|
||||
let optType ← `(Option $((← read).returnType))
|
||||
let forInTerm ← if let some h := h? then
|
||||
`(for_in'% $(xs) (MProd.mk none $uvarsTuple) fun $x $h r => let r := r.2; $forInBody)
|
||||
`(for_in'% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x $h (r : MProd $optType _) => let r := r.2; $forInBody)
|
||||
else
|
||||
`(for_in% $(xs) (MProd.mk none $uvarsTuple) fun $x r => let r := r.2; $forInBody)
|
||||
`(for_in% $(xs) (MProd.mk (none : $optType) $uvarsTuple) fun $x (r : MProd $optType _) => let r := r.2; $forInBody)
|
||||
let auxDo ← `(do let r ← $forInTerm:term;
|
||||
$uvarsTuple:term := r.2;
|
||||
match r.1 with
|
||||
|
|
@ -1495,7 +1502,7 @@ mutual
|
|||
let bc := tryCatchPred tryCode catches finallyCode? hasBreakContinue
|
||||
let toTerm (codeBlock : CodeBlock) : M Syntax := do
|
||||
let codeBlock ← liftM $ extendUpdatedVars codeBlock ws
|
||||
liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m uvars a r bc
|
||||
liftMacroM <| ToTerm.mkNestedTerm codeBlock.code ctx.m ctx.returnType uvars a r bc
|
||||
let term ← toTerm tryCode
|
||||
let term ← catches.foldlM (init := term) fun term «catch» => do
|
||||
let catchTerm ← toTerm «catch».codeBlock
|
||||
|
|
@ -1511,7 +1518,7 @@ mutual
|
|||
throwError "'finally' currently does not support reassignments"
|
||||
if hasBreakContinueReturn finallyCode.code then
|
||||
throwError "'finally' currently does 'return', 'break', nor 'continue'"
|
||||
let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m {} ToTerm.Kind.regular
|
||||
let finallyTerm ← liftMacroM <| ToTerm.run finallyCode.code ctx.m ctx.returnType {} ToTerm.Kind.regular
|
||||
``(tryFinally $term $finallyTerm)
|
||||
let doElemsNew ← liftMacroM <| ToTerm.matchNestedTermResult term uvars a r bc
|
||||
doSeqToCode (doElemsNew ++ doElems)
|
||||
|
|
@ -1592,8 +1599,8 @@ mutual
|
|||
throwError "unexpected do-element of kind {doElem.getKind}:\n{doElem}"
|
||||
end
|
||||
|
||||
def run (doStx : Syntax) (m : Syntax) : TermElabM CodeBlock :=
|
||||
(doSeqToCode <| getDoSeqElems <| getDoSeq doStx).run { ref := doStx, m }
|
||||
def run (doStx : Syntax) (m : Syntax) (returnType : Syntax) : TermElabM CodeBlock :=
|
||||
(doSeqToCode <| getDoSeqElems <| getDoSeq doStx).run { ref := doStx, m, returnType }
|
||||
|
||||
end ToCodeBlock
|
||||
|
||||
|
|
@ -1601,8 +1608,9 @@ end ToCodeBlock
|
|||
tryPostponeIfNoneOrMVar expectedType?
|
||||
let bindInfo ← extractBind expectedType?
|
||||
let m ← Term.exprToSyntax bindInfo.m
|
||||
let codeBlock ← ToCodeBlock.run stx m
|
||||
let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m
|
||||
let returnType ← Term.exprToSyntax bindInfo.returnType
|
||||
let codeBlock ← ToCodeBlock.run stx m returnType
|
||||
let stxNew ← liftMacroM <| ToTerm.run codeBlock.code m returnType
|
||||
trace[Elab.do] stxNew
|
||||
withMacroExpansion stx stxNew <| elabTermEnsuringType stxNew bindInfo.expectedType
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
217.lean:5:28-5:29: error: don't know how to synthesize placeholder for argument 'f'
|
||||
context:
|
||||
⊢ CoreM Unit → Name → ConstantInfo → CoreM Unit
|
||||
217.lean:5:30-5:31: error: don't know how to synthesize placeholder for argument 'init'
|
||||
context:
|
||||
⊢ CoreM Unit
|
||||
217.lean:5:28-5:29: error: don't know how to synthesize placeholder for argument 'f'
|
||||
context:
|
||||
⊢ CoreM Unit → Name → ConstantInfo → CoreM Unit
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ linterUnusedVariables.lean:50:11-50:12: warning: unused variable `z`
|
|||
linterUnusedVariables.lean:55:14-55:15: warning: unused variable `y`
|
||||
linterUnusedVariables.lean:61:20-61:21: warning: unused variable `y`
|
||||
linterUnusedVariables.lean:66:34-66:38: warning: unused variable `inst`
|
||||
linterUnusedVariables.lean:108:6-108:7: warning: unused variable `y`
|
||||
linterUnusedVariables.lean:107:25-107:26: warning: unused variable `x`
|
||||
linterUnusedVariables.lean:108:6-108:7: warning: unused variable `y`
|
||||
linterUnusedVariables.lean:114:6-114:7: warning: unused variable `a`
|
||||
linterUnusedVariables.lean:124:26-124:27: warning: unused variable `z`
|
||||
linterUnusedVariables.lean:132:9-132:10: warning: unused variable `h`
|
||||
|
|
|
|||
8
tests/lean/run/forInReturnPropagation.lean
Normal file
8
tests/lean/run/forInReturnPropagation.lean
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
def main (args : List String) : IO UInt32 := do
|
||||
for (arg : String) in args do
|
||||
match arg with
|
||||
| "--print-cflags" =>
|
||||
return 1
|
||||
| _ =>
|
||||
pure ()
|
||||
return 0
|
||||
Loading…
Add table
Reference in a new issue