Her er mine tanker.
Jeg ved ikke om det er en god loesning, men den er ihvertfald meget objektorienteret.
:-)
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace Protocol
{
namespace Packets
{
public enum PacketType { Version = 1, Foo = 2, Bar = 3 }
public abstract class Packet
{
public abstract byte Length { get; }
public abstract PacketType Type { get; }
}
public enum Version { V1 = 1, V2 = 2 }
public class VersionPacket : Packet
{
public override byte Length { get { return 2; } }
public override PacketType Type { get { return PacketType.Version; } }
public Version Version { get; set; }
}
public class BarPacket : Packet
{
public override byte Length { get { return 4; } }
public override PacketType Type { get { return PacketType.Bar; } }
public int C { get; set; }
}
namespace V1
{
public class FooPacket : Packet
{
public override byte Length { get { return 4; } }
public override PacketType Type { get { return PacketType.Foo; } }
public int A { get; set; }
}
}
namespace V2
{
public class FooPacket : Packet
{
public override byte Length { get { return 6; } }
public override PacketType Type { get { return PacketType.Foo; } }
public int A { get; set; }
public int B { get; set; }
}
}
}
namespace IO
{
using Protocol.Packets;
internal interface IWriter
{
void Write(Packet p);
void Close();
}
public class PacketWriter : IDisposable
{
protected Version ver;
private IWriter real;
public PacketWriter(BinaryWriter bw, Version ver)
{
this.ver = ver;
switch(ver)
{
case Version.V1:
real = new WriterV1(bw);
break;
case Version.V2:
real = new WriterV2(bw);
break;
default:
throw new Exception("Unsupported version");
}
}
public virtual void Write(Packet p)
{
real.Write(p);
}
public void Write(IEnumerable<Packet> plist)
{
foreach(Packet p in plist)
{
Write(p);
}
}
public void Close()
{
real.Close();
}
public void Dispose()
{
Close();
}
}
internal abstract class WriterBase : IWriter
{
protected BinaryWriter bw;
internal WriterBase(BinaryWriter bw)
{
this.bw = bw;
}
protected void WriteHeader(Packet p)
{
bw.Write(p.Length);
bw.Write((byte)p.Type);
}
internal void WriteVersion(VersionPacket vp)
{
WriteHeader(vp);
bw.Write((byte)vp.Version);
}
internal abstract void WriteFoo(Packet p);
internal void WriteBar(BarPacket bp)
{
WriteHeader(bp);
bw.Write((short)bp.C);
}
public void Write(Packet p)
{
switch(p.Type)
{
case PacketType.Version:
WriteVersion((VersionPacket)p);
break;
case PacketType.Foo:
WriteFoo(p);
break;
case PacketType.Bar:
WriteBar((BarPacket)p);
break;
default:
throw new Exception("Unknown packet type");
}
}
public void Close()
{
bw.Close();
}
}
internal class WriterV1 : WriterBase
{
internal WriterV1(BinaryWriter bw) : base(bw)
{
}
internal override void WriteFoo(Packet p)
{
Protocol.Packets.V1.FooPacket fp = (Protocol.Packets.V1.FooPacket)p;
WriteHeader(fp);
bw.Write((short)fp.A);
}
}
internal class WriterV2 : WriterBase
{
internal WriterV2(BinaryWriter bw) : base(bw)
{
}
internal override void WriteFoo(Packet p)
{
Protocol.Packets.V2.FooPacket fp = (Protocol.Packets.V2.FooPacket)p;
WriteHeader(fp);
bw.Write((short)fp.A);
bw.Write((short)fp.B);
}
}
public interface IReader
{
Packet Read();
void Close(bool clsrdr);
}
public class PacketReader : IDisposable
{
private BinaryReader br;
protected Version ver;
private IReader real;
public PacketReader(BinaryReader br, Version ver)
{
this.br = br;
this.ver = ver;
Init();
}
private void Init()
{
switch(ver)
{
case Version.V1:
real = new ReaderV1(br);
break;
case Version.V2:
real = new ReaderV2(br);
break;
default:
throw new Exception("Unsupported version");
}
}
public virtual Packet Read()
{
Packet res = real.Read();
if(res.Type == PacketType.Version)
{
VersionPacket vp = (VersionPacket)res;
real.Close(false);
ver = vp.Version;
Init();
}
return res;
}
public void Close()
{
real.Close(true);
}
public void Dispose()
{
Close();
}
}
internal abstract class ReaderBase : IReader
{
protected BinaryReader br;
internal ReaderBase(BinaryReader br)
{
this.br = br;
}
internal VersionPacket ReadVersion()
{
VersionPacket res = new VersionPacket();
res.Version = (Version)br.ReadByte();
return res;
}
internal abstract Packet ReadFoo();
internal BarPacket ReadBar()
{
BarPacket res = new BarPacket();
res.C = br.ReadInt16();
return res;
}
public Packet Read()
{
int len = br.ReadByte();
PacketType typ = (PacketType)br.ReadByte();
Packet res;
switch(typ)
{
case PacketType.Version:
res = ReadVersion();
break;
case PacketType.Foo:
res = ReadFoo();
break;
case PacketType.Bar:
res = ReadBar();
break;
default:
throw new Exception("Invalid packet type");
}
if(len != res.Length)
{
throw new Exception("Invalid packet length");
}
return res;
}
public void Close(bool clsrdr)
{
if(clsrdr)
{
br.Close();
}
}
}
internal class ReaderV1 : ReaderBase
{
internal ReaderV1(BinaryReader br) : base(br)
{
}
internal override Packet ReadFoo()
{
Protocol.Packets.V1.FooPacket res = new Protocol.Packets.V1.FooPacket();
res.A = br.ReadInt16();
return res;
}
}
internal class ReaderV2 : ReaderBase
{
internal ReaderV2(BinaryReader br) : base(br)
{
}
internal override Packet ReadFoo()
{
Protocol.Packets.V2.FooPacket res = new Protocol.Packets.V2.FooPacket();
res.A = br.ReadInt16();
res.B = br.ReadInt16();
return res;
}
}
}
namespace Validation
{
using Protocol.Packets;
using Protocol.IO;
public interface IPacketValidator
{
bool IsValid(Packet p, Version ver);
}
public delegate bool Validate(Packet p, Version ver);
public class DummyPacketValidator : IPacketValidator
{
public bool IsValid(Packet p, Version ver)
{
return true;
}
}
public class GeneralPacketValidator : IPacketValidator
{
private IDictionary<PacketType, IList<Validate>> rules = new Dictionary<PacketType, IList<Validate>>();
public void Add(PacketType typ, Validate val)
{
if(!rules.ContainsKey(typ))
{
rules.Add(typ, new List<Validate>());
}
rules[typ].Add(val);
}
public bool IsValid(Packet p, Version ver)
{
if(rules.ContainsKey(p.Type))
{
foreach(Validate val in rules[p.Type])
{
if(!val(p, ver)) return false;
}
}
return true;
}
}
public class StandardPacketValidator : GeneralPacketValidator
{
public StandardPacketValidator() : base()
{
Add(PacketType.Version, (p, ver) => { VersionPacket vp = (VersionPacket)p; return Version.V1 <= vp.Version && vp.Version <= Version.V2; });
Add(PacketType.Foo, (p, ver) => { switch(ver)
{
case Version.V1:
Protocol.Packets.V1.FooPacket fp1 = (Protocol.Packets.V1.FooPacket)p;
return 0 <= fp1.A && fp1.A <= 1000;
case Version.V2:
Protocol.Packets.V2.FooPacket fp2 = (Protocol.Packets.V2.FooPacket)p;
return 0 <= fp2.A && fp2.A <= 2000 && 0 <= fp2.B && fp2.B <= 1000;
default:
return false;
}
});
Add(PacketType.Bar, (p, ver) => { BarPacket bp = (BarPacket)p; return (0 <= bp.C && bp.C <= 1000); });
}
}
public interface IPacketStreamValidator
{
bool IsValid(Packet p, Version ver);
}
public class DummyPacketStreamValidator : IPacketStreamValidator
{
public bool IsValid(Packet p, Version ver)
{
return true;
}
}
public class PacketStreamValidator : IPacketStreamValidator
{
private IDictionary<Tuple<Version, PacketType>, ISet<PacketType>> transition = new Dictionary<Tuple<Version, PacketType>, ISet<PacketType>>();
private ISet<PacketType> expected;
private IDictionary<Tuple<Version, PacketType>, int> limit = new Dictionary<Tuple<Version, PacketType>, int>();
private IDictionary<PacketType, int> count = new Dictionary<PacketType, int>();
public PacketStreamValidator(ISet<PacketType> expected)
{
this.expected = expected;
}
public void AddTransition(Version ver, PacketType pt, ISet<PacketType> expected)
{
transition.Add(Tuple.Create(ver, pt), expected);
}
public void AddLimit(Version ver, PacketType pt, int max)
{
limit.Add(Tuple.Create(ver, pt), max);
}
public bool IsValid(Packet p, Version ver)
{
if(!expected.Contains(p.Type))
{
return false;
}
Tuple<Version, PacketType> vpt = Tuple.Create(ver, p.Type);
if(transition.ContainsKey(vpt))
{
expected = transition[vpt];
}
if(!count.ContainsKey(p.Type))
{
count.Add(p.Type, 0);
}
count[p.Type] = count[p.Type] + 1;
if(limit.ContainsKey(vpt) && count[p.Type] > limit[vpt])
{
return false;
}
return true;
}
}
public class StandardPacketStreamValidator : PacketStreamValidator
{
public StandardPacketStreamValidator() : base(new HashSet<PacketType>() { PacketType.Version })
{
AddTransition(Version.V1, PacketType.Version, new HashSet<PacketType>() { PacketType.Foo, PacketType.Bar });
AddTransition(Version.V2, PacketType.Version, new HashSet<PacketType>() { PacketType.Foo, PacketType.Bar });
AddTransition(Version.V1, PacketType.Foo, new HashSet<PacketType>() { PacketType.Version, PacketType.Foo, PacketType.Bar });
AddTransition(Version.V2, PacketType.Foo, new HashSet<PacketType>() { PacketType.Version, PacketType.Foo, PacketType.Bar });
AddTransition(Version.V1, PacketType.Bar, new HashSet<PacketType>() { PacketType.Version, PacketType.Foo, PacketType.Bar });
AddTransition(Version.V2, PacketType.Bar, new HashSet<PacketType>() { PacketType.Version, PacketType.Foo, PacketType.Bar });
AddLimit(Version.V1, PacketType.Version, 1);
AddLimit(Version.V2, PacketType.Version, 1);
}
}
public class ValidatingPacketWriter : PacketWriter
{
private IPacketValidator pval;
private IPacketStreamValidator psval;
public ValidatingPacketWriter(BinaryWriter bw, Version ver, IPacketValidator pval, IPacketStreamValidator psval) : base(bw, ver)
{
this.pval = pval;
this.psval = psval;
}
public override void Write(Packet p)
{
if(pval.IsValid(p, ver))
{
if(psval.IsValid(p, ver))
{
base.Write(p);
}
else
{
throw new Exception(String.Format("Invalid context for packet (type:{0})", p.Type));
}
}
else
{
throw new Exception(String.Format("Invalid packet (type:{0})", p.Type));
}
}
}
public class ValidatingPacketReader : PacketReader
{
private IPacketValidator pval;
private IPacketStreamValidator psval;
public ValidatingPacketReader(BinaryReader br, Version ver, IPacketValidator pval, IPacketStreamValidator psval) : base(br, ver)
{
this.pval = pval;
this.psval = psval;
}
public override Packet Read()
{
Packet p = base.Read();
if(pval.IsValid(p, ver))
{
if(psval.IsValid(p, ver))
{
return p;
}
else
{
throw new Exception(String.Format("Invalid context for packet (type:{0})", p.Type));
}
}
else
{
throw new Exception(String.Format("Invalid packet (type:{0})", p.Type));
}
}
}
}
}