chore: remove ref from patterns

We don't use them to report errors. We only need `ref` at `Alt`
This commit is contained in:
Leonardo de Moura 2020-08-13 12:31:32 -07:00
parent a6b22728ca
commit 145a3dddca
2 changed files with 110 additions and 118 deletions

View file

@ -18,99 +18,91 @@ namespace DepElim
abbrev VarId := Name
inductive Pattern (internal : Bool := false) : Type
| inaccessible (ref : Syntax) (e : Expr) : Pattern
| var (ref : Syntax) (varId : VarId) : Pattern
| ctor (ref : Syntax) (ctorName : Name) (us : List Level) (params : List Expr) (fields : List Pattern) : Pattern
| val (ref : Syntax) (e : Expr) : Pattern
| arrayLit (ref : Syntax) (type : Expr) (xs : List Pattern) : Pattern
| as (ref : Syntax) (varId : VarId) (p : Pattern) : Pattern
| inaccessible (e : Expr) : Pattern
| var (varId : VarId) : Pattern
| ctor (ctorName : Name) (us : List Level) (params : List Expr) (fields : List Pattern) : Pattern
| val (e : Expr) : Pattern
| arrayLit (type : Expr) (xs : List Pattern) : Pattern
| as (varId : VarId) (p : Pattern) : Pattern
abbrev IPattern := Pattern true
namespace Pattern
instance {b} : Inhabited (Pattern b) := ⟨Pattern.inaccessible Syntax.missing (arbitrary _)⟩
def ref {b : Bool} : Pattern b → Syntax
| inaccessible r _ => r
| var r _ => r
| ctor r _ _ _ _ => r
| val r _ => r
| arrayLit r _ _ => r
| as r _ _ => r
instance {b} : Inhabited (Pattern b) := ⟨Pattern.inaccessible (arbitrary _)⟩
partial def toMessageData {b : Bool} : Pattern b → MessageData
| inaccessible _ e => ".(" ++ e ++ ")"
| var _ varId => if b then mkMVar varId else mkFVar varId
| ctor _ ctorName _ _ [] => ctorName
| ctor _ ctorName _ _ pats => "(" ++ ctorName ++ pats.foldl (fun (msg : MessageData) pat => msg ++ " " ++ toMessageData pat) Format.nil ++ ")"
| val _ e => "val!(" ++ e ++ ")"
| arrayLit _ _ pats => "#[" ++ MessageData.joinSep (pats.map toMessageData) ", " ++ "]"
| as _ varId p => (if b then mkMVar varId else mkFVar varId) ++ "@" ++toMessageData p
| inaccessible e => ".(" ++ e ++ ")"
| var varId => if b then mkMVar varId else mkFVar varId
| ctor ctorName _ _ [] => ctorName
| ctor ctorName _ _ pats => "(" ++ ctorName ++ pats.foldl (fun (msg : MessageData) pat => msg ++ " " ++ toMessageData pat) Format.nil ++ ")"
| val e => "val!(" ++ e ++ ")"
| arrayLit _ pats => "#[" ++ MessageData.joinSep (pats.map toMessageData) ", " ++ "]"
| as varId p => (if b then mkMVar varId else mkFVar varId) ++ "@" ++toMessageData p
partial def toExpr {b} : Pattern b → MetaM Expr
| inaccessible _ e => pure e
| var _ varId => if b then pure (mkMVar varId) else pure (mkFVar varId)
| val _ e => pure e
| as _ _ p => toExpr p
| arrayLit _ type xs => do
| inaccessible e => pure e
| var varId => if b then pure (mkMVar varId) else pure (mkFVar varId)
| val e => pure e
| as _ p => toExpr p
| arrayLit type xs => do
xs ← xs.mapM toExpr;
mkArrayLit type xs
| ctor _ ctorName us params fields => do
| ctor ctorName us params fields => do
fields ← fields.mapM toExpr;
pure $ mkAppN (mkConst ctorName us) (params ++ fields).toArray
/- Apply the free variable substitution `s` to the given (internal) pattern -/
partial def applyFVarSubst (s : FVarSubst) : Pattern true → IPattern
| inaccessible r e => inaccessible r $ s.apply e
| ctor r n us ps fs => ctor r n us (ps.map s.apply) $ fs.map applyFVarSubst
| val r e => val r $ s.apply e
| arrayLit r t xs => arrayLit r (s.apply t) $ xs.map applyFVarSubst
| var r id => var r id
| as r v p => as r v $ applyFVarSubst p
| inaccessible e => inaccessible $ s.apply e
| ctor n us ps fs => ctor n us (ps.map s.apply) $ fs.map applyFVarSubst
| val e => val $ s.apply e
| arrayLit t xs => arrayLit (s.apply t) $ xs.map applyFVarSubst
| var id => var id
| as v p => as v $ applyFVarSubst p
partial def instantiateMVars : IPattern → MetaM IPattern
| inaccessible r e => inaccessible r <$> Meta.instantiateMVars e
| ctor r n us ps fs => ctor r n us <$> ps.mapM Meta.instantiateMVars <*> fs.mapM instantiateMVars
| val r e => val r <$> Meta.instantiateMVars e
| arrayLit r t xs => arrayLit r <$> Meta.instantiateMVars t <*> xs.mapM instantiateMVars
| var ref mvarId => do
| inaccessible e => inaccessible <$> Meta.instantiateMVars e
| ctor n us ps fs => ctor n us <$> ps.mapM Meta.instantiateMVars <*> fs.mapM instantiateMVars
| val e => val <$> Meta.instantiateMVars e
| arrayLit t xs => arrayLit <$> Meta.instantiateMVars t <*> xs.mapM instantiateMVars
| var mvarId => do
mctx ← getMCtx;
match mctx.getExprAssignment? mvarId with
| some v => inaccessible ref <$> Meta.instantiateMVars v
| none => pure (var ref mvarId)
| as ref mvarId p => do
| some v => inaccessible <$> Meta.instantiateMVars v
| none => pure (var mvarId)
| as mvarId p => do
mctx ← getMCtx;
match mctx.getExprAssignment? mvarId with
| some v => instantiateMVars p
| none => as ref mvarId <$> instantiateMVars p
| none => as mvarId <$> instantiateMVars p
partial def applyMVarRenaming (m : MVarRenaming) : Pattern true → IPattern
| inaccessible r e => inaccessible r $ m.apply e
| ctor r n us ps fs => ctor r n us (ps.map m.apply) $ fs.map applyMVarRenaming
| val r e => val r $ m.apply e
| arrayLit r t xs => arrayLit r (m.apply t) $ xs.map applyMVarRenaming
| var ref mvarId =>
| inaccessible e => inaccessible $ m.apply e
| ctor n us ps fs => ctor n us (ps.map m.apply) $ fs.map applyMVarRenaming
| val e => val $ m.apply e
| arrayLit t xs => arrayLit (m.apply t) $ xs.map applyMVarRenaming
| var mvarId =>
match m.find? mvarId with
| some newMVarId => var ref newMVarId
| none => var ref mvarId
| as ref mvarId p =>
| some newMVarId => var newMVarId
| none => var mvarId
| as mvarId p =>
match m.find? mvarId with
| some newMVarId => as ref newMVarId $ applyMVarRenaming p
| none => as ref mvarId $ applyMVarRenaming p
| some newMVarId => as newMVarId $ applyMVarRenaming p
| none => as mvarId $ applyMVarRenaming p
partial def toIPattern (s : FVarSubst) : Pattern → IPattern
| inaccessible r e => inaccessible r $ s.apply e
| ctor r n us ps fs => ctor r n us (ps.map s.apply) $ fs.map toIPattern
| val r e => val r $ s.apply e
| arrayLit r t xs => arrayLit r (s.apply t) $ xs.map toIPattern
| var ref fvarId =>
| inaccessible e => inaccessible $ s.apply e
| ctor n us ps fs => ctor n us (ps.map s.apply) $ fs.map toIPattern
| val e => val $ s.apply e
| arrayLit t xs => arrayLit (s.apply t) $ xs.map toIPattern
| var fvarId =>
match s.get fvarId with
| Expr.mvar mvarId _ => Pattern.var ref mvarId
| Expr.mvar mvarId _ => Pattern.var mvarId
| _ => unreachable!
| as ref fvarId p =>
| as fvarId p =>
match s.get fvarId with
| Expr.mvar mvarId _ => Pattern.as ref mvarId $ toIPattern p
| Expr.mvar mvarId _ => Pattern.as mvarId $ toIPattern p
| _ => unreachable!
end Pattern
@ -340,21 +332,21 @@ match p.vars with
private def hasAsPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| Pattern.as _ _ _ :: _ => true
| _ => false
| Pattern.as _ _ :: _ => true
| _ => false
/- Return true if the next pattern of each remaining alternative is an inaccessible term or a variable -/
private def isVariableTransition (p : Problem) : Bool :=
p.alts.all fun alt => match alt.patterns with
| Pattern.inaccessible _ _ :: _ => true
| Pattern.var _ _ :: _ => true
| _ => false
| Pattern.inaccessible _ :: _ => true
| Pattern.var _ :: _ => true
| _ => false
/- Return true if the next pattern of each remaining alternative is a constructor application -/
private def isConstructorTransition (p : Problem) : Bool :=
p.alts.all fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ _ :: _ => true
| _ => false
| Pattern.ctor _ _ _ _ :: _ => true
| _ => false
/- Return true if the next pattern of the remaining alternatives contain variables AND constructors. -/
private def isCompleteTransition (p : Problem) : Bool :=
@ -362,9 +354,9 @@ let (ok, hasVar, hasCtor) := p.alts.foldl
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
let (ok, hasVar, hasCtor) := acc;
match alt.patterns with
| Pattern.ctor _ _ _ _ _ :: _ => (ok, hasVar, true)
| Pattern.var _ _ :: _ => (ok, true, hasCtor)
| _ => (false, hasVar, hasCtor))
| Pattern.ctor _ _ _ _ :: _ => (ok, hasVar, true)
| Pattern.var _ :: _ => (ok, true, hasCtor)
| _ => (false, hasVar, hasCtor))
(true, false, false);
ok && hasVar && hasCtor
@ -374,9 +366,9 @@ let (ok, hasVar, hasVal) := p.alts.foldl
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
let (ok, hasVar, hasVal) := acc;
match alt.patterns with
| Pattern.val _ _ :: _ => (ok, hasVar, true)
| Pattern.var _ _ :: _ => (ok, true, hasVal)
| _ => (false, hasVar, hasVal))
| Pattern.val _ :: _ => (ok, hasVar, true)
| Pattern.var _ :: _ => (ok, true, hasVal)
| _ => (false, hasVar, hasVal))
(true, false, false);
ok && hasVar && hasVal
@ -386,9 +378,9 @@ let (ok, hasVar, hasArray) := p.alts.foldl
(fun (acc : Bool × Bool × Bool) (alt : Alt) =>
let (ok, hasVar, hasArray) := acc;
match alt.patterns with
| Pattern.arrayLit _ _ _ :: _ => (ok, hasVar, true)
| Pattern.var _ _ :: _ => (ok, true, hasArray)
| _ => (false, hasVar, hasArray))
| Pattern.arrayLit _ _ :: _ => (ok, hasVar, true)
| Pattern.var _ :: _ => (ok, true, hasArray)
| _ => (false, hasVar, hasArray))
(true, false, false);
ok && hasVar && hasArray
@ -418,7 +410,7 @@ match p.vars with
| [] => unreachable!
| x :: xs => do
alts ← p.alts.mapM fun alt => match alt.patterns with
| Pattern.as _ mvarId p :: ps => do
| Pattern.as mvarId p :: ps => do
assignExprMVar mvarId x;
rhs ← instantiateMVars alt.rhs;
let mvars := alt.mvars.erase mvarId;
@ -434,8 +426,8 @@ match p.vars with
| [] => unreachable!
| x :: xs => do
alts ← p.alts.mapM fun alt => match alt.patterns with
| Pattern.inaccessible _ _ :: ps => pure { alt with patterns := ps }
| Pattern.var _ mvarId :: ps => do
| Pattern.inaccessible _ :: ps => pure { alt with patterns := ps }
| Pattern.var mvarId :: ps => do
-- trace! `Meta.EqnCompiler.matchDebug (">> assign " ++ mkMVar mvarId ++ " := " ++ x);
assignExprMVar mvarId x;
rhs ← instantiateMVars alt.rhs;
@ -449,8 +441,8 @@ match p.vars with
private def isFirstPatternCtor (ctorName : Name) (alt : Alt) : Bool :=
match alt.patterns with
| Pattern.ctor _ n _ _ _ :: _ => n == ctorName
| _ => false
| Pattern.ctor n _ _ _ :: _ => n == ctorName
| _ => false
private def processConstructor (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
trace! `Meta.EqnCompiler.match ("constructor step");
@ -471,8 +463,8 @@ match p.vars with
let examples := examples.map $ Example.applyFVarSubst subst;
let newAlts := p.alts.filter $ isFirstPatternCtor subgoal.ctorName;
let newAlts := newAlts.map fun alt => match alt.patterns with
| Pattern.ctor _ _ _ _ fields :: ps => { alt with patterns := fields ++ ps }
| _ => unreachable!;
| Pattern.ctor _ _ _ fields :: ps => { alt with patterns := fields ++ ps }
| _ => unreachable!;
newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst;
newAlts ← newAlts.mapM fun alt => alt.copy;
process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s)
@ -481,17 +473,17 @@ match p.vars with
private def throwInductiveTypeExpected {α} (type : Expr) : MetaM α := do
throwOther ("failed to compile pattern matching, inductive type expected" ++ indentExpr type)
private partial def tryConstructorAux (alt : Alt) (ref : Syntax) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr)
private partial def tryConstructorAux (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (mvars : Array Expr)
: Nat → Array MVarId → Array IPattern → MetaM Alt
| i, newMVars, fields => do
if h : i < mvars.size then do
let mvar := mvars.get ⟨i, h⟩;
e ← instantiateMVars mvar;
match e with
| Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var ref mvarId))
| _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible ref e))
| Expr.mvar mvarId _ => tryConstructorAux (i+1) (newMVars.push mvarId) (fields.push (Pattern.var mvarId))
| _ => tryConstructorAux (i+1) newMVars (fields.push (Pattern.inaccessible e))
else do
let p := Pattern.ctor ref ctorName us params.toList fields.toList;
let p := Pattern.ctor ctorName us params.toList fields.toList;
e ← p.toExpr;
assignExprMVar mvarId e;
ps ← alt.patterns.mapM Pattern.instantiateMVars;
@ -503,17 +495,17 @@ private partial def tryConstructorAux (alt : Alt) (ref : Syntax) (mvarId : MVarI
mvars ← mvars.filterM fun mvarId => not <$> isExprMVarAssigned mvarId;
pure { alt with rhs := rhs, mvars := mvars, patterns := ps }
private def tryConstructor? (alt : Alt) (ref : Syntax) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr)
private def tryConstructor? (alt : Alt) (mvarId : MVarId) (ctorName : Name) (us : List Level) (params : Array Expr) (expectedType : Expr)
: MetaM (Option Alt) := do
let ctor := mkAppN (mkConst ctorName us) params;
ctorType ← inferType ctor;
(mvars, _, resultType) ← forallMetaTelescopeReducing ctorType;
trace! `Meta.EqnCompiler.matchDebug ("ctorName: " ++ ctorName ++ ", resultType: " ++ resultType ++ ", expectedType: " ++ expectedType);
condM (isDefEq resultType expectedType)
(Option.some <$> tryConstructorAux alt ref mvarId ctorName us params mvars 0 #[] #[])
(Option.some <$> tryConstructorAux alt mvarId ctorName us params mvars 0 #[] #[])
(pure none)
private def expandAlt (alt : Alt) (ref : Syntax) (mvarId : MVarId) : MetaM (List Alt) := do
private def expandAlt (alt : Alt) (mvarId : MVarId) : MetaM (List Alt) := do
env ← getEnv;
mvarDecl ← getMVarDecl mvarId;
let expectedType := mvarDecl.type;
@ -531,7 +523,7 @@ matchConst env expectedType.getAppFn (fun _ => throwInductiveTypeExpected expect
let I := expectedType.getAppFn;
let Iargs := expectedType.getAppArgs;
let params := Iargs.extract 0 val.nparams;
alt? ← tryConstructor? alt ref mvarId ctor us params expectedType;
alt? ← tryConstructor? alt mvarId ctor us params expectedType;
match alt? with
| none => pure result
| some alt => pure (alt :: result))
@ -545,10 +537,10 @@ env ← getEnv;
newAlts ← p.alts.foldlM
(fun (newAlts : List Alt) alt =>
match alt.patterns with
| Pattern.ctor _ _ _ _ _ :: ps => pure (alt :: newAlts)
| p@(Pattern.var ref mvarId) :: ps => do
| Pattern.ctor _ _ _ _ :: ps => pure (alt :: newAlts)
| p@(Pattern.var mvarId) :: ps => do
let alt := { alt with patterns := ps };
alts ← expandAlt alt ref mvarId;
alts ← expandAlt alt mvarId;
pure (alts ++ newAlts)
| _ => unreachable!)
[];
@ -558,14 +550,14 @@ private def collectValues (p : Problem) : Array Expr :=
p.alts.foldl
(fun (values : Array Expr) alt =>
match alt.patterns with
| Pattern.val _ v :: _ => if values.contains v then values else values.push v
| _ => values)
| Pattern.val v :: _ => if values.contains v then values else values.push v
| _ => values)
#[]
private def isFirstPatternVar (alt : Alt) : Bool :=
match alt.patterns with
| Pattern.var _ _ :: _ => true
| _ => false
| Pattern.var _ :: _ => true
| _ => false
private def processValue (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
trace! `Meta.EqnCompiler.match ("value step");
@ -584,14 +576,14 @@ match p.vars with
let examples := p.examples.map $ Example.replaceFVarId x.fvarId! (Example.val value);
let examples := examples.map $ Example.applyFVarSubst subst;
let newAlts := p.alts.filter fun alt => match alt.patterns with
| Pattern.val _ v :: _ => v == value
| Pattern.var _ _ :: _ => true
| Pattern.val v :: _ => v == value
| Pattern.var _ :: _ => true
| _ => false;
newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst;
newAlts ← newAlts.mapM fun alt => alt.copy;
newAlts ← newAlts.mapM fun alt => match alt.patterns with
| Pattern.val _ _ :: ps => pure { alt with patterns := ps }
| Pattern.var _ mvarId :: ps => do
| Pattern.val _ :: ps => pure { alt with patterns := ps }
| Pattern.var mvarId :: ps => do
assignExprMVar mvarId value;
ps ← ps.mapM Pattern.instantiateMVars;
rhs ← instantiateMVars alt.rhs;
@ -611,8 +603,8 @@ private def collectArraySizes (p : Problem) : Array Nat :=
p.alts.foldl
(fun (sizes : Array Nat) alt =>
match alt.patterns with
| Pattern.arrayLit _ _ ps :: _ => let sz := ps.length; if sizes.contains sz then sizes else sizes.push sz
| _ => sizes)
| Pattern.arrayLit _ ps :: _ => let sz := ps.length; if sizes.contains sz then sizes else sizes.push sz
| _ => sizes)
#[]
private def processArrayLit (process : Problem → State → MetaM State) (p : Problem) (s : State) : MetaM State := do
@ -635,14 +627,14 @@ match p.vars with
let examples := p.examples.map $ Example.replaceFVarId x.fvarId! subex;
let examples := examples.map $ Example.applyFVarSubst subst;
let newAlts := p.alts.filter fun alt => match alt.patterns with
| Pattern.arrayLit _ _ ps :: _ => ps.length == size
| Pattern.var _ _ :: _ => true
| _ => false;
| Pattern.arrayLit _ ps :: _ => ps.length == size
| Pattern.var _ :: _ => true
| _ => false;
newAlts ← newAlts.mapM fun alt => alt.applyFVarSubst subst;
newAlts ← newAlts.mapM fun alt => alt.copy;
newAlts ← newAlts.mapM fun alt => match alt.patterns with
| Pattern.arrayLit _ _ pats :: ps => pure { alt with patterns := pats ++ ps }
| Pattern.var ref mvarId :: ps => do
| Pattern.arrayLit _ pats :: ps => pure { alt with patterns := pats ++ ps }
| Pattern.var mvarId :: ps => do
α ← getArrayArgType x;
newMVars ← size.foldM
(fun _ (newMVars : List Expr) => do
@ -655,7 +647,7 @@ match p.vars with
rhs ← instantiateMVars alt.rhs;
let mvars := alt.mvars.erase mvarId;
let mvars := newMVars.map Expr.mvarId! ++ mvars;
let ps := newMVars.map (fun mvar => Pattern.var ref mvar.mvarId!) ++ ps;
let ps := newMVars.map (fun mvar => Pattern.var mvar.mvarId!) ++ ps;
pure { alt with rhs := rhs, mvars := mvars, patterns := ps }
| _ => unreachable!;
process { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples } s

View file

@ -45,16 +45,16 @@ match e with
partial def mkPattern : Expr → MetaM Pattern
| e =>
if e.isAppOfArity `val 2 then
pure $ Pattern.val Syntax.missing e.appArg!
pure $ Pattern.val e.appArg!
else if e.isAppOfArity `inaccessible 2 then
pure $ Pattern.inaccessible Syntax.missing e.appArg!
pure $ Pattern.inaccessible e.appArg!
else if e.isFVar then
pure $ Pattern.var Syntax.missing e.fvarId!
pure $ Pattern.var e.fvarId!
else if e.isAppOfArity `As 3 && (e.getArg! 1).isFVar then do
let v := e.getArg! 1;
let p := e.getArg! 2;
p ← mkPattern p;
pure $ Pattern.as Syntax.missing v.fvarId! p
pure $ Pattern.as v.fvarId! p
else if e.isAppOfArity `ArrayLit0 1 ||
e.isAppOfArity `ArrayLit1 2 ||
e.isAppOfArity `ArrayLit2 3 ||
@ -64,14 +64,14 @@ partial def mkPattern : Expr → MetaM Pattern
let type := args.get! 0;
let ps := args.extract 1 args.size;
ps ← ps.toList.mapM mkPattern;
pure $ Pattern.arrayLit Syntax.missing type ps
pure $ Pattern.arrayLit type ps
else match e.arrayLit? with
| some es => do
pats ← es.mapM mkPattern;
type ← inferType e;
type ← whnfD type;
let elemType := type.appArg!;
pure $ Pattern.arrayLit Syntax.missing elemType pats
pure $ Pattern.arrayLit elemType pats
| none => do
e ← whnfD e;
r? ← constructorApp? e;
@ -81,7 +81,7 @@ partial def mkPattern : Expr → MetaM Pattern
let params := args.extract 0 cval.nparams;
let fields := args.extract cval.nparams args.size;
pats ← fields.toList.mapM mkPattern;
pure $ Pattern.ctor Syntax.missing cval.name fn.constLevels! params.toList pats
pure $ Pattern.ctor cval.name fn.constLevels! params.toList pats
partial def decodePats : Expr → MetaM (List Pattern)
| e =>