lean4-htt/src/Lean/Compiler/IR/ToIR.lean

341 lines
12 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Cameron Zwarich
-/
module
prelude
public import Lean.Compiler.LCNF.Basic
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PhaseExt
public import Lean.Compiler.IR.Basic
public import Lean.Compiler.IR.CompilerM
public import Lean.Compiler.IR.ToIRType
public import Lean.CoreM
public import Lean.Environment
public section
namespace Lean.IR
open Lean.Compiler (LCNF.Alt LCNF.Arg LCNF.Code LCNF.Decl LCNF.DeclValue LCNF.LCtx LCNF.LetDecl
LCNF.LetValue LCNF.LitValue LCNF.Param LCNF.getMonoDecl?)
namespace ToIR
inductive FVarClassification where
| var (id : VarId)
| joinPoint (id : JoinPointId)
| erased
deriving Inhabited
structure BuilderState where
fvars : Std.HashMap FVarId FVarClassification := {}
nextId : Nat := 1
abbrev M := StateRefT BuilderState CoreM
def M.run (x : M α) : CoreM α := do
x.run' {}
def getFVarValue (fvarId : FVarId) : M FVarClassification := do
return (← get).fvars.get! fvarId
def bindVar (fvarId : FVarId) : M VarId := do
modifyGet fun s =>
let varId := { idx := s.nextId }
⟨varId, { s with fvars := s.fvars.insertIfNew fvarId (.var varId),
nextId := s.nextId + 1 }⟩
def bindVarToVarId (fvarId : FVarId) (varId : VarId) : M Unit := do
modify fun s => { s with fvars := s.fvars.insertIfNew fvarId (.var varId) }
def newVar : M VarId := do
modifyGet fun s =>
let varId := { idx := s.nextId }
⟨varId, { s with nextId := s.nextId + 1 }⟩
def bindJoinPoint (fvarId : FVarId) : M JoinPointId := do
modifyGet fun s =>
let joinPointId := { idx := s.nextId }
⟨joinPointId, { s with fvars := s.fvars.insertIfNew fvarId (.joinPoint joinPointId),
nextId := s.nextId + 1 }⟩
def bindErased (fvarId : FVarId) : M Unit := do
modify fun s => { s with fvars := s.fvars.insertIfNew fvarId .erased }
def findDecl (n : Name) : M (Option Decl) :=
return findEnvDecl (← Lean.getEnv) n
def addDecl (d : Decl) : M Unit :=
Lean.modifyEnv fun env => declMapExt.addEntry env d
def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType :=
match v with
| .nat n =>
let type := if n < UInt32.size then .tagged else .tobject
⟨.num n, type⟩
| .str s => ⟨.str s, .object⟩
| .uint8 v => ⟨.num (UInt8.toNat v), .uint8⟩
| .uint16 v => ⟨.num (UInt16.toNat v), .uint16⟩
| .uint32 v => ⟨.num (UInt32.toNat v), .uint32⟩
| .uint64 v => ⟨.num (UInt64.toNat v), .uint64⟩
| .usize v => ⟨.num (UInt64.toNat v), .usize⟩
def lowerArg (a : LCNF.Arg) : M Arg := do
match a with
| .fvar fvarId =>
match (← getFVarValue fvarId) with
| .var varId => return .var varId
| .erased => return .erased
| .joinPoint .. => panic! "unexpected value"
| .erased | .type .. => return .erased
inductive TranslatedProj where
| expr (e : Expr)
| erased
deriving Inhabited
def lowerProj (base : VarId) (ctorInfo : CtorInfo) (field : CtorFieldInfo)
: TranslatedProj × IRType :=
match field with
| .object i irType => ⟨.expr (.proj i base), irType⟩
| .usize i => ⟨.expr (.uproj i base), .usize⟩
| .scalar _ offset irType => ⟨.expr (.sproj (ctorInfo.size + ctorInfo.usize) offset base), irType⟩
| .erased => ⟨.erased, .erased⟩
def lowerParam (p : LCNF.Param) : M Param := do
let x ← bindVar p.fvarId
let ty ← toIRType p.type
return { x, borrow := p.borrow, ty }
mutual
partial def lowerCode (c : LCNF.Code) : M FnBody := do
match c with
| .let decl k => lowerLet decl k
| .jp decl k =>
let joinPoint ← bindJoinPoint decl.fvarId
let params ← decl.params.mapM lowerParam
let body ← lowerCode decl.value
return .jdecl joinPoint params body (← lowerCode k)
| .jmp fvarId args =>
match (← getFVarValue fvarId) with
| .joinPoint joinPointId =>
return .jmp joinPointId (← args.mapM lowerArg)
| .var .. | .erased => panic! "unexpected value"
| .cases cases =>
match (← getFVarValue cases.discr) with
| .var varId =>
return .case cases.typeName
varId
(← nameToIRType cases.typeName)
(← cases.alts.mapM (lowerAlt varId))
| .erased =>
let #[alt] := cases.alts | panic! "erased inductive should only have one case"
match alt with
| .alt _ params code =>
params.forM (bindErased ·.fvarId)
lowerCode code
| .default code =>
lowerCode code
| .joinPoint .. => panic! "unexpected value"
| .return fvarId =>
let arg := match (← getFVarValue fvarId) with
| .var varId => .var varId
| .erased => .erased
| .joinPoint .. => panic! "unexpected value"
return .ret arg
| .unreach .. => return .unreachable
| .fun .. => panic! "all local functions should be λ-lifted"
partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do
match decl.value with
| .lit litValue =>
let var ← bindVar decl.fvarId
let ⟨litValue, type⟩ := lowerLitValue litValue
return .vdecl var type (.lit litValue) (← lowerCode k)
| .proj typeName i fvarId =>
match (← getFVarValue fvarId) with
| .var varId =>
let some (.inductInfo { ctors := [ctorName], .. }) := (← Lean.getEnv).find? typeName
| panic! "projection of non-structure type"
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
let ⟨result, type⟩ := lowerProj varId ctorInfo fields[i]!
match result with
| .expr e =>
let var ← bindVar decl.fvarId
return .vdecl var type e (← lowerCode k)
| .erased =>
bindErased decl.fvarId
lowerCode k
| .erased =>
bindErased decl.fvarId
lowerCode k
| .joinPoint .. => panic! "unexpected value"
| .const name _ args =>
let irArgs ← args.mapM lowerArg
if let some code ← tryIrDecl? name irArgs then
return code
let env ← Lean.getEnv
match env.find? name with
| some (.ctorInfo ctorVal) =>
if isExtern env name then
return (← mkFap name irArgs)
let type ← nameToIRType ctorVal.induct
if type.isScalar then
let var ← bindVar decl.fvarId
return .vdecl var type (.lit (.num ctorVal.cidx)) (← lowerCode k)
let ⟨ctorInfo, fields⟩ ← getCtorLayout name
let irArgs := irArgs.extract (start := ctorVal.numParams)
let objArgs : Array Arg ← do
let mut result : Array Arg := #[]
for h : i in *...fields.size do
match fields[i] with
| .object .. =>
result := result.push irArgs[i]!
| .usize .. | .scalar .. | .erased => pure ()
pure result
let objVar ← bindVar decl.fvarId
let rec lowerNonObjectFields (_ : Unit) : M FnBody :=
let rec loop (i : Nat) : M FnBody := do
match irArgs[i]? with
| some (.var varId) =>
match fields[i]! with
| .usize usizeIdx =>
let k ← loop (i + 1)
return .uset objVar usizeIdx varId k
| .scalar _ offset argType =>
let k ← loop (i + 1)
return .sset objVar (ctorInfo.size + ctorInfo.usize) offset varId argType k
| .object .. | .erased => loop (i + 1)
| some .erased => loop (i + 1)
| none => lowerCode k
loop 0
return .vdecl objVar ctorInfo.type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ())
| some (.defnInfo ..) | some (.opaqueInfo ..) =>
mkFap name irArgs
| some (.axiomInfo ..) | .some (.quotInfo ..) | .some (.inductInfo ..) | .some (.thmInfo ..) =>
throwNamedError lean.dependsOnNoncomputable f!"'{name}' not supported by code generator; consider marking definition as 'noncomputable'"
| some (.recInfo ..) =>
throwError f!"code generator does not support recursor '{name}' yet, consider using 'match ... with' and/or structural recursion"
| none => panic! "reference to unbound name"
| .fvar fvarId args =>
match (← getFVarValue fvarId) with
| .var id =>
let irArgs ← args.mapM lowerArg
mkAp id irArgs
| .erased => mkErased ()
| .joinPoint .. => panic! "unexpected value"
| .erased => mkErased ()
where
mkVar (v : VarId) : M FnBody := do
bindVarToVarId decl.fvarId v
lowerCode k
mkErased (_ : Unit) : M FnBody := do
bindErased decl.fvarId
lowerCode k
mkFap (name : Name) (args : Array Arg) : M FnBody := do
let var ← bindVar decl.fvarId
let type ← toIRType decl.type
return .vdecl var type (.fap name args) (← lowerCode k)
mkPap (name : Name) (args : Array Arg) : M FnBody := do
let var ← bindVar decl.fvarId
return .vdecl var .object (.pap name args) (← lowerCode k)
mkAp (fnVar : VarId) (args : Array Arg) : M FnBody := do
let var ← bindVar decl.fvarId
let type := (← toIRType decl.type).boxed
return .vdecl var type (.ap fnVar args) (← lowerCode k)
mkOverApplication (name : Name) (numParams : Nat) (args : Array Arg) : M FnBody := do
let var ← bindVar decl.fvarId
let type := (← toIRType decl.type).boxed
let tmpVar ← newVar
let firstArgs := args.extract 0 numParams
let restArgs := args.extract numParams args.size
return .vdecl tmpVar .object (.fap name firstArgs) <|
.vdecl var type (.ap tmpVar restArgs) (← lowerCode k)
tryIrDecl? (name : Name) (args : Array Arg) : M (Option FnBody) := do
if let some decl ← LCNF.getMonoDecl? name then
let numArgs := args.size
let numParams := decl.params.size
if numArgs < numParams then
return some (← mkPap name args)
else if numArgs == numParams then
return some (← mkFap name args)
else
return some (← mkOverApplication name numParams args)
else
return none
partial def lowerAlt (discr : VarId) (a : LCNF.Alt) : M Alt := do
match a with
| .alt ctorName params code =>
let ⟨ctorInfo, fields⟩ ← getCtorLayout ctorName
let lowerParams (params : Array LCNF.Param) (fields : Array CtorFieldInfo) : M FnBody := do
let rec loop (i : Nat) : M FnBody := do
match params[i]?, fields[i]? with
| some param, some field =>
let ⟨result, type⟩ := lowerProj discr ctorInfo field
match result with
| .expr e =>
return .vdecl (← bindVar param.fvarId)
type
e
(← loop (i + 1))
| .erased =>
bindErased param.fvarId
loop (i + 1)
| none, none => lowerCode code
| _, _ => panic! "mismatched fields and params"
loop 0
let body ← lowerParams params fields
return .ctor ctorInfo body
| .default code =>
return .default (← lowerCode code)
end
def lowerResultType (type : Lean.Expr) (arity : Nat) : M IRType :=
toIRType (resultTypeForArity type arity)
where resultTypeForArity (type : Lean.Expr) (arity : Nat) : Lean.Expr :=
if arity == 0 then
type
else
match type with
| .forallE _ _ b _ => resultTypeForArity b (arity - 1)
| .const ``lcErased _ => mkConst ``lcErased
| _ => panic! "invalid arity"
def lowerDecl (d : LCNF.Decl) : M (Option Decl) := do
let params ← d.params.mapM lowerParam
let resultType ← lowerResultType d.type d.params.size
match d.value with
| .code code =>
let body ← lowerCode code
pure <| some <| .fdecl d.name params resultType body {}
| .extern externAttrData =>
if externAttrData.entries.isEmpty then
-- TODO: This matches the behavior of the old compiler, but we should
-- find a better way to handle this.
addDecl (mkDummyExternDecl d.name params resultType)
pure <| none
else
pure <| some <| .extern d.name params resultType externAttrData
end ToIR
def toIR (decls: Array LCNF.Decl) : CoreM (Array Decl) := do
let mut irDecls := #[]
for decl in decls do
if let some irDecl ← ToIR.lowerDecl decl |>.run then
irDecls := irDecls.push irDecl
return irDecls
end Lean.IR