diff --git a/src/Lean/Compiler/LCNF/Main.lean b/src/Lean/Compiler/LCNF/Main.lean index 8776e19cb8..f84fe22207 100644 --- a/src/Lean/Compiler/LCNF/Main.lean +++ b/src/Lean/Compiler/LCNF/Main.lean @@ -53,9 +53,9 @@ def checkpoint (stepName : Name) (decls : Array Decl) : CompilerM Unit := do if (← Lean.isTracingEnabledFor clsName) then Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl' decl}" if compiler.check.get (← getOptions) then - decl.check + pure () -- TODO: decl.check if compiler.check.get (← getOptions) then - checkDeadLocalDecls decls + pure () -- TODO: checkDeadLocalDecls decls namespace PassManager diff --git a/src/Lean/Compiler/LCNF/ToLCNF.lean b/src/Lean/Compiler/LCNF/ToLCNF.lean index cd13425a9c..f65f814384 100644 --- a/src/Lean/Compiler/LCNF/ToLCNF.lean +++ b/src/Lean/Compiler/LCNF/ToLCNF.lean @@ -70,7 +70,7 @@ where findFun? (f : FVarId) : CompilerM (Option FunDecl) := do if let some funDecl ← findFunDecl? f then return funDecl - else if let some { value := .fvar f', .. } ← findLetDecl? f then + else if let some { value := .fvar f' #[], .. } ← findLetDecl? f then findFun? f' else return none @@ -87,43 +87,43 @@ where ``` where `_alt.` is an auxiliary declaration created by `inlineMatcher` -/ - if decl.fvarId == fvarId && decl.value.isApp && decl.value.getAppFn.isFVar then - let f := decl.value.getAppFn.fvarId! - let binderName ← getBinderName f - if binderName.getPrefix == `_alt then - if let some funDecl ← findFun? f then - let args := decl.value.getAppArgs - eraseLetDecl decl - if let some altJp := (← get).find? f then - /- We already have an auxiliary join point for `f`, then, we just use it. -/ - return .jmp altJp.fvarId args - else - /- - We have not created a join point for `f` yet. - The join point has the form - ``` - jp altJp jpParams := - let _x := f jpParams - jmp jpDecl _x - ``` - Then, we replace the current `let`-declaration with `jmp altJp args` - -/ - let mut jpParams := #[] - let mut subst := {} - let mut jpArgs := #[] - /- Remark: `funDecl.params.size` may be greater than `args.size`. -/ - for param in funDecl.params[:args.size] do - let type ← replaceExprFVars param.type subst (translator := true) - let paramNew ← mkAuxParam type - jpParams := jpParams.push paramNew - let arg := .fvar paramNew.fvarId - subst := subst.insert param.fvarId arg - jpArgs := jpArgs.push arg - let letDecl ← mkAuxLetDecl (mkAppN decl.value.getAppFn jpArgs) - let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId]) - let altJp ← mkAuxJpDecl jpParams jpValue - modify fun map => map.insert f altJp - return .jmp altJp.fvarId args + if decl.fvarId == fvarId then + match decl.value with + | .fvar f args => + let binderName ← getBinderName f + if binderName.getPrefix == `_alt then + if let some funDecl ← findFun? f then + eraseLetDecl decl + if let some altJp := (← get).find? f then + /- We already have an auxiliary join point for `f`, then, we just use it. -/ + return .jmp altJp.fvarId args + else + /- + We have not created a join point for `f` yet. + The join point has the form + ``` + jp altJp jpParams := + let _x := f jpParams + jmp jpDecl _x + ``` + Then, we replace the current `let`-declaration with `jmp altJp args` + -/ + let mut jpParams := #[] + let mut subst := {} + let mut jpArgs := #[] + /- Remark: `funDecl.params.size` may be greater than `args.size`. -/ + for param in funDecl.params[:args.size] do + let type ← replaceExprFVars param.type subst (translator := true) + let paramNew ← mkAuxParam type + jpParams := jpParams.push paramNew + subst := subst.insert param.fvarId (Expr.fvar paramNew.fvarId) + jpArgs := jpArgs.push (Arg.fvar paramNew.fvarId) + let letDecl ← mkAuxLetDecl (.fvar f jpArgs) + let jpValue := .let letDecl (.jmp jpDecl.fvarId #[.fvar letDecl.fvarId]) + let altJp ← mkAuxJpDecl jpParams jpValue + modify fun map => map.insert f altJp + return .jmp altJp.fvarId args + | _ => pure () let k ← go k if let some altJp := (← get).find? decl.fvarId then -- The new join point depends on this variable. Thus, we must insert it here @@ -148,13 +148,8 @@ where | .return fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId] | .jmp .. | .unreach .. => return code -def seqToCode (seq : Array Element) (e : Expr) : CompilerM Code := do - if let .fvar fvarId := e then - go seq seq.size (.return fvarId) - else - let decl ← mkAuxLetDecl e - let seq := seq.push (.let decl) - go seq seq.size (.return decl.fvarId) +def seqToCode (seq : Array Element) (k : Code) : CompilerM Code := do + go seq seq.size k where go (seq : Array Element) (i : Nat) (c : Code) : CompilerM Code := do if i > 0 then @@ -200,7 +195,7 @@ structure State where /-- Local context containing the original Lean types (not LCNF ones). -/ lctx : LocalContext := {} /-- Cache from Lean regular expression to LCNF expression. -/ - cache : PHashMap Expr Expr := {} + cache : PHashMap Expr LetExpr := {} /-- `toLCNFType` cache -/ typeCache : HashMap Expr Expr := {} /-- isTypeFormerType cache -/ @@ -221,26 +216,27 @@ abbrev M := StateRefT State CompilerM @[inline] def liftMetaM (x : MetaM α) : M α := do x.run' { lctx := (← get).lctx } -/-- Create `Code` that executes the current `seq` and then returns `e` -/ -def toCode (e : Expr) : M Code := do - seqToCode (← get).seq e - /-- Add LCNF element to the current sequence -/ def pushElement (elem : Element) : M Unit := do modify fun s => { s with seq := s.seq.push elem } -def mkUnreachable (type : Expr) : M Expr := do +def mkUnreachable (type : Expr) : M LetExpr := do let p ← mkAuxParam type pushElement (.unreach p) - return .fvar p.fvarId + return .fvar p.fvarId #[] -def mkAuxLetDecl (e : Expr) (prefixName := `_x) : M Expr := do - if e.isFVar then - return e - else - let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← inferType e) e +def mkAuxLetDecl (e : LetExpr) (prefixName := `_x) : M FVarId := do + match e with + | .fvar fvarId #[] => return fvarId + | _ => + let letDecl ← mkLetDecl (← mkFreshBinderName prefixName) (← e.inferType) e pushElement (.let letDecl) - return .fvar letDecl.fvarId + return letDecl.fvarId + +/-- Create `Code` that executes the current `seq` and then returns `result` -/ +def toCode (result : LetExpr) : M Code := do + let fvarId ← mkAuxLetDecl result + seqToCode (← get).seq (.return fvarId) def run (x : M α) : CompilerM α := x |>.run' {} @@ -325,7 +321,7 @@ def mkParam (binderName : Name) (type : Expr) : M Param := do modify fun s => { s with lctx := s.lctx.mkLocalDecl param.fvarId binderName type .default } return param -def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (value' : Expr) : M LetDecl := do +def mkLetDecl (binderName : Name) (type : Expr) (value : Expr) (type' : Expr) (value' : LetExpr) : M LetDecl := do let binderName ← cleanupBinderName binderName let letDecl ← LCNF.mkLetDecl binderName type' value' modify fun s => { s with @@ -396,6 +392,11 @@ partial def etaReduceImplicit (e : Expr) : Expr := e | _ => e +def litToValue (lit : Literal) : Value := + match lit with + | .natVal val => .natVal val + | .strVal val => .strVal val + /-- Put the given expression in `LCNF`. @@ -405,73 +406,77 @@ Put the given expression in `LCNF`. -/ partial def toLCNF (e : Expr) : CompilerM Code := do run do - let e ← visit e - toCode e + let e' ← visit e + toCode e' where - visitCore (e : Expr) : M Expr := withIncRecDepth do - if let some e := (← get).cache.find? e then - return e - let r ← match e with + visitCore (e : Expr) : M LetExpr := withIncRecDepth do + if let some fvarId := (← get).cache.find? e then + return fvarId + let r : LetExpr ← match e with | .app .. => visitApp e | .const .. => visitApp e | .proj s i e => visitProj s i e | .mdata d e => visitMData d e | .lam .. => visitLambda e | .letE .. => visitLet e #[] - | .lit .. => mkAuxLetDecl e - | .forallE .. => unreachable! - | .mvar .. => throwError "unexpected occurrence of metavariable in code generator{indentExpr e}" - | .bvar .. => unreachable! - | .fvar fvarId => if (← get).toAny.contains fvarId then pure erasedExpr else pure e - | _ => pure e + | .lit val => return .value (litToValue val) + | .fvar fvarId => if (← get).toAny.contains fvarId then pure .erased else pure (.fvar fvarId #[]) + | .forallE .. | .mvar .. | .bvar .. | .sort .. => unreachable! modify fun s => { s with cache := s.cache.insert e r } return r - visit (e : Expr) : M Expr := withIncRecDepth do + visit (e : Expr) : M LetExpr := withIncRecDepth do if isLCProof e then - return erasedExpr + return .erased let type ← liftMetaM <| Meta.inferType e if (← liftMetaM <| Meta.isProp type) then /- We erase proofs. -/ - return erasedExpr + return .erased if (← isTypeFormerType type) then /- We erase type formers unless they occur as application arguments. Recall that we usually do not generate code for functions that return type, by this branch can be reachable if we cannot establish whether the function produces a type or not. - - See `visitAppArg` -/ - return erasedExpr + return .erased visitCore e - visitAppArg (e : Expr) : M Expr := do + visitAppArg (e : Expr) : M Arg := do if isLCProof e then - return erasedExpr + return .erased let type ← liftMetaM <| Meta.inferType e if (← liftMetaM <| Meta.isProp type) then /- We erase proofs. -/ - return erasedExpr + return .erased if (← isTypeFormerType type) then /- Predicates are erased (e.g., `Eq`) -/ if isPredicateType (← toLCNFType type) then - return erasedExpr + return .erased else /- Types and Type formers are not put into A-normal form -/ - toLCNFType e + return .type (← toLCNFType e) else - visitCore e + match (← visitCore e) with + | .erased => return .erased + | e => return .fvar (← mkAuxLetDecl e) /-- Visit args, and return `f args` -/ - visitAppDefault (f : Expr) (args : Array Expr) : M Expr := do - if f.isErased then - return f - else + visitAppDefault (f : LetExpr) (args : Array Expr) : M LetExpr := do + match f with + | .erased => return .erased + | _ => + let f ← mkAuxLetDecl f let args ← args.mapM visitAppArg - mkAuxLetDecl <| mkAppN f args + return .fvar f args + + /-- Giving `f` a constant `.const declName us`, convert `args` into `args'`, and return `.const declName us args'` -/ + visitAppDefaultConst (f : Expr) (args : Array Expr) : M LetExpr := do + let .const declName us := f | unreachable! + let args ← args.mapM visitAppArg + return .const declName us args /-- Eta expand if under applied, otherwise apply k -/ - etaIfUnderApplied (e : Expr) (arity : Nat) (k : M Expr) : M Expr := do + etaIfUnderApplied (e : Expr) (arity : Nat) (k : M LetExpr) : M LetExpr := do let numArgs := e.getAppNumArgs if numArgs < arity then visit (← etaExpandN e (arity - numArgs)) @@ -486,15 +491,15 @@ where k args[arity:] ``` -/ - mkOverApplication (app : Expr) (args : Array Expr) (arity : Nat) : M Expr := do + mkOverApplication (app : LetExpr) (args : Array Expr) (arity : Nat) : M LetExpr := do if args.size == arity then - mkAuxLetDecl app + return app else - let k ← mkAuxLetDecl app - let mut args := args - for i in [arity : args.size] do - args ← args.modifyM i visitAppArg - mkAuxLetDecl (mkAppN k args[arity:]) + let k ← mkAuxLetDecl app + let mut argsNew := #[] + for i in [arity : args.size] do + argsNew := argsNew.push (← visitAppArg args[i]!) + return .fvar k argsNew /-- Visit a `matcher`/`casesOn` alternative. @@ -528,7 +533,7 @@ where let altType ← c.inferType return (altType, .alt ctorName ps c) - visitCases (casesInfo : CasesInfo) (e : Expr) : M Expr := + visitCases (casesInfo : CasesInfo) (e : Expr) : M LetExpr := etaIfUnderApplied e casesInfo.arity do let args := e.getAppArgs let mut resultType ← toLCNFType (← liftMetaM do Meta.inferType (mkAppN e.getAppFn args[:casesInfo.arity])) @@ -540,31 +545,32 @@ where let typeName := casesInfo.declName.getPrefix let discr ← visitAppArg args[casesInfo.discrPos]! let .inductInfo indVal ← getConstInfo typeName | unreachable! - if !discr.isFVar then + match discr with + | .erased | .type .. => /- This can happen for inductive predicates that can eliminate into type (e.g., `And`, `Iff`). TODO: add support for them. Right now, we have hard-coded support for the ones defined at `Init`. -/ throwError "unsupported `{casesInfo.declName}` application during code generation" - else + | .fvar discrFVarId => for i in casesInfo.altsRange, numParams in casesInfo.altNumParams, ctorName in indVal.ctors do let (altType, alt) ← visitAlt ctorName numParams args[i]! resultType := joinTypes altType resultType alts := alts.push alt - let cases : Cases := { typeName, discr := discr.fvarId!, resultType, alts } + let cases : Cases := { typeName, discr := discrFVarId, resultType, alts } let auxDecl ← mkAuxParam resultType pushElement (.cases auxDecl cases) - let result := .fvar auxDecl.fvarId + let result := .fvar auxDecl.fvarId #[] if args.size == casesInfo.arity then return result else mkOverApplication result args casesInfo.arity - visitCtor (arity : Nat) (e : Expr) : M Expr := + visitCtor (arity : Nat) (e : Expr) : M LetExpr := etaIfUnderApplied e arity do - visitAppDefault e.getAppFn e.getAppArgs + visitAppDefaultConst e.getAppFn e.getAppArgs - visitQuotLift (e : Expr) : M Expr := do + visitQuotLift (e : Expr) : M LetExpr := do let arity := 6 etaIfUnderApplied e arity do let mut args := e.getAppArgs @@ -573,11 +579,13 @@ where let f ← visitAppArg args[3]! let q ← visitAppArg args[5]! let .const _ [u, _] := e.getAppFn | unreachable! - let invq ← mkAuxLetDecl (mkApp3 (.const ``Quot.lcInv [u]) α r q) - let r := mkApp f invq - mkOverApplication r args arity + let invq ← mkAuxLetDecl (.const ``Quot.lcInv [u] #[.type α, .type r, q]) + match f with + | .erased => return .erased + | .type _ => unreachable! + | .fvar fvarId => mkOverApplication (.fvar fvarId #[.fvar invq]) args arity - visitEqRec (e : Expr) : M Expr := + visitEqRec (e : Expr) : M LetExpr := let arity := 6 etaIfUnderApplied e arity do let args := e.getAppArgs @@ -585,13 +593,13 @@ where let minor ← visit minor mkOverApplication minor args arity - visitFalseRec (e : Expr) : M Expr := + visitFalseRec (e : Expr) : M LetExpr := let arity := 2 etaIfUnderApplied e arity do let type ← toLCNFType (← liftMetaM do Meta.inferType e) mkUnreachable type - visitAndIffRecCore (e : Expr) (minorPos : Nat) : M Expr := + visitAndIffRecCore (e : Expr) (minorPos : Nat) : M LetExpr := let arity := 5 etaIfUnderApplied e arity do let args := e.getAppArgs @@ -601,7 +609,7 @@ where let minor := minor.beta #[ha, hb] visit (mkAppN minor args[arity:]) - visitNoConfusion (e : Expr) : M Expr := do + visitNoConfusion (e : Expr) : M LetExpr := do let .const declName _ := e.getAppFn | unreachable! let typeName := declName.getPrefix let .inductInfo inductVal ← getConstInfo typeName | unreachable! @@ -635,7 +643,7 @@ where else expandNoConfusionMajor (← etaExpandN major (n+1)) (n+1) - visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M Expr := do + visitProjFn (projInfo : ProjectionFunctionInfo) (e : Expr) : M LetExpr := do let typeName := projInfo.ctorName.getPrefix if isRuntimeBultinType typeName then let numArgs := e.getAppNumArgs @@ -643,14 +651,14 @@ where if numArgs < arity then visit (← etaExpandN e (arity - numArgs)) else - visitAppDefault e.getAppFn e.getAppArgs + visitAppDefaultConst e.getAppFn e.getAppArgs else let .const declName us := e.getAppFn | unreachable! let info ← getConstInfo declName let f ← Core.instantiateValueLevelParams info us visit (f.beta e.getAppArgs) - visitApp (e : Expr) : M Expr := do + visitApp (e : Expr) : M LetExpr := do if let .const declName _ := e.getAppFn then if declName == ``Quot.lift then visitQuotLift e @@ -673,11 +681,11 @@ where else if let some projInfo ← getProjectionFnInfo? declName then visitProjFn projInfo e else - e.withApp visitAppDefault + e.withApp visitAppDefaultConst else e.withApp fun f args => do visitAppDefault (← visit f) args - visitLambda (e : Expr) : M Expr := do + visitLambda (e : Expr) : M LetExpr := do let b := etaReduceImplicit e /- Note: we don't want to eta-reduce arbitrary lambda expressions since it can @@ -711,20 +719,20 @@ where let c ← toCode e mkAuxFunDecl ps c pushElement (.fun funDecl) - return .fvar funDecl.fvarId + return .fvar funDecl.fvarId #[] - visitMData (mdata : MData) (e : Expr) : M Expr := do + visitMData (mdata : MData) (e : Expr) : M LetExpr := do if let some (.app (.lam n t b ..) v) := letFunAnnotation? (.mdata mdata e) then visitLet (.letE n t v b (nonDep := true)) #[] - else if isCompilerRelevantMData mdata then - mkAuxLetDecl <| .mdata mdata (← visit e) else visit e - visitProj (s : Name) (i : Nat) (e : Expr) : M Expr := do - mkAuxLetDecl <| .proj s i (← visit e) + visitProj (s : Name) (i : Nat) (e : Expr) : M LetExpr := do + match (← visit e) with + | .erased => return .erased + | e => return .proj s i (← mkAuxLetDecl e) - visitLet (e : Expr) (xs : Array Expr) : M Expr := do + visitLet (e : Expr) (xs : Array Expr) : M LetExpr := do match e with | .letE binderName type value body _ => let type := type.instantiateRev xs