WingMan – Rev 34

Subversion Repositories:
Rev:
using System;
using System.IO;
using System.IO.Compression;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using LZ4;
using MQTTnet;
using MQTTnet.Client;
using MQTTnet.Extensions.ManagedClient;
using MQTTnet.Protocol;
using MQTTnet.Server;
using WingMan.Utilities;
using MqttClientConnectedEventArgs = MQTTnet.Client.MqttClientConnectedEventArgs;
using MqttClientDisconnectedEventArgs = MQTTnet.Client.MqttClientDisconnectedEventArgs;

namespace WingMan.Communication
{
    public class MqttCommunication : IDisposable
    {
        public delegate void ClientAuthenticationFailed(object sender, MqttAuthenticationFailureEventArgs e);

        public delegate void ClientConnected(object sender, MqttClientConnectedEventArgs e);

        public delegate void ClientConnectionFailed(object sender, MqttManagedProcessFailedEventArgs e);

        public delegate void ClientDisconnected(object sender, MqttClientDisconnectedEventArgs e);

        public delegate void ClientSubscribed(object sender, MqttClientSubscribedTopicEventArgs e);

        public delegate void ClientUnsubscribed(object sender, MqttClientUnsubscribedTopicEventArgs e);

        public delegate void MessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e);

        public delegate void ServerAuthenticationFailed(object sender, MqttAuthenticationFailureEventArgs e);

        public delegate void ServerClientConnected(object sender, MQTTnet.Server.MqttClientConnectedEventArgs e);

        public delegate void ServerClientDisconnected(object sender, MQTTnet.Server.MqttClientDisconnectedEventArgs e);

        public delegate void ServerStarted(object sender, EventArgs e);

        public delegate void ServerStopped(object sender, EventArgs e);

        public MqttCommunication(TaskScheduler taskScheduler, CancellationToken cancellationToken)
        {
            TaskScheduler = taskScheduler;
            CancellationToken = cancellationToken;

            Client = new MqttFactory().CreateManagedMqttClient();
            Server = new MqttFactory().CreateMqttServer();
        }

        private TaskScheduler TaskScheduler { get; }

        private IManagedMqttClient Client { get; }

        private IMqttServer Server { get; }

        public bool Running { get; set; }

        public string Nick { get; set; }

        private IPAddress IpAddress { get; set; }

        private int Port { get; set; }

        private string Password { get; set; }

        private CancellationToken CancellationToken { get; }

        public MqttCommunicationType Type { get; set; }

        public async void Dispose()
        {
            await Stop();
        }

        public event ClientAuthenticationFailed OnClientAuthenticationFailed;

        public event ServerAuthenticationFailed OnServerAuthenticationFailed;

        public event MessageReceived OnMessageReceived;

        public event ClientConnected OnClientConnected;

        public event ClientDisconnected OnClientDisconnected;

        public event ClientConnectionFailed OnClientConnectionFailed;

        public event ClientUnsubscribed OnClientUnsubscribed;

        public event ClientSubscribed OnClientSubscribed;

        public event ServerClientDisconnected OnServerClientDisconnected;

        public event ServerClientConnected OnServerClientConnected;

        public event ServerStarted OnServerStarted;

        public event ServerStopped OnServerStopped;

        public async Task<bool> Start(MqttCommunicationType type, IPAddress ipAddress, int port, string nick,
            string password)
        {
            Type = type;
            IpAddress = ipAddress;
            Port = port;
            Nick = nick;
            Password = password;

            switch (type)
            {
                case MqttCommunicationType.Client:
                    return await StartClient();
                case MqttCommunicationType.Server:
                    return await StartServer();
            }

            return false;
        }

