fix: make sure "eta for structures" in the elaborator uses projection functions if available
This commit is contained in:
parent
1add9b814b
commit
3bdb385c19
4 changed files with 40 additions and 7 deletions
|
|
@ -36,10 +36,10 @@ namespace Lean.Meta
|
|||
private def isDefEqEtaStruct (a b : Expr) : MetaM Bool := do
|
||||
if !(← getConfig).etaStruct then return false
|
||||
else
|
||||
matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal _ =>
|
||||
matchConstCtor a.getAppFn (fun _ => go ctorVal) fun _ _ => return false
|
||||
matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal us =>
|
||||
matchConstCtor a.getAppFn (fun _ => go ctorVal us) fun _ _ => return false
|
||||
where
|
||||
go ctorVal := do
|
||||
go ctorVal us := do
|
||||
if ctorVal.numParams + ctorVal.numFields != b.getAppNumArgs then
|
||||
trace[Meta.isDefEq.eta.struct] "failed, insufficient number of arguments at{indentExpr b}"
|
||||
return false
|
||||
|
|
@ -50,9 +50,12 @@ where
|
|||
else if (← isDefEq (← inferType a) (← inferType b)) then
|
||||
checkpointDefEq do
|
||||
let args := b.getAppArgs
|
||||
let params := args[:ctorVal.numParams].toArray
|
||||
let info? := getStructureInfo? (← getEnv) ctorVal.induct
|
||||
for i in [ctorVal.numParams : args.size] do
|
||||
let proj := mkProj ctorVal.induct (i - ctorVal.numParams) a
|
||||
trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{i - ctorVal.numParams}], {proj} =?= {args[i]}"
|
||||
let j := i - ctorVal.numParams
|
||||
let proj ← mkProjFn ctorVal us params j a
|
||||
trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{j}], {proj} =?= {args[i]}"
|
||||
unless (← isDefEq proj args[i]) do
|
||||
trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, projection{indentExpr proj}\nis not defeq to{indentExpr args[i]}"
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -115,6 +115,17 @@ private def toCtorWhenK (recVal : RecursorVal) (major : Expr) : MetaM Expr := do
|
|||
else
|
||||
return major
|
||||
|
||||
/--
|
||||
Create the `i`th projection `major`. It tries to use the auto-generated projection functions if available. Otherwise falls back
|
||||
to `Expr.proj`.
|
||||
-/
|
||||
def mkProjFn (ctorVal : ConstructorVal) (us : List Level) (params : Array Expr) (i : Nat) (major : Expr) : CoreM Expr := do
|
||||
match getStructureInfo? (← getEnv) ctorVal.induct with
|
||||
| none => return mkProj ctorVal.induct i major
|
||||
| some info => match info.getProjFn? i with
|
||||
| none => return mkProj ctorVal.induct i major
|
||||
| some projFn => return mkApp (mkAppN (mkConst projFn us) params) major
|
||||
|
||||
/--
|
||||
If `major` is not a constructor application, and its type is a structure `C ...`, then return `C.mk major.1 ... major.n`
|
||||
|
||||
|
|
@ -142,9 +153,10 @@ private def toCtorWhenStructure (inductName : Name) (major : Expr) : MetaM Expr
|
|||
else
|
||||
let some ctorName ← getFirstCtor d | pure major
|
||||
let ctorInfo ← getConstInfoCtor ctorName
|
||||
let mut result := mkAppN (mkConst ctorName us) (majorType.getAppArgs.shrink ctorInfo.numParams)
|
||||
let params := majorType.getAppArgs.shrink ctorInfo.numParams
|
||||
let mut result := mkAppN (mkConst ctorName us) params
|
||||
for i in [:ctorInfo.numFields] do
|
||||
result := mkApp result (mkProj inductName i major)
|
||||
result := mkApp result (← mkProjFn ctorInfo us params i major)
|
||||
return result
|
||||
| _ => return major
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,13 @@ structure StructureInfo where
|
|||
def StructureInfo.lt (i₁ i₂ : StructureInfo) : Bool :=
|
||||
Name.quickLt i₁.structName i₂.structName
|
||||
|
||||
def StructureInfo.getProjFn? (info : StructureInfo) (i : Nat) : Option Name :=
|
||||
if h : i < info.fieldNames.size then
|
||||
let fieldName := info.fieldNames.get ⟨i, h⟩
|
||||
info.fieldInfo.binSearch { fieldName := fieldName, projFn := default, subobject? := none, binderInfo := default, inferMod := false } StructureFieldInfo.lt |>.map (·.projFn)
|
||||
else
|
||||
none
|
||||
|
||||
/-- Auxiliary state for structures defined in the current module. -/
|
||||
private structure StructureState where
|
||||
map : Std.PersistentHashMap Name StructureInfo := {}
|
||||
|
|
|
|||
11
tests/lean/run/primProjEtaIssue.lean
Normal file
11
tests/lean/run/primProjEtaIssue.lean
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
example (f : Fin n → Prop) (h : ∀ i h, i = 0 → f ⟨i, h⟩) : f i := by
|
||||
apply h
|
||||
rw [show i.1 = 0 from sorry]
|
||||
|
||||
def foo (x : Fin n) : Nat :=
|
||||
match x with
|
||||
| ⟨i, _⟩ => 5 + i
|
||||
|
||||
example (x : Fin n) : foo x = 5 := by
|
||||
simp [foo]
|
||||
rw [show x.1 = 0 from sorry]
|
||||
Loading…
Add table
Reference in a new issue