fix: combine "complete" and constructor transitions

This commit is contained in:
Leonardo de Moura 2020-08-14 10:48:35 -07:00
parent b4b60dc326
commit 8e81c19162

View file

@ -339,23 +339,16 @@ p.alts.all fun alt => match alt.patterns with
| Pattern.var _ :: _ => true
| _ => false
/- Return true if the next pattern of each remaining alternative is a constructor application -/
/- Return true if the next pattern of each remaining alternative is a constructor application or variable -/
private def isConstructorTransition (p : Problem) : Bool :=
p.alts.all fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => true
| _ => false
/- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/
private def isCompleteTransition (p : Problem) : Bool :=
let (ok, hasVar, hasCtor) := p.alts.foldl
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
let (ok, hasVar, hasCtor) := acc;
match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => (ok, hasVar, true)
| Pattern.var _ :: _ => (ok, true, hasCtor)
| _ => (false, hasVar, hasCtor))
(true, false, false);
ok && hasVar && hasCtor
(p.alts.any fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => true
| _ => false)
&&
(p.alts.all fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ :: _ => true
| Pattern.var _ :: _ => true
| _ => false)
/- Return true if the next pattern of the remaining alternatives contain variables AND values. -/
private def isValueTransition (p : Problem) : Bool :=
@ -436,10 +429,45 @@ match p.vars with
| _ => unreachable!;
process { p with alts := alts, vars := xs } s
private def isFirstPatternCtor (ctorName : Name) (alt : Alt) : Bool :=
match alt.patterns with
| Pattern.ctor n _ _ _ :: _ => n == ctorName
| _ => false
private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do
throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type)
private def getInductiveUniverseAndParams (type : Expr) : MetaM (List Level × Array Expr) := do
env ← getEnv;
type ← whnfD type;
matchConst env type.getAppFn (fun _ => throwInductiveTypeExpected type) fun info us =>
match info with
| ConstantInfo.inductInfo val =>
let I := type.getAppFn;
let Iargs := type.getAppArgs;
let params := Iargs.extract 0 val.nparams;
pure (us, params)
| _ => throwInductiveTypeExpected type
private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) : MetaM (Option Alt) := do
expectedType ← inferType (mkMVar mvarId);
(us, params) ← getInductiveUniverseAndParams expectedType;
let ctor := mkAppN (mkConst ctorName us) params;
ctorType ← inferType ctor;
(fieldMVars, _, resultType) ← forallMetaTelescopeReducing ctorType;
let ctor := mkAppN ctor fieldMVars;
trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
isCompatible ← isDefEq resultType expectedType;
if !isCompatible then pure none
else do
let fieldMVars := fieldMVars.toList;
assignExprMVar mvarId ctor;
rhs ← instantiateMVars alt.rhs;
newPatterns ← fieldMVars.mapM fun fieldMVar => do {
e ← instantiateMVars fieldMVar;
match e with
| Expr.mvar mvarId _ => pure (Pattern.var mvarId : IPattern)
| _ => pure (Pattern.inaccessible e)
};
newMVarIds ← fieldMVars.filterMapM fun fieldMVar => condM (isExprMVarAssigned fieldMVar.mvarId!) (pure none) (pure (some fieldMVar.mvarId!));
let mvars := (alt.mvars.map fun mvarId' => if mvarId' == mvarId then newMVarIds else [mvarId']).join;
mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId;
pure $ some { alt with rhs := rhs, mvars := mvars, patterns := newPatterns ++ alt.patterns }
private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
trace! `Meta.EqnCompiler.match ("constructor step");
@ -458,91 +486,19 @@ match p.vars with
| _ => Example.underscore; -- This case can happen due to dependent elimination
let examples := p.examples.map $ Example.replaceFVarId x.fvarId! subex;
let examples := examples.map $ Example.applyFVarSubst subst;
let newAlts := p.alts.filter $ isFirstPatternCtor subgoal.ctorName;
let newAlts := newAlts.map fun alt => match alt.patterns with
| Pattern.ctor _ _ _ fields :: ps => { alt with patterns := fields ++ ps }
| _ => unreachable!;
let newAlts := p.alts.filter fun alt => match alt.patterns with
| Pattern.ctor n _ _ _ :: _ => n == subgoal.ctorName
| Pattern.var _ :: _ => true
| _ => false;
newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst;
newAlts ← newAlts.mapM fun alt => alt.copy;
newAlts ← newAlts.filterMapM fun alt => match alt.patterns with
| Pattern.ctor _ _ _ fields :: ps => pure $ some { alt with patterns := fields ++ ps }
| Pattern.var mvarId :: ps => tryConstructor? { alt with patterns := ps } mvarId subgoal.ctorName
| _ => unreachable!;
process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s)
s
private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do
throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type)
private partial def tryConstructorAux (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr)
: Nat → Array MVarId → Array IPattern → MetaM Alt
| i, newMVars, fields => do
if h : i < mvars.size then do
let mvar := mvars.get ⟨i, h⟩;
e ← instantiateMVars mvar;
match e with
| Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var mvarId))
| _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible e))
else do
let p := Pattern.ctor ctorName us params.toList fields.toList;
e ← p.toExpr;
assignExprMVar mvarId e;
ps ← alt.patterns.mapM Pattern.instantiateMVars;
let ps := p :: ps;
rhs ← instantiateMVars alt.rhs;
unless (alt.mvars.contains mvarId) $
throwOther "ill-format alternative"; -- TODO: improve error message
let mvars := (alt.mvars.map fun mvarId' => if mvarId' == mvarId then newMVars.toList else [mvarId']).join;
mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId;
pure { alt with rhs := rhs, mvars := mvars, patterns := ps }
private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr)
: MetaM (Option Alt) := do
let ctor := mkAppN (mkConst ctorName us) params;
ctorType ← inferType ctor;
(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType;
trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
condM (isDefEq resultType expectedType)
(Option.some <$> tryConstructorAux alt mvarId ctorName us params mvars 0 #[] #[])
(pure none)
private def expandAlt (alt : Alt) (mvarId : MVarId) : MetaM (List Alt) := do
env ← getEnv;
mvarDecl ← getMVarDecl mvarId;
let expectedType := mvarDecl.type;
expectedType ← whnfD expectedType;
matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expectedType) fun info us =>
match info with
| ConstantInfo.inductInfo val =>
val.ctors.foldlM
(fun (result : List Alt) ctor => do
(mvarSubst, alt) ← alt.copyCore;
let mvarId := mvarSubst.find! mvarId;
mvarDecl ← getMVarDecl mvarId;
let expectedType := mvarDecl.type;
expectedType ← whnfD expectedType;
let I := expectedType.getAppFn;
let Iargs := expectedType.getAppArgs;
let params := Iargs.extract 0 val.nparams;
alt? ← tryConstructor? alt mvarId ctor us params expectedType;
match alt? with
| none => pure result
| some alt => pure (alt :: result))
[]
| _ => throwInductiveTypeExpected expectedType
private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
trace! `Meta.EqnCompiler.match ("complete step");
withGoalOf p do
env ← getEnv;
newAlts ← p.alts.foldlM
(fun (newAlts : List Alt) alt =>
match alt.patterns with
| Pattern.ctor _ _ _ _ :: ps => pure (alt :: newAlts)
| p@(Pattern.var mvarId) :: ps => do
let alt := { alt with patterns := ps };
alts ← expandAlt alt mvarId;
pure (alts ++ newAlts)
| _ => unreachable!)
[];
process { p with alts := newAlts.reverse } s
private def collectValues (p : Problem) : Array Expr :=
p.alts.foldl
(fun (values : Array Expr) alt =>
@ -668,8 +624,6 @@ private partial def process : Problem → State → MetaM State
processVariable process p s
else if isConstructorTransition p then
processConstructor process p s
else if isCompleteTransition p then
processComplete process p s
else if isValueTransition p then
processValue process p s
else if isArrayLitTransition p then