        private async Task<bool> StartClient()
        {
            var clientOptions = new MqttClientOptionsBuilder()
                .WithTcpServer(IpAddress.ToString(), Port);

            // Setup and start a managed MQTT client.
            var options = new ManagedMqttClientOptionsBuilder()
                .WithClientOptions(clientOptions.Build())
                .Build();

            BindClientHandlers();

            await Client.SubscribeAsync(
                new TopicFilterBuilder()
                    .WithTopic("lobby")
                    .Build()
            );

            await Client.SubscribeAsync(
                new TopicFilterBuilder()
                    .WithTopic("exchange")
                    .Build()
            );

            await Client.SubscribeAsync(
                new TopicFilterBuilder()
                    .WithTopic("execute")
                    .Build()
            );

            await Client.StartAsync(options);

            Running = true;

            return Running;
        }

        private async Task StopClient()
        {
            UnbindClientHandlers();

            await Client.StopAsync();
        }

        public void BindClientHandlers()
        {
            Client.Connected += ClientOnConnected;
            Client.Disconnected += ClientOnDisconnected;
            Client.ConnectingFailed += ClientOnConnectingFailed;
            Client.ApplicationMessageReceived += ClientOnApplicationMessageReceived;
        }

        public void UnbindClientHandlers()
        {
            Client.Connected -= ClientOnConnected;
            Client.Disconnected -= ClientOnDisconnected;
            Client.ConnectingFailed -= ClientOnConnectingFailed;
            Client.ApplicationMessageReceived -= ClientOnApplicationMessageReceived;
        }

        private async void ClientOnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e)
        {
            try
            {
                using (var inputStream = new MemoryStream(e.ApplicationMessage.Payload))
                {
                    using (var decryptedStream = await AES.Decrypt(inputStream, Password))
                    {
                        using (var lz4Decompress = new LZ4Stream(decryptedStream, CompressionMode.Decompress))
                        {
                            using (var outpuStream = new MemoryStream())
                            {
                                await lz4Decompress.CopyToAsync(outpuStream);

                                outpuStream.Position = 0L;

                                e.ApplicationMessage.Payload = outpuStream.ToArray();
                            }
                        }
                    }
                }

                await Task.Delay(0, CancellationToken).ContinueWith(_ => OnMessageReceived?.Invoke(sender, e),
                    CancellationToken, TaskContinuationOptions.None, TaskScheduler);
            }
            catch (Exception ex)
            {
                await Task.Delay(0, CancellationToken).ContinueWith(
                    _ => OnClientAuthenticationFailed?.Invoke(sender, new MqttAuthenticationFailureEventArgs(e, ex)),
                    CancellationToken, TaskContinuationOptions.None, TaskScheduler);
            }
        }

        private async void ClientOnConnectingFailed(object sender, MqttManagedProcessFailedEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnClientConnectionFailed?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ClientOnDisconnected(object sender, MqttClientDisconnectedEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnClientDisconnected?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ClientOnConnected(object sender, MqttClientConnectedEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnClientConnected?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async Task<bool> StartServer()
        {
            var optionsBuilder = new MqttServerOptionsBuilder()
                .WithDefaultEndpointBoundIPAddress(IpAddress)
                .WithDefaultEndpointPort(Port);

            BindServerHandlers();

            try
            {
                await Server.StartAsync(optionsBuilder.Build());

                Running = true;
            }
            catch (Exception)
            {
                Running = false;
            }

            return Running;
        }

        private async Task StopServer()
        {
            UnbindServerHandlers();

            await Server.StopAsync();
        }

        private void BindServerHandlers()
        {
            Server.Started += ServerOnStarted;
            Server.Stopped += ServerOnStopped;
            Server.ClientConnected += ServerOnClientConnected;
            Server.ClientDisconnected += ServerOnClientDisconnected;
            Server.ClientSubscribedTopic += ServerOnClientSubscribedTopic;
            Server.ClientUnsubscribedTopic += ServerOnClientUnsubscribedTopic;
            Server.ApplicationMessageReceived += ServerOnApplicationMessageReceived;
        }

        private void UnbindServerHandlers()
        {
            Server.Started -= ServerOnStarted;
            Server.Stopped -= ServerOnStopped;
            Server.ClientConnected -= ServerOnClientConnected;
            Server.ClientDisconnected -= ServerOnClientDisconnected;
            Server.ClientSubscribedTopic -= ServerOnClientSubscribedTopic;
            Server.ClientUnsubscribedTopic -= ServerOnClientUnsubscribedTopic;
            Server.ApplicationMessageReceived -= ServerOnApplicationMessageReceived;
        }

        private async void ServerOnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e)
        {
            try
            {
                using (var inputStream = new MemoryStream(e.ApplicationMessage.Payload))
                {
                    using (var decryptedStream = await AES.Decrypt(inputStream, Password))
                    {
                        using (var lz4Decompress = new LZ4Stream(decryptedStream, CompressionMode.Decompress))
                        {
                            using (var outpuStream = new MemoryStream())
                            {
                                await lz4Decompress.CopyToAsync(outpuStream);

                                outpuStream.Position = 0L;

                                e.ApplicationMessage.Payload = outpuStream.ToArray();
                            }
                        }
                    }
                }

                await Task.Delay(0, CancellationToken).ContinueWith(_ => OnMessageReceived?.Invoke(sender, e),
                    CancellationToken, TaskContinuationOptions.None, TaskScheduler);
            }
            catch (Exception ex)
            {
                foreach (var clientSessionStatus in await Server.GetClientSessionsStatusAsync())
                {
                    if (!string.Equals(clientSessionStatus.ClientId, e.ClientId, StringComparison.Ordinal))
                        continue;

                    await clientSessionStatus.DisconnectAsync();
                }

                await Task.Delay(0, CancellationToken).ContinueWith(
                    _ => OnServerAuthenticationFailed?.Invoke(sender, new MqttAuthenticationFailureEventArgs(e, ex)),
                    CancellationToken, TaskContinuationOptions.None, TaskScheduler);
            }
        }

        private async void ServerOnClientUnsubscribedTopic(object sender, MqttClientUnsubscribedTopicEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnClientUnsubscribed?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ServerOnClientSubscribedTopic(object sender, MqttClientSubscribedTopicEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnClientSubscribed?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ServerOnClientDisconnected(object sender, MQTTnet.Server.MqttClientDisconnectedEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnServerClientDisconnected?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ServerOnClientConnected(object sender, MQTTnet.Server.MqttClientConnectedEventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnServerClientConnected?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ServerOnStopped(object sender, EventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnServerStopped?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        private async void ServerOnStarted(object sender, EventArgs e)
        {
            await Task.Delay(0, CancellationToken).ContinueWith(_ => OnServerStarted?.Invoke(sender, e),
                CancellationToken, TaskContinuationOptions.None, TaskScheduler);
        }

        public async Task Stop()
        {
            switch (Type)
            {
                case MqttCommunicationType.Server:
                    await StopServer();
                    break;
                case MqttCommunicationType.Client:
                    await StopClient();
                    break;
            }

            Running = false;
        }

        public async Task Broadcast(string topic, byte[] payload)
        {
            using (var compressStream = new MemoryStream())
            {
                using (var lz4Stream = new LZ4Stream(compressStream, CompressionMode.Compress))
                {
                    await lz4Stream.WriteAsync(payload, 0, payload.Length);
                    await lz4Stream.FlushAsync();

                    compressStream.Position = 0L;

                    using (var outputStream = await AES.Encrypt(compressStream, Password))
                    {
                        var data = outputStream.ToArray();
                        switch (Type)
                        {
                            case MqttCommunicationType.Client:
                                await Client.PublishAsync(new[]
                                {
                                    new MqttApplicationMessage
                                    {
                                        Topic = topic,
                                        Payload = data,
                                        QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce
                                    }
                                });
                                break;
                            case MqttCommunicationType.Server:
                                await Server.PublishAsync(new[]
                                {
                                    new MqttApplicationMessage
                                    {
                                        Topic = topic,
                                        Payload = data,
                                        QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce
                                    }
                                });
                                break;
                        }
                    }
                }
            }
        }
    }
}