// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.IO.Pipelines; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR.Tests { public class TestClient : ITransferFormatFeature, IConnectionHeartbeatFeature, IDisposable { private readonly object _heartbeatLock = new object(); private List<(Action handler, object state)> _heartbeatHandlers; private static int _id; private readonly IHubProtocol _protocol; private readonly IInvocationBinder _invocationBinder; private readonly CancellationTokenSource _cts; public DefaultConnectionContext Connection { get; } public Task Connected => ((TaskCompletionSource)Connection.Items["ConnectedTask"]).Task; public HandshakeResponseMessage HandshakeResponseMessage { get; private set; } public TransferFormat SupportedFormats { get; set; } = TransferFormat.Text | TransferFormat.Binary; public TransferFormat ActiveFormat { get; set; } public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { var scheduler = synchronousCallbacks ? PipeScheduler.Inline : null; var options = new PipeOptions(readerScheduler: scheduler, writerScheduler: scheduler, useSynchronizationContext: false); var pair = DuplexPipe.CreateConnectionPair(options, options); Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application); // Add features SignalR needs for testing Connection.Features.Set(this); Connection.Features.Set(this); var claimValue = Interlocked.Increment(ref _id).ToString(); var claims = new List { new Claim(ClaimTypes.Name, claimValue) }; if (addClaimId) { claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue)); } Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims)); Connection.Items["ConnectedTask"] = new TaskCompletionSource(); _protocol = protocol ?? new JsonHubProtocol(); _invocationBinder = invocationBinder ?? new DefaultInvocationBinder(); _cts = new CancellationTokenSource(); } public async Task ConnectAsync( Connections.ConnectionHandler handler, bool sendHandshakeRequestMessage = true, bool expectedHandshakeResponseMessage = true) { if (sendHandshakeRequestMessage) { var memoryBufferWriter = MemoryBufferWriter.Get(); try { HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); } finally { MemoryBufferWriter.Return(memoryBufferWriter); } } var connection = handler.OnConnectedAsync(Connection); if (expectedHandshakeResponseMessage) { // note that the handshake response might not immediately be readable // e.g. server is waiting for request, times out after configured duration, // and sends response with timeout error HandshakeResponseMessage = (HandshakeResponseMessage)await ReadAsync(true).OrTimeout(); } return connection; } public async Task> StreamAsync(string methodName, params object[] args) { var invocationId = await SendStreamInvocationAsync(methodName, args); var messages = new List(); while (true) { var message = await ReadAsync(); if (message == null) { throw new InvalidOperationException("Connection aborted!"); } if (message is HubInvocationMessage hubInvocationMessage && !string.Equals(hubInvocationMessage.InvocationId, invocationId)) { throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); } switch (message) { case StreamItemMessage _: messages.Add(message); break; case CompletionMessage _: messages.Add(message); return messages; default: throw new NotSupportedException("TestClient does not support receiving invocations!"); } } } public async Task InvokeAsync(string methodName, params object[] args) { var invocationId = await SendInvocationAsync(methodName, nonBlocking: false, args: args); while (true) { var message = await ReadAsync(); if (message == null) { throw new InvalidOperationException("Connection aborted!"); } if (message is HubInvocationMessage hubInvocationMessage && !string.Equals(hubInvocationMessage.InvocationId, invocationId)) { throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); } switch (message) { case StreamItemMessage result: throw new NotSupportedException("Use 'StreamAsync' to call a streaming method"); case CompletionMessage completion: return completion; case PingMessage _: // Pings are ignored break; default: throw new NotSupportedException("TestClient does not support receiving invocations!"); } } } public Task SendInvocationAsync(string methodName, params object[] args) { return SendInvocationAsync(methodName, nonBlocking: false, args: args); } public Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) { var invocationId = nonBlocking ? null : GetInvocationId(); return SendHubMessageAsync(new InvocationMessage(invocationId, methodName, argumentBindingException: null, arguments: args)); } public Task SendStreamInvocationAsync(string methodName, params object[] args) { var invocationId = GetInvocationId(); return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, argumentBindingException: null, arguments: args)); } public async Task SendHubMessageAsync(HubMessage message) { var payload = _protocol.GetMessageBytes(message); await Connection.Application.Output.WriteAsync(payload); return message is HubInvocationMessage hubMessage ? hubMessage.InvocationId : null; } public async Task ReadAsync(bool isHandshake = false) { while (true) { var message = TryRead(isHandshake); if (message == null) { var result = await Connection.Application.Input.ReadAsync(); var buffer = result.Buffer; try { if (!buffer.IsEmpty) { continue; } if (result.IsCompleted) { return null; } } finally { Connection.Application.Input.AdvanceTo(buffer.Start); } } else { return message; } } } public HubMessage TryRead(bool isHandshake = false) { if (!Connection.Application.Input.TryRead(out var result)) { return null; } var buffer = result.Buffer; try { if (!isHandshake) { if (_protocol.TryParseMessage(ref buffer, _invocationBinder, out var message)) { return message; } } else { // read first message out of the incoming data if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var responseMessage)) { return responseMessage; } } } finally { Connection.Application.Input.AdvanceTo(buffer.Start); } return null; } public void Dispose() { _cts.Cancel(); Connection.Application.Output.Complete(); } private static string GetInvocationId() { return Guid.NewGuid().ToString("N"); } public void OnHeartbeat(Action action, object state) { lock (_heartbeatLock) { if (_heartbeatHandlers == null) { _heartbeatHandlers = new List<(Action handler, object state)>(); } _heartbeatHandlers.Add((action, state)); } } public void TickHeartbeat() { lock (_heartbeatLock) { if (_heartbeatHandlers == null) { return; } foreach (var (handler, state) in _heartbeatHandlers) { handler(state); } } } private class DefaultInvocationBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) { // TODO: Possibly support actual client methods return new[] { typeof(object) }; } public Type GetReturnType(string invocationId) { return typeof(object); } } } }