summaryrefslogtreecommitdiffstats
path: root/src/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/Network.cpp126
1 files changed, 102 insertions, 24 deletions
diff --git a/src/Network.cpp b/src/Network.cpp
index cf4fe15..9cb2097 100644
--- a/src/Network.cpp
+++ b/src/Network.cpp
@@ -1,5 +1,7 @@
#include "Network.hpp"
+#include <zlib.h>
+
Network::Network(std::string address, unsigned short port) {
try {
socket = new Socket(address, port);
@@ -13,6 +15,8 @@ Network::Network(std::string address, unsigned short port) {
} catch (std::exception &e) {
LOG(WARNING) << "Stream creation failed: " << e.what();
}
+
+
}
Network::~Network() {
@@ -20,40 +24,114 @@ Network::~Network() {
delete socket;
}
-std::shared_ptr<Packet> Network::ReceivePacket(ConnectionState state) {
- int packetSize = stream->ReadVarInt();
- auto packetData = stream->ReadByteArray(packetSize);
- StreamBuffer streamBuffer(packetData.data(), packetData.size());
- int packetId = streamBuffer.ReadVarInt();
- auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer);
- return packet;
+std::shared_ptr<Packet> Network::ReceivePacket(ConnectionState state, bool useCompression) {
+ if (useCompression) {
+ int packetLength = stream->ReadVarInt();
+ auto packetData = stream->ReadByteArray(packetLength);
+ StreamBuffer streamBuffer(packetData.data(), packetData.size());
+
+ int dataLength = streamBuffer.ReadVarInt();
+ if (dataLength == 0) {
+ auto packetData = streamBuffer.ReadByteArray(packetLength - streamBuffer.GetReadedLength());
+ StreamBuffer streamBuffer(packetData.data(), packetData.size());
+ int packetId = streamBuffer.ReadVarInt();
+ auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer);
+ return packet;
+ } else {
+ std::vector<unsigned char> compressedData = streamBuffer.ReadByteArray(packetLength - streamBuffer.GetReadedLength());
+ std::vector<unsigned char> uncompressedData;
+ uncompressedData.resize(dataLength);
+
+ z_stream stream;
+ stream.avail_in = compressedData.size();
+ stream.next_in = compressedData.data();
+ stream.avail_out = uncompressedData.size();
+ stream.next_out = uncompressedData.data();
+ stream.zalloc = Z_NULL;
+ stream.zfree = Z_NULL;
+ stream.opaque = Z_NULL;
+ if (inflateInit(&stream) != Z_OK)
+ throw std::runtime_error("Zlib decompression initalization error");
+
+ int status = inflate(&stream, Z_FINISH);
+ switch (status) {
+ case Z_STREAM_END:
+ break;
+ case Z_OK:
+ case Z_STREAM_ERROR:
+ case Z_BUF_ERROR:
+ throw std::runtime_error("Zlib decompression error: " + std::to_string(status));
+ }
+
+ if (inflateEnd(&stream) != Z_OK)
+ throw std::runtime_error("Zlib decompression end error");
+
+ StreamBuffer streamBuffer(uncompressedData.data(), uncompressedData.size());
+ int packetId = streamBuffer.ReadVarInt();
+ auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer);
+ return packet;
+ }
+ } else {
+ int packetSize = stream->ReadVarInt();
+ auto packetData = stream->ReadByteArray(packetSize);
+ StreamBuffer streamBuffer(packetData.data(), packetData.size());
+ int packetId = streamBuffer.ReadVarInt();
+ auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer);
+ return packet;
+ }
}
-void Network::SendPacket(Packet &packet) {
- StreamCounter packetSize;
- packetSize.WriteVarInt(packet.GetPacketId());
- packet.ToStream(&packetSize);
- stream->WriteVarInt(packetSize.GetCountedSize());
- stream->WriteVarInt(packet.GetPacketId());
- packet.ToStream(stream);
+void Network::SendPacket(Packet &packet, int compressionThreshold) {
+ if (compressionThreshold >= 0) {
+ StreamCounter packetSize;
+ packetSize.WriteVarInt(packet.GetPacketId());
+ packetSize.WriteVarInt(0);
+ packet.ToStream(&packetSize);
+ if (packetSize.GetCountedSize() < compressionThreshold) {
+ stream->WriteVarInt(packetSize.GetCountedSize());
+ stream->WriteVarInt(0);
+ stream->WriteVarInt(packet.GetPacketId());
+ packet.ToStream(stream);
+ } else {
+ throw std::runtime_error("Compressing data");
+ /*StreamBuffer buffer(packetSize.GetCountedSize());
+ packet.ToStream(&buffer);
+
+ z_stream stream;*/
+ }
+ }
+ else {
+ StreamCounter packetSize;
+ packetSize.WriteVarInt(packet.GetPacketId());
+ packet.ToStream(&packetSize);
+ stream->WriteVarInt(packetSize.GetCountedSize());
+ stream->WriteVarInt(packet.GetPacketId());
+ packet.ToStream(stream);
+ }
}
std::shared_ptr<Packet> Network::ReceivePacketByPacketId(int packetId, ConnectionState state, StreamInput &stream) {
std::shared_ptr < Packet > packet(nullptr);
switch (state) {
case Handshaking:
- switch (packetId) {
- case PacketNameHandshakingCB::Handshake:
- packet = std::make_shared<PacketHandshake>();
- break;
- }
+ switch (packetId) {
+ case PacketNameHandshakingCB::Handshake:
+ packet = std::make_shared<PacketHandshake>();
+ break;
+ }
break;
case Login:
- switch (packetId) {
- case PacketNameLoginCB::LoginSuccess:
- packet = std::make_shared<PacketLoginSuccess>();
- break;
- }
+ switch (packetId) {
+ case PacketNameLoginCB::LoginSuccess:
+ packet = std::make_shared<PacketLoginSuccess>();
+ break;
+ case PacketNameLoginCB::SetCompression:
+ packet = std::make_shared<PacketSetCompression>();
+ break;
+ case PacketNameLoginCB::Disconnect:
+ packet = std::make_shared<PacketDisconnect>();
+ break;
+ }
break;
case Play:
packet = ParsePacketPlay((PacketNamePlayCB) packetId);