/trunk/WingMan/Communication/MQTTCommunication.cs |
@@ -1,4 +1,6 @@ |
using System; |
using System.Collections.Generic; |
using System.Linq; |
using System.Net; |
using System.Threading; |
using System.Threading.Tasks; |
@@ -5,7 +7,9 @@ |
using MQTTnet; |
using MQTTnet.Client; |
using MQTTnet.Extensions.ManagedClient; |
using MQTTnet.Protocol; |
using MQTTnet.Server; |
using WingMan.Utilities; |
using MqttClientConnectedEventArgs = MQTTnet.Client.MqttClientConnectedEventArgs; |
using MqttClientDisconnectedEventArgs = MQTTnet.Client.MqttClientDisconnectedEventArgs; |
|
@@ -13,6 +17,8 @@ |
{ |
public class MQTTCommunication : IDisposable |
{ |
public delegate void ClientAuthenticationFailed(object sender, EventArgs e); |
|
public delegate void ClientConnected(object sender, MqttClientConnectedEventArgs e); |
|
public delegate void ClientConnectionFailed(object sender, MqttManagedProcessFailedEventArgs e); |
@@ -25,6 +31,8 @@ |
|
public delegate void MessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e); |
|
public delegate void ServerAuthenticationFailed(object sender, EventArgs e); |
|
public delegate void ServerClientConnected(object sender, MQTTnet.Server.MqttClientConnectedEventArgs e); |
|
public delegate void ServerClientDisconnected(object sender, MQTTnet.Server.MqttClientDisconnectedEventArgs e); |
@@ -38,10 +46,14 @@ |
TaskScheduler = taskScheduler; |
CancellationToken = cancellationToken; |
|
TrackedClients = new TrackedClients {Clients = new List<TrackedClient>()}; |
|
Client = new MqttFactory().CreateManagedMqttClient(); |
Server = new MqttFactory().CreateMqttServer(); |
} |
|
private TrackedClients TrackedClients { get; } |
|
private TaskScheduler TaskScheduler { get; } |
|
private IManagedMqttClient Client { get; } |
@@ -56,6 +68,8 @@ |
|
private int Port { get; set; } |
|
private string Password { get; set; } |
|
private CancellationToken CancellationToken { get; } |
|
public MQTTCommunicationType Type { get; set; } |
@@ -65,6 +79,10 @@ |
await Stop(); |
} |
|
public event ClientAuthenticationFailed OnClientAuthenticationFailed; |
|
public event ServerAuthenticationFailed OnServerAuthenticationFailed; |
|
public event MessageReceived OnMessageReceived; |
|
public event ClientConnected OnClientConnected; |
@@ -85,12 +103,13 @@ |
|
public event ServerStopped OnServerStopped; |
|
public async Task Start(MQTTCommunicationType type, IPAddress ipAddress, int port, string nick) |
public async Task Start(MQTTCommunicationType type, IPAddress ipAddress, int port, string nick, string password) |
{ |
Type = type; |
IPAddress = ipAddress; |
Port = port; |
Nick = nick; |
Password = password; |
|
switch (type) |
{ |
@@ -158,6 +177,18 @@ |
|
private async void ClientOnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e) |
{ |
try |
{ |
e.ApplicationMessage.Payload = await AES.Decrypt(e.ApplicationMessage.Payload, Password); |
} |
catch (Exception) |
{ |
await Task.Delay(0).ContinueWith(_ => OnClientAuthenticationFailed?.Invoke(sender, e), |
CancellationToken, TaskContinuationOptions.None, TaskScheduler); |
|
return; |
} |
|
await Task.Delay(0).ContinueWith(_ => OnMessageReceived?.Invoke(sender, e), |
CancellationToken, TaskContinuationOptions.None, TaskScheduler); |
} |
@@ -185,6 +216,7 @@ |
var optionsBuilder = new MqttServerOptionsBuilder() |
.WithDefaultEndpointBoundIPAddress(IPAddress) |
.WithSubscriptionInterceptor(MQTTSubscriptionIntercept) |
.WithConnectionValidator(MQTTConnectionValidator) |
.WithDefaultEndpointPort(Port); |
|
BindServerHandlers(); |
@@ -194,6 +226,23 @@ |
Running = true; |
} |
|
private void MQTTConnectionValidator(MqttConnectionValidatorContext context) |
{ |
// Do not accept connections from banned clients. |
if (TrackedClients.Clients.Any(client => |
(string.Equals(client.EndPoint, context.Endpoint, StringComparison.OrdinalIgnoreCase) || |
string.Equals(client.ClientId, context.ClientId, StringComparison.Ordinal)) && |
client.Banned)) |
{ |
context.ReturnCode = MqttConnectReturnCode.ConnectionRefusedNotAuthorized; |
return; |
} |
|
TrackedClients.Clients.Add(new TrackedClient {ClientId = context.ClientId, EndPoint = context.Endpoint}); |
|
context.ReturnCode = MqttConnectReturnCode.ConnectionAccepted; |
} |
|
private async Task StopServer() |
{ |
UnbindServerHandlers(); |
@@ -239,6 +288,37 @@ |
|
private async void ServerOnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e) |
{ |
try |
{ |
e.ApplicationMessage.Payload = await AES.Decrypt(e.ApplicationMessage.Payload, Password); |
} |
catch (Exception) |
{ |
// Decryption failed, assume a rogue client and ban the client. |
foreach (var client in TrackedClients.Clients) |
{ |
if (!string.Equals(client.ClientId, e.ClientId, StringComparison.Ordinal)) |
continue; |
|
client.Banned = true; |
|
foreach (var clientSessionStatus in await Server.GetClientSessionsStatusAsync()) |
{ |
if (!string.Equals(clientSessionStatus.ClientId, e.ClientId, StringComparison.Ordinal) && |
!string.Equals(clientSessionStatus.Endpoint, client.EndPoint, StringComparison.Ordinal)) |
continue; |
|
await clientSessionStatus.ClearPendingApplicationMessagesAsync(); |
await clientSessionStatus.DisconnectAsync(); |
} |
} |
|
await Task.Delay(0).ContinueWith(_ => OnServerAuthenticationFailed?.Invoke(sender, e), |
CancellationToken, TaskContinuationOptions.None, TaskScheduler); |
|
return; |
} |
|
await Task.Delay(0).ContinueWith(_ => OnMessageReceived?.Invoke(sender, e), |
CancellationToken, TaskContinuationOptions.None, TaskScheduler).ConfigureAwait(false); |
} |
@@ -294,19 +374,23 @@ |
|
public async Task Broadcast(string topic, byte[] payload) |
{ |
// Encrypt the payload. |
var encryptedPayload = await AES.Encrypt(payload, Password); |
|
switch (Type) |
{ |
case MQTTCommunicationType.Client: |
|
await Client.PublishAsync(new ManagedMqttApplicationMessage |
{ |
ApplicationMessage = new MqttApplicationMessage {Topic = topic, Payload = payload} |
ApplicationMessage = new MqttApplicationMessage {Topic = topic, Payload = encryptedPayload} |
}).ConfigureAwait(false); |
break; |
case MQTTCommunicationType.Server: |
await Server.PublishAsync(new MqttApplicationMessage {Topic = topic, Payload = payload}) |
await Server.PublishAsync(new MqttApplicationMessage {Topic = topic, Payload = encryptedPayload}) |
.ConfigureAwait(false); |
break; |
} |
} |
} |
} |
} |