lean4-htt/src/Lean/Compiler/LCNF/StructProjCases.lean
Henrik Böving 57afb23c5c
fix: compilation of projections on non trivial structures (#11340)
This PR fixes a miscompilation when encountering projections of non
trivial structure types.

Closes: #11322
2025-11-24 19:25:03 +00:00

139 lines
5 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Cameron Zwarich
-/
module
prelude
public import Lean.Compiler.LCNF.PrettyPrinter
public section
namespace Lean.Compiler.LCNF
namespace StructProjCases
def findStructCtorInfo? (typeName : Name) : CoreM (Option ConstructorVal) := do
let .inductInfo info ← getConstInfo typeName | return none
let [ctorName] := info.ctors | return none
let some (.ctorInfo ctorInfo) := (← getEnv).find? ctorName | return none
return ctorInfo
def mkFieldParamsForCtorType (ctorType : Expr) (numParams : Nat) (numFields : Nat) :
CompilerM (Array Param) := do
let mut type ← Meta.MetaM.run' <| toLCNFType ctorType
type ← toMonoType type
for _ in *...numParams do
match type with
| .forallE _ _ body _ =>
type := body
| _ => unreachable!
let mut fields := Array.emptyWithCapacity numFields
for _ in *...numFields do
match type with
| .forallE name fieldType body _ =>
let param ← mkParam name fieldType false
fields := fields.push param
type := body
| _ => unreachable!
return fields
structure State where
projMap : Std.HashMap FVarId (Array FVarId) := {}
fvarMap : Std.HashMap FVarId FVarId := {}
abbrev M := StateRefT State CompilerM
def M.run (x : M α) : CompilerM α := do
x.run' {}
def remapFVar (fvarId : FVarId) : M FVarId := do
return (← get).fvarMap[fvarId]?.getD fvarId
mutual
partial def visitCode (code : Code) : M Code := do
match code with
| .let decl k =>
match decl.value with
| .proj typeName i base =>
eraseLetDecl decl
let base ← remapFVar base
if let some projVars := (← get).projMap[base]? then
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId projVars[i]! }
visitCode k
else
let some ctorInfo ← findStructCtorInfo? typeName | panic! "expected struct constructor"
let params ← mkFieldParamsForCtorType ctorInfo.type ctorInfo.numParams ctorInfo.numFields
let fvars := params.map (·.fvarId)
modify fun s => { s with projMap := s.projMap.insert base fvars,
fvarMap := s.fvarMap.insert decl.fvarId fvars[i]! }
let k ← visitCode k
modify fun s => { s with projMap := s.projMap.erase base }
let resultType ← toMonoType (← k.inferType)
let alts := #[.alt ctorInfo.name params k]
return .cases { typeName, resultType, discr := base, alts }
| _ => return code.updateLet! (← decl.updateValue (← visitLetValue decl.value)) (← visitCode k)
| .fun decl k =>
let decl ← decl.updateValue (← visitCode decl.value)
return code.updateFun! decl (← visitCode k)
| .jp decl k =>
let decl ← decl.updateValue (← visitCode decl.value)
return code.updateFun! decl (← visitCode k)
| .jmp fvarId args =>
return code.updateJmp! (← remapFVar fvarId) (← args.mapM visitArg)
| .cases cases =>
let discr ← remapFVar cases.discr
if let #[.alt ctorName params k] := cases.alts then
if let some projVars := (← get).projMap[discr]? then
assert! projVars.size == params.size
for param in params, projVar in projVars do
modify fun s => { s with fvarMap := s.fvarMap.insert param.fvarId projVar }
eraseParam param
visitCode k
else
let fvars := params.map (·.fvarId)
modify fun s => { s with projMap := s.projMap.insert discr fvars }
let k ← visitCode k
modify fun s => { s with projMap := s.projMap.erase discr }
-- TODO: This should preserve the .alt allocation, but binding it to
-- a variable above while also destructuring an array doesn't work.
return code.updateCases! cases.resultType discr #[.alt ctorName params k]
else
let alts ← cases.alts.mapM (visitAlt ·)
return code.updateCases! cases.resultType discr alts
| .return fvarId => return code.updateReturn! (← remapFVar fvarId)
| .unreach .. => return code
partial def visitLetValue (v : LetValue) : M LetValue := do
match v with
| .const _ _ args =>
return v.updateArgs! (← args.mapM visitArg)
| .fvar fvarId args =>
return v.updateFVar! (← remapFVar fvarId) (← args.mapM visitArg)
| .lit _ | .erased => return v
-- Projections should be handled directly by `visitCode`.
| .proj .. => unreachable!
partial def visitAlt (alt : Alt) : M Alt := do
return alt.updateCode (← visitCode alt.getCode)
partial def visitArg (arg : Arg) : M Arg :=
match arg with
| .fvar fvarId => return arg.updateFVar! (← remapFVar fvarId)
| .type _ | .erased => return arg
end
def visitDecl (decl : Decl) : M Decl := do
let value ← decl.value.mapCodeM (visitCode ·)
return { decl with value }
end StructProjCases
def structProjCases : Pass :=
.mkPerDeclaration `structProjCases (StructProjCases.visitDecl · |>.run) .mono
builtin_initialize registerTraceClass `Compiler.structProjCases (inherited := true)
end Lean.Compiler.LCNF