diff --git a/src/Lean/Meta/EqnCompiler/DepElim.lean b/src/Lean/Meta/EqnCompiler/DepElim.lean index 710048effc..2c6dda88e6 100644 --- a/src/Lean/Meta/EqnCompiler/DepElim.lean +++ b/src/Lean/Meta/EqnCompiler/DepElim.lean @@ -339,23 +339,16 @@ p.alts.all fun alt => match alt.patterns with | Pattern.var _ :: _ => true | _ => false -/- Return true if the next pattern of each remaining alternative is a constructor application -/ +/- Return true if the next pattern of each remaining alternative is a constructor application or variable -/ private def isConstructorTransition (p : Problem) : Bool := -p.alts.all fun alt => match alt.patterns with - | Pattern.ctor _ _ _ _ :: _ => true - | _ => false - -/- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/ -private def isCompleteTransition (p : Problem) : Bool := -let (ok, hasVar, hasCtor) := p.alts.foldl - (fun (acc : Bool × Bool × Bool) (alt : Alt) => - let (ok, hasVar, hasCtor) := acc; - match alt.patterns with - | Pattern.ctor _ _ _ _ :: _ => (ok, hasVar, true) - | Pattern.var _ :: _ => (ok, true, hasCtor) - | _ => (false, hasVar, hasCtor)) - (true, false, false); -ok && hasVar && hasCtor +(p.alts.any fun alt => match alt.patterns with + | Pattern.ctor _ _ _ _ :: _ => true + | _ => false) +&& +(p.alts.all fun alt => match alt.patterns with + | Pattern.ctor _ _ _ _ :: _ => true + | Pattern.var _ :: _ => true + | _ => false) /- Return true if the next pattern of the remaining alternatives contain variables AND values. -/ private def isValueTransition (p : Problem) : Bool := @@ -436,10 +429,45 @@ match p.vars with | _ => unreachable!; process { p with alts := alts, vars := xs } s -private def isFirstPatternCtor (ctorName : Name) (alt : Alt) : Bool := -match alt.patterns with -| Pattern.ctor n _ _ _ :: _ => n == ctorName -| _ => false +private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do +throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type) + +private def getInductiveUniverseAndParams (type : Expr) : MetaM (List Level × Array Expr) := do +env ← getEnv; +type ← whnfD type; +matchConst env type.getAppFn (fun _ => throwInductiveTypeExpected type) fun info us => + match info with + | ConstantInfo.inductInfo val => + let I := type.getAppFn; + let Iargs := type.getAppArgs; + let params := Iargs.extract 0 val.nparams; + pure (us, params) + | _ => throwInductiveTypeExpected type + +private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) : MetaM (Option Alt) := do +expectedType ← inferType (mkMVar mvarId); +(us, params) ← getInductiveUniverseAndParams expectedType; +let ctor := mkAppN (mkConst ctorName us) params; +ctorType ← inferType ctor; +(fieldMVars, _, resultType) ← forallMetaTelescopeReducing ctorType; +let ctor := mkAppN ctor fieldMVars; +trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType); +isCompatible ← isDefEq resultType expectedType; +if !isCompatible then pure none +else do + let fieldMVars := fieldMVars.toList; + assignExprMVar mvarId ctor; + rhs ← instantiateMVars alt.rhs; + newPatterns ← fieldMVars.mapM fun fieldMVar => do { + e ← instantiateMVars fieldMVar; + match e with + | Expr.mvar mvarId _ => pure (Pattern.var mvarId : IPattern) + | _ => pure (Pattern.inaccessible e) + }; + newMVarIds ← fieldMVars.filterMapM fun fieldMVar => condM (isExprMVarAssigned fieldMVar.mvarId!) (pure none) (pure (some fieldMVar.mvarId!)); + let mvars := (alt.mvars.map fun mvarId' => if mvarId' == mvarId then newMVarIds else [mvarId']).join; + mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId; + pure $ some { alt with rhs := rhs, mvars := mvars, patterns := newPatterns ++ alt.patterns } private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do trace! `Meta.EqnCompiler.match ("constructor step"); @@ -458,91 +486,19 @@ match p.vars with | _ => Example.underscore; -- This case can happen due to dependent elimination let examples := p.examples.map $ Example.replaceFVarId x.fvarId! subex; let examples := examples.map $ Example.applyFVarSubst subst; - let newAlts := p.alts.filter $ isFirstPatternCtor subgoal.ctorName; - let newAlts := newAlts.map fun alt => match alt.patterns with - | Pattern.ctor _ _ _ fields :: ps => { alt with patterns := fields ++ ps } - | _ => unreachable!; + let newAlts := p.alts.filter fun alt => match alt.patterns with + | Pattern.ctor n _ _ _ :: _ => n == subgoal.ctorName + | Pattern.var _ :: _ => true + | _ => false; newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst; newAlts ← newAlts.mapM fun alt => alt.copy; + newAlts ← newAlts.filterMapM fun alt => match alt.patterns with + | Pattern.ctor _ _ _ fields :: ps => pure $ some { alt with patterns := fields ++ ps } + | Pattern.var mvarId :: ps => tryConstructor? { alt with patterns := ps } mvarId subgoal.ctorName + | _ => unreachable!; process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s) s -private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do -throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type) - -private partial def tryConstructorAux (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr) - : Nat → Array MVarId → Array IPattern → MetaM Alt -| i, newMVars, fields => do - if h : i < mvars.size then do - let mvar := mvars.get ⟨i, h⟩; - e ← instantiateMVars mvar; - match e with - | Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var mvarId)) - | _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible e)) - else do - let p := Pattern.ctor ctorName us params.toList fields.toList; - e ← p.toExpr; - assignExprMVar mvarId e; - ps ← alt.patterns.mapM Pattern.instantiateMVars; - let ps := p :: ps; - rhs ← instantiateMVars alt.rhs; - unless (alt.mvars.contains mvarId) $ - throwOther "ill-format alternative"; -- TODO: improve error message - let mvars := (alt.mvars.map fun mvarId' => if mvarId' == mvarId then newMVars.toList else [mvarId']).join; - mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId; - pure { alt with rhs := rhs, mvars := mvars, patterns := ps } - -private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr) - : MetaM (Option Alt) := do -let ctor := mkAppN (mkConst ctorName us) params; -ctorType ← inferType ctor; -(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType; -trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType); -condM (isDefEq resultType expectedType) - (Option.some <$> tryConstructorAux alt mvarId ctorName us params mvars 0 #[] #[]) - (pure none) - -private def expandAlt (alt : Alt) (mvarId : MVarId) : MetaM (List Alt) := do -env ← getEnv; -mvarDecl ← getMVarDecl mvarId; -let expectedType := mvarDecl.type; -expectedType ← whnfD expectedType; -matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expectedType) fun info us => - match info with - | ConstantInfo.inductInfo val => - val.ctors.foldlM - (fun (result : List Alt) ctor => do - (mvarSubst, alt) ← alt.copyCore; - let mvarId := mvarSubst.find! mvarId; - mvarDecl ← getMVarDecl mvarId; - let expectedType := mvarDecl.type; - expectedType ← whnfD expectedType; - let I := expectedType.getAppFn; - let Iargs := expectedType.getAppArgs; - let params := Iargs.extract 0 val.nparams; - alt? ← tryConstructor? alt mvarId ctor us params expectedType; - match alt? with - | none => pure result - | some alt => pure (alt :: result)) - [] - | _ => throwInductiveTypeExpected expectedType - -private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do -trace! `Meta.EqnCompiler.match ("complete step"); -withGoalOf p do -env ← getEnv; -newAlts ← p.alts.foldlM - (fun (newAlts : List Alt) alt => - match alt.patterns with - | Pattern.ctor _ _ _ _ :: ps => pure (alt :: newAlts) - | p@(Pattern.var mvarId) :: ps => do - let alt := { alt with patterns := ps }; - alts ← expandAlt alt mvarId; - pure (alts ++ newAlts) - | _ => unreachable!) - []; -process { p with alts := newAlts.reverse } s - private def collectValues (p : Problem) : Array Expr := p.alts.foldl (fun (values : Array Expr) alt => @@ -668,8 +624,6 @@ private partial def process : Problem → State → MetaM State processVariable process p s else if isConstructorTransition p then processConstructor process p s - else if isCompleteTransition p then - processComplete process p s else if isValueTransition p then processValue process p s else if isArrayLitTransition p then