WingMan

Subversion Repositories:
Compare Path: Rev
With Path: Rev
?path1? @ 7  →  ?path2? @ 8
/trunk/WingMan/Communication/MQTTCommunication.cs
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
@@ -5,7 +7,9 @@
using MQTTnet;
using MQTTnet.Client;
using MQTTnet.Extensions.ManagedClient;
using MQTTnet.Protocol;
using MQTTnet.Server;
using WingMan.Utilities;
using MqttClientConnectedEventArgs = MQTTnet.Client.MqttClientConnectedEventArgs;
using MqttClientDisconnectedEventArgs = MQTTnet.Client.MqttClientDisconnectedEventArgs;
 
@@ -13,6 +17,8 @@
{
public class MQTTCommunication : IDisposable
{
public delegate void ClientAuthenticationFailed(object sender, EventArgs e);
 
public delegate void ClientConnected(object sender, MqttClientConnectedEventArgs e);
 
public delegate void ClientConnectionFailed(object sender, MqttManagedProcessFailedEventArgs e);
@@ -25,6 +31,8 @@
 
public delegate void MessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e);
 
public delegate void ServerAuthenticationFailed(object sender, EventArgs e);
 
public delegate void ServerClientConnected(object sender, MQTTnet.Server.MqttClientConnectedEventArgs e);
 
public delegate void ServerClientDisconnected(object sender, MQTTnet.Server.MqttClientDisconnectedEventArgs e);
@@ -38,10 +46,14 @@
TaskScheduler = taskScheduler;
CancellationToken = cancellationToken;
 
TrackedClients = new TrackedClients {Clients = new List<TrackedClient>()};
 
Client = new MqttFactory().CreateManagedMqttClient();
Server = new MqttFactory().CreateMqttServer();
}
 
private TrackedClients TrackedClients { get; }
 
private TaskScheduler TaskScheduler { get; }
 
private IManagedMqttClient Client { get; }
@@ -56,6 +68,8 @@
 
private int Port { get; set; }
 
private string Password { get; set; }
 
private CancellationToken CancellationToken { get; }
 
public MQTTCommunicationType Type { get; set; }
@@ -65,6 +79,10 @@
await Stop();
}
 
public event ClientAuthenticationFailed OnClientAuthenticationFailed;
 
public event ServerAuthenticationFailed OnServerAuthenticationFailed;
 
public event MessageReceived OnMessageReceived;
 
public event ClientConnected OnClientConnected;
@@ -85,12 +103,13 @@
 
public event ServerStopped OnServerStopped;
 
public async Task Start(MQTTCommunicationType type, IPAddress ipAddress, int port, string nick)
public async Task Start(MQTTCommunicationType type, IPAddress ipAddress, int port, string nick, string password)
{
Type = type;
IPAddress = ipAddress;
Port = port;
Nick = nick;
Password = password;
 
switch (type)
{
@@ -158,6 +177,18 @@
 
private async void ClientOnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e)
{
try
{
e.ApplicationMessage.Payload = await AES.Decrypt(e.ApplicationMessage.Payload, Password);
}
catch (Exception)
{
await Task.Delay(0).ContinueWith(_ => OnClientAuthenticationFailed?.Invoke(sender, e),
CancellationToken, TaskContinuationOptions.None, TaskScheduler);
 
return;
}
 
await Task.Delay(0).ContinueWith(_ => OnMessageReceived?.Invoke(sender, e),
CancellationToken, TaskContinuationOptions.None, TaskScheduler);
}
@@ -185,6 +216,7 @@
var optionsBuilder = new MqttServerOptionsBuilder()
.WithDefaultEndpointBoundIPAddress(IPAddress)
.WithSubscriptionInterceptor(MQTTSubscriptionIntercept)
.WithConnectionValidator(MQTTConnectionValidator)
.WithDefaultEndpointPort(Port);
 
BindServerHandlers();
@@ -194,6 +226,23 @@
Running = true;
}
 
private void MQTTConnectionValidator(MqttConnectionValidatorContext context)
{
// Do not accept connections from banned clients.
if (TrackedClients.Clients.Any(client =>
(string.Equals(client.EndPoint, context.Endpoint, StringComparison.OrdinalIgnoreCase) ||
string.Equals(client.ClientId, context.ClientId, StringComparison.Ordinal)) &&
client.Banned))
{
context.ReturnCode = MqttConnectReturnCode.ConnectionRefusedNotAuthorized;
return;
}
 
TrackedClients.Clients.Add(new TrackedClient {ClientId = context.ClientId, EndPoint = context.Endpoint});
 
context.ReturnCode = MqttConnectReturnCode.ConnectionAccepted;
}
 
private async Task StopServer()
{
UnbindServerHandlers();
@@ -239,6 +288,37 @@
 
private async void ServerOnApplicationMessageReceived(object sender, MqttApplicationMessageReceivedEventArgs e)
{
try
{
e.ApplicationMessage.Payload = await AES.Decrypt(e.ApplicationMessage.Payload, Password);
}
catch (Exception)
{
// Decryption failed, assume a rogue client and ban the client.
foreach (var client in TrackedClients.Clients)
{
if (!string.Equals(client.ClientId, e.ClientId, StringComparison.Ordinal))
continue;
 
client.Banned = true;
 
foreach (var clientSessionStatus in await Server.GetClientSessionsStatusAsync())
{
if (!string.Equals(clientSessionStatus.ClientId, e.ClientId, StringComparison.Ordinal) &&
!string.Equals(clientSessionStatus.Endpoint, client.EndPoint, StringComparison.Ordinal))
continue;
 
await clientSessionStatus.ClearPendingApplicationMessagesAsync();
await clientSessionStatus.DisconnectAsync();
}
}
 
await Task.Delay(0).ContinueWith(_ => OnServerAuthenticationFailed?.Invoke(sender, e),
CancellationToken, TaskContinuationOptions.None, TaskScheduler);
 
return;
}
 
await Task.Delay(0).ContinueWith(_ => OnMessageReceived?.Invoke(sender, e),
CancellationToken, TaskContinuationOptions.None, TaskScheduler).ConfigureAwait(false);
}
@@ -294,19 +374,23 @@
 
public async Task Broadcast(string topic, byte[] payload)
{
// Encrypt the payload.
var encryptedPayload = await AES.Encrypt(payload, Password);
 
switch (Type)
{
case MQTTCommunicationType.Client:
 
await Client.PublishAsync(new ManagedMqttApplicationMessage
{
ApplicationMessage = new MqttApplicationMessage {Topic = topic, Payload = payload}
ApplicationMessage = new MqttApplicationMessage {Topic = topic, Payload = encryptedPayload}
}).ConfigureAwait(false);
break;
case MQTTCommunicationType.Server:
await Server.PublishAsync(new MqttApplicationMessage {Topic = topic, Payload = payload})
await Server.PublishAsync(new MqttApplicationMessage {Topic = topic, Payload = encryptedPayload})
.ConfigureAwait(false);
break;
}
}
}
}
}
/trunk/WingMan/Communication/TrackedClient.cs
@@ -0,0 +1,11 @@
namespace WingMan.Communication
{
public class TrackedClient
{
public string ClientId { get; set; }
 
public string EndPoint { get; set; }
 
public bool Banned { get; set; }
}
}
/trunk/WingMan/Communication/TrackedClients.cs
@@ -0,0 +1,9 @@
using System.Collections.Generic;
 
namespace WingMan.Communication
{
public class TrackedClients
{
public List<TrackedClient> Clients { get; set; }
}
}