fix: combine "complete" and constructor transitions
This commit is contained in:
parent
b4b60dc326
commit
8e81c19162
1 changed files with 56 additions and 102 deletions
|
|
@ -339,23 +339,16 @@ p.alts.all fun alt => match alt.patterns with
|
|||
| Pattern.var _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
/- Return true if the next pattern of each remaining alternative is a constructor application -/
|
||||
/- Return true if the next pattern of each remaining alternative is a constructor application or variable -/
|
||||
private def isConstructorTransition (p : Problem) : Bool :=
|
||||
p.alts.all fun alt => match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ :: _ => true
|
||||
| _ => false
|
||||
|
||||
/- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/
|
||||
private def isCompleteTransition (p : Problem) : Bool :=
|
||||
let (ok, hasVar, hasCtor) := p.alts.foldl
|
||||
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
|
||||
let (ok, hasVar, hasCtor) := acc;
|
||||
match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ :: _ => (ok, hasVar, true)
|
||||
| Pattern.var _ :: _ => (ok, true, hasCtor)
|
||||
| _ => (false, hasVar, hasCtor))
|
||||
(true, false, false);
|
||||
ok && hasVar && hasCtor
|
||||
(p.alts.any fun alt => match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ :: _ => true
|
||||
| _ => false)
|
||||
&&
|
||||
(p.alts.all fun alt => match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ :: _ => true
|
||||
| Pattern.var _ :: _ => true
|
||||
| _ => false)
|
||||
|
||||
/- Return true if the next pattern of the remaining alternatives contain variables AND values. -/
|
||||
private def isValueTransition (p : Problem) : Bool :=
|
||||
|
|
@ -436,10 +429,45 @@ match p.vars with
|
|||
| _ => unreachable!;
|
||||
process { p with alts := alts, vars := xs } s
|
||||
|
||||
private def isFirstPatternCtor (ctorName : Name) (alt : Alt) : Bool :=
|
||||
match alt.patterns with
|
||||
| Pattern.ctor n _ _ _ :: _ => n == ctorName
|
||||
| _ => false
|
||||
private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do
|
||||
throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type)
|
||||
|
||||
private def getInductiveUniverseAndParams (type : Expr) : MetaM (List Level × Array Expr) := do
|
||||
env ← getEnv;
|
||||
type ← whnfD type;
|
||||
matchConst env type.getAppFn (fun _ => throwInductiveTypeExpected type) fun info us =>
|
||||
match info with
|
||||
| ConstantInfo.inductInfo val =>
|
||||
let I := type.getAppFn;
|
||||
let Iargs := type.getAppArgs;
|
||||
let params := Iargs.extract 0 val.nparams;
|
||||
pure (us, params)
|
||||
| _ => throwInductiveTypeExpected type
|
||||
|
||||
private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) : MetaM (Option Alt) := do
|
||||
expectedType ← inferType (mkMVar mvarId);
|
||||
(us, params) ← getInductiveUniverseAndParams expectedType;
|
||||
let ctor := mkAppN (mkConst ctorName us) params;
|
||||
ctorType ← inferType ctor;
|
||||
(fieldMVars, _, resultType) ← forallMetaTelescopeReducing ctorType;
|
||||
let ctor := mkAppN ctor fieldMVars;
|
||||
trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
|
||||
isCompatible ← isDefEq resultType expectedType;
|
||||
if !isCompatible then pure none
|
||||
else do
|
||||
let fieldMVars := fieldMVars.toList;
|
||||
assignExprMVar mvarId ctor;
|
||||
rhs ← instantiateMVars alt.rhs;
|
||||
newPatterns ← fieldMVars.mapM fun fieldMVar => do {
|
||||
e ← instantiateMVars fieldMVar;
|
||||
match e with
|
||||
| Expr.mvar mvarId _ => pure (Pattern.var mvarId : IPattern)
|
||||
| _ => pure (Pattern.inaccessible e)
|
||||
};
|
||||
newMVarIds ← fieldMVars.filterMapM fun fieldMVar => condM (isExprMVarAssigned fieldMVar.mvarId!) (pure none) (pure (some fieldMVar.mvarId!));
|
||||
let mvars := (alt.mvars.map fun mvarId' => if mvarId' == mvarId then newMVarIds else [mvarId']).join;
|
||||
mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId;
|
||||
pure $ some { alt with rhs := rhs, mvars := mvars, patterns := newPatterns ++ alt.patterns }
|
||||
|
||||
private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
trace! `Meta.EqnCompiler.match ("constructor step");
|
||||
|
|
@ -458,91 +486,19 @@ match p.vars with
|
|||
| _ => Example.underscore; -- This case can happen due to dependent elimination
|
||||
let examples := p.examples.map $ Example.replaceFVarId x.fvarId! subex;
|
||||
let examples := examples.map $ Example.applyFVarSubst subst;
|
||||
let newAlts := p.alts.filter $ isFirstPatternCtor subgoal.ctorName;
|
||||
let newAlts := newAlts.map fun alt => match alt.patterns with
|
||||
| Pattern.ctor _ _ _ fields :: ps => { alt with patterns := fields ++ ps }
|
||||
| _ => unreachable!;
|
||||
let newAlts := p.alts.filter fun alt => match alt.patterns with
|
||||
| Pattern.ctor n _ _ _ :: _ => n == subgoal.ctorName
|
||||
| Pattern.var _ :: _ => true
|
||||
| _ => false;
|
||||
newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst;
|
||||
newAlts ← newAlts.mapM fun alt => alt.copy;
|
||||
newAlts ← newAlts.filterMapM fun alt => match alt.patterns with
|
||||
| Pattern.ctor _ _ _ fields :: ps => pure $ some { alt with patterns := fields ++ ps }
|
||||
| Pattern.var mvarId :: ps => tryConstructor? { alt with patterns := ps } mvarId subgoal.ctorName
|
||||
| _ => unreachable!;
|
||||
process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s)
|
||||
s
|
||||
|
||||
private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do
|
||||
throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type)
|
||||
|
||||
private partial def tryConstructorAux (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr)
|
||||
: Nat → Array MVarId → Array IPattern → MetaM Alt
|
||||
| i, newMVars, fields => do
|
||||
if h : i < mvars.size then do
|
||||
let mvar := mvars.get ⟨i, h⟩;
|
||||
e ← instantiateMVars mvar;
|
||||
match e with
|
||||
| Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var mvarId))
|
||||
| _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible e))
|
||||
else do
|
||||
let p := Pattern.ctor ctorName us params.toList fields.toList;
|
||||
e ← p.toExpr;
|
||||
assignExprMVar mvarId e;
|
||||
ps ← alt.patterns.mapM Pattern.instantiateMVars;
|
||||
let ps := p :: ps;
|
||||
rhs ← instantiateMVars alt.rhs;
|
||||
unless (alt.mvars.contains mvarId) $
|
||||
throwOther "ill-format alternative"; -- TODO: improve error message
|
||||
let mvars := (alt.mvars.map fun mvarId' => if mvarId' == mvarId then newMVars.toList else [mvarId']).join;
|
||||
mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId;
|
||||
pure { alt with rhs := rhs, mvars := mvars, patterns := ps }
|
||||
|
||||
private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr)
|
||||
: MetaM (Option Alt) := do
|
||||
let ctor := mkAppN (mkConst ctorName us) params;
|
||||
ctorType ← inferType ctor;
|
||||
(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType;
|
||||
trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
|
||||
condM (isDefEq resultType expectedType)
|
||||
(Option.some <$> tryConstructorAux alt mvarId ctorName us params mvars 0 #[] #[])
|
||||
(pure none)
|
||||
|
||||
private def expandAlt (alt : Alt) (mvarId : MVarId) : MetaM (List Alt) := do
|
||||
env ← getEnv;
|
||||
mvarDecl ← getMVarDecl mvarId;
|
||||
let expectedType := mvarDecl.type;
|
||||
expectedType ← whnfD expectedType;
|
||||
matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expectedType) fun info us =>
|
||||
match info with
|
||||
| ConstantInfo.inductInfo val =>
|
||||
val.ctors.foldlM
|
||||
(fun (result : List Alt) ctor => do
|
||||
(mvarSubst, alt) ← alt.copyCore;
|
||||
let mvarId := mvarSubst.find! mvarId;
|
||||
mvarDecl ← getMVarDecl mvarId;
|
||||
let expectedType := mvarDecl.type;
|
||||
expectedType ← whnfD expectedType;
|
||||
let I := expectedType.getAppFn;
|
||||
let Iargs := expectedType.getAppArgs;
|
||||
let params := Iargs.extract 0 val.nparams;
|
||||
alt? ← tryConstructor? alt mvarId ctor us params expectedType;
|
||||
match alt? with
|
||||
| none => pure result
|
||||
| some alt => pure (alt :: result))
|
||||
[]
|
||||
| _ => throwInductiveTypeExpected expectedType
|
||||
|
||||
private def processComplete (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
|
||||
trace! `Meta.EqnCompiler.match ("complete step");
|
||||
withGoalOf p do
|
||||
env ← getEnv;
|
||||
newAlts ← p.alts.foldlM
|
||||
(fun (newAlts : List Alt) alt =>
|
||||
match alt.patterns with
|
||||
| Pattern.ctor _ _ _ _ :: ps => pure (alt :: newAlts)
|
||||
| p@(Pattern.var mvarId) :: ps => do
|
||||
let alt := { alt with patterns := ps };
|
||||
alts ← expandAlt alt mvarId;
|
||||
pure (alts ++ newAlts)
|
||||
| _ => unreachable!)
|
||||
[];
|
||||
process { p with alts := newAlts.reverse } s
|
||||
|
||||
private def collectValues (p : Problem) : Array Expr :=
|
||||
p.alts.foldl
|
||||
(fun (values : Array Expr) alt =>
|
||||
|
|
@ -668,8 +624,6 @@ private partial def process : Problem → State → MetaM State
|
|||
processVariable process p s
|
||||
else if isConstructorTransition p then
|
||||
processConstructor process p s
|
||||
else if isCompleteTransition p then
|
||||
processComplete process p s
|
||||
else if isValueTransition p then
|
||||
processValue process p s
|
||||
else if isArrayLitTransition p then
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue