diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index f45a504eb1..423a2a7019 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -36,10 +36,10 @@ namespace Lean.Meta private def isDefEqEtaStruct (a b : Expr) : MetaM Bool := do if !(← getConfig).etaStruct then return false else - matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal _ => - matchConstCtor a.getAppFn (fun _ => go ctorVal) fun _ _ => return false + matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal us => + matchConstCtor a.getAppFn (fun _ => go ctorVal us) fun _ _ => return false where - go ctorVal := do + go ctorVal us := do if ctorVal.numParams + ctorVal.numFields != b.getAppNumArgs then trace[Meta.isDefEq.eta.struct] "failed, insufficient number of arguments at{indentExpr b}" return false @@ -50,9 +50,12 @@ where else if (← isDefEq (← inferType a) (← inferType b)) then checkpointDefEq do let args := b.getAppArgs + let params := args[:ctorVal.numParams].toArray + let info? := getStructureInfo? (← getEnv) ctorVal.induct for i in [ctorVal.numParams : args.size] do - let proj := mkProj ctorVal.induct (i - ctorVal.numParams) a - trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{i - ctorVal.numParams}], {proj} =?= {args[i]}" + let j := i - ctorVal.numParams + let proj ← mkProjFn ctorVal us params j a + trace[Meta.isDefEq.eta.struct] "{a} =?= {b} @ [{j}], {proj} =?= {args[i]}" unless (← isDefEq proj args[i]) do trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, projection{indentExpr proj}\nis not defeq to{indentExpr args[i]}" return false diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 4dad9c83c2..d404b15fce 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -115,6 +115,17 @@ private def toCtorWhenK (recVal : RecursorVal) (major : Expr) : MetaM Expr := do else return major +/-- + Create the `i`th projection `major`. It tries to use the auto-generated projection functions if available. Otherwise falls back + to `Expr.proj`. +-/ +def mkProjFn (ctorVal : ConstructorVal) (us : List Level) (params : Array Expr) (i : Nat) (major : Expr) : CoreM Expr := do + match getStructureInfo? (← getEnv) ctorVal.induct with + | none => return mkProj ctorVal.induct i major + | some info => match info.getProjFn? i with + | none => return mkProj ctorVal.induct i major + | some projFn => return mkApp (mkAppN (mkConst projFn us) params) major + /-- If `major` is not a constructor application, and its type is a structure `C ...`, then return `C.mk major.1 ... major.n` @@ -142,9 +153,10 @@ private def toCtorWhenStructure (inductName : Name) (major : Expr) : MetaM Expr else let some ctorName ← getFirstCtor d | pure major let ctorInfo ← getConstInfoCtor ctorName - let mut result := mkAppN (mkConst ctorName us) (majorType.getAppArgs.shrink ctorInfo.numParams) + let params := majorType.getAppArgs.shrink ctorInfo.numParams + let mut result := mkAppN (mkConst ctorName us) params for i in [:ctorInfo.numFields] do - result := mkApp result (mkProj inductName i major) + result := mkApp result (← mkProjFn ctorInfo us params i major) return result | _ => return major diff --git a/src/Lean/Structure.lean b/src/Lean/Structure.lean index b43a90b3f2..282e8c1144 100644 --- a/src/Lean/Structure.lean +++ b/src/Lean/Structure.lean @@ -30,6 +30,13 @@ structure StructureInfo where def StructureInfo.lt (i₁ i₂ : StructureInfo) : Bool := Name.quickLt i₁.structName i₂.structName +def StructureInfo.getProjFn? (info : StructureInfo) (i : Nat) : Option Name := + if h : i < info.fieldNames.size then + let fieldName := info.fieldNames.get ⟨i, h⟩ + info.fieldInfo.binSearch { fieldName := fieldName, projFn := default, subobject? := none, binderInfo := default, inferMod := false } StructureFieldInfo.lt |>.map (·.projFn) + else + none + /-- Auxiliary state for structures defined in the current module. -/ private structure StructureState where map : Std.PersistentHashMap Name StructureInfo := {} diff --git a/tests/lean/run/primProjEtaIssue.lean b/tests/lean/run/primProjEtaIssue.lean new file mode 100644 index 0000000000..582981a337 --- /dev/null +++ b/tests/lean/run/primProjEtaIssue.lean @@ -0,0 +1,11 @@ +example (f : Fin n → Prop) (h : ∀ i h, i = 0 → f ⟨i, h⟩) : f i := by + apply h + rw [show i.1 = 0 from sorry] + +def foo (x : Fin n) : Nat := + match x with + | ⟨i, _⟩ => 5 + i + +example (x : Fin n) : foo x = 5 := by + simp [foo] + rw [show x.1 = 0 from sorry]