diff --git a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean index 2af1d8aff2..62463832b5 100644 --- a/src/Lean/Elab/PreDefinition/Structural/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Structural/Eqns.lean @@ -19,10 +19,109 @@ structure EqnInfo where recArgPos : Nat deriving Inhabited -def mkEqns (info : EqnInfo) : MetaM (Array Name) := do +private partial def expand : Expr → Expr + | Expr.letE _ t v b _ => expand (b.instantiate1 v) + | Expr.mdata _ b _ => expand b + | e => e + +private def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do + let target ← instantiateMVars (← getMVarType mvarId) + let some (_, lhs, rhs) ← target.eq? | return none + unless rhs.isLet || rhs.isMData do return none + return some (← replaceTargetDefEq mvarId (← mkEq lhs (expand rhs))) + +private def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do + let target ← getMVarType mvarId + let some (_, lhs, rhs) ← target.eq? | return none + unless rhs.isLambda do return none + commitWhenSome? do + let [mvarId] ← apply mvarId (← mkConstWithFreshMVarLevels ``funext) | return none + let (_, mvarId) ← intro1 mvarId + return some mvarId + +private def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do + let mvarId' ← Split.simpMatchTarget mvarId + if mvarId != mvarId' then return some mvarId' else return none + +/-- + Return true if the right-hand-side is matching on one of the variables in + the recursion position at the left-hand-side. + Example: returns true for + ``` + f ys (x :: xs) = match xs with ... + ``` + if `recArgPos == 1` +-/ +private def matchRecArg (mvarId : MVarId) (recArgPos : Nat) : MetaM Bool := do + let target ← instantiateMVars (← getMVarType mvarId) + let some (_, lhs, rhs) ← target.eq? | return false + let lhsArgs := lhs.getAppArgs + if h : recArgPos < lhsArgs.size then + let recArg := lhsArgs.get ⟨recArgPos, h⟩ + let recFVarSet := collectFVars {} recArg |>.fvarSet + let env ← getEnv + return Option.isSome <| rhs.find? fun e => do + if let some info := isMatcherAppCore? env e then + let args := e.getAppArgs + for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do + let discr := args[i] + if recFVarSet.any discr.containsFVar then + return true + return false + else + return false + else + return true -- conservative answer + +private def saveEqn (mvarId : MVarId) : StateRefT (Array Expr) MetaM Unit := withMVarContext mvarId do + let target ← instantiateMVars (← getMVarType mvarId) + let fvarIds := collectFVars {} target |>.fvarSet.toArray + let (_, mvarId) ← revert mvarId fvarIds + let type ← instantiateMVars (← getMVarType mvarId) + modify (·.push type) + +private partial def mkEqnTypes (mvarId : MVarId) : ReaderT EqnInfo (StateRefT (Array Expr) MetaM) Unit := do + if !(← matchRecArg mvarId (← read).recArgPos) then + saveEqn mvarId + else if let some mvarId ← expandRHS? mvarId then + mkEqnTypes mvarId + else if let some mvarId ← funext? mvarId then + mkEqnTypes mvarId + else if let some mvarId ← simpMatch? mvarId then + mkEqnTypes mvarId + else if let some mvarIds ← splitTarget? mvarId then + mvarIds.forM mkEqnTypes + else + saveEqn mvarId + +/-- Create a "unique" base name for equations and splitter -/ +private def mkBaseNameFor (env : Environment) (declName : Name) : Name := + Lean.mkBaseNameFor env declName `eq_1 `_eqns + +private def mkProof (type : Expr) : MetaM Expr := do -- TODO - trace[Elab.definition.structural.eqns] "mkEqns:\n{info.value}" - return #[] + mkSorry type false + +def mkEqns (info : EqnInfo) : MetaM (Array Name) := do + withOptions (tactic.hygienic.set . false) do + lambdaTelescope info.value fun xs body => do + let us := info.levelParams.map mkLevelParam + let target ← mkEq (mkAppN (Lean.mkConst info.declName us) xs) body + let goal ← mkFreshExprSyntheticOpaqueMVar target + let (_, eqnTypes) ← mkEqnTypes goal.mvarId! |>.run info |>.run #[] + let baseName := mkBaseNameFor (← getEnv) info.declName + let mut thmNames := #[] + for i in [: eqnTypes.size] do + let type := eqnTypes[i] + trace[Elab.definition.structural.eqns] "{eqnTypes[i]}" + let name := baseName ++ (`eq).appendIndexAfter (i+1) + thmNames := thmNames.push name + let value ← mkProof type + addDecl <| Declaration.thmDecl { + name, type, value + levelParams := info.levelParams + } + return thmNames builtin_initialize eqnInfoExt : MapDeclarationExtension EqnInfo ← mkMapDeclarationExtension `structEqInfo diff --git a/tests/lean/run/structuralEqns.lean b/tests/lean/run/structuralEqns.lean new file mode 100644 index 0000000000..126711802f --- /dev/null +++ b/tests/lean/run/structuralEqns.lean @@ -0,0 +1,42 @@ +import Lean + +set_option trace.Elab.definition.structural.eqns true + +open Lean +open Lean.Meta +def tst (declName : Name) : MetaM Unit := do + IO.println (← getEqnsFor? declName) + +#eval tst ``List.map +#check @List.map.eq_1 +#check @List.map.eq_2 + +def foo (xs ys zs : List Nat) : List Nat := + match (xs, ys) with + | (xs', ys') => + match zs with + | z::zs => foo xs ys zs + | _ => match ys' with + | [] => [1] + | _ => [2] + +#eval tst ``foo + +#check foo.eq_1 +#check foo.eq_2 + +#eval tst ``foo + +def g : List Nat → List Nat → Nat + | [], y::ys => y + | [], ys => 0 + | [x1], ys => g [] ys + | x::xs, y::ys => g xs ys + y + | x::xs, [] => g xs [] + +#eval tst ``g +#check g.eq_1 +#check g.eq_2 +#check g.eq_3 +#check g.eq_4 +#check g.eq_5