// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.IO; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using System.Threading.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Internal; using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Tests { public class TestClient : IDisposable { private static int _id; private readonly HubProtocolReaderWriter _protocolReaderWriter; private readonly IInvocationBinder _invocationBinder; private CancellationTokenSource _cts; private ChannelConnection _transport; public DefaultConnectionContext Connection { get; } public Channel Application { get; } public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { var options = new UnboundedChannelOptions { AllowSynchronousContinuations = synchronousCallbacks }; var transportToApplication = Channel.CreateUnbounded(options); var applicationToTransport = Channel.CreateUnbounded(options); Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); _transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application); var claimValue = Interlocked.Increment(ref _id).ToString(); var claims = new List{ new Claim(ClaimTypes.Name, claimValue) }; if (addClaimId) { claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue)); } Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims)); Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); protocol = protocol ?? new JsonHubProtocol(); _protocolReaderWriter = new HubProtocolReaderWriter(protocol, new PassThroughEncoder()); _invocationBinder = invocationBinder ?? new DefaultInvocationBinder(); _cts = new CancellationTokenSource(); using (var memoryStream = new MemoryStream()) { NegotiationProtocol.WriteMessage(new NegotiationMessage(protocol.Name), memoryStream); Application.Writer.TryWrite(memoryStream.ToArray()); } } public async Task> StreamAsync(string methodName, params object[] args) { var invocationId = await SendStreamInvocationAsync(methodName, args); var messages = new List(); while (true) { var message = await ReadAsync(); if (message == null) { throw new InvalidOperationException("Connection aborted!"); } if (message is HubInvocationMessage hubInvocationMessage && !string.Equals(hubInvocationMessage.InvocationId, invocationId)) { throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); } switch (message) { case StreamItemMessage _: messages.Add(message); break; case CompletionMessage _: messages.Add(message); return messages; default: throw new NotSupportedException("TestClient does not support receiving invocations!"); } } } public async Task InvokeAsync(string methodName, params object[] args) { var invocationId = await SendInvocationAsync(methodName, nonBlocking: false, args: args); while (true) { var message = await ReadAsync(); if (message == null) { throw new InvalidOperationException("Connection aborted!"); } if (message is HubInvocationMessage hubInvocationMessage && !string.Equals(hubInvocationMessage.InvocationId, invocationId)) { throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); } switch (message) { case StreamItemMessage result: throw new NotSupportedException("Use 'StreamAsync' to call a streaming method"); case CompletionMessage completion: return completion; case PingMessage _: // Pings are ignored break; default: throw new NotSupportedException("TestClient does not support receiving invocations!"); } } } public Task SendInvocationAsync(string methodName, params object[] args) { return SendInvocationAsync(methodName, nonBlocking: false, args: args); } public Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) { var invocationId = GetInvocationId(); return SendHubMessageAsync(new InvocationMessage(invocationId, nonBlocking, methodName, argumentBindingException: null, arguments: args)); } public Task SendStreamInvocationAsync(string methodName, params object[] args) { var invocationId = GetInvocationId(); return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, argumentBindingException: null, arguments: args)); } public async Task SendHubMessageAsync(HubMessage message) { var payload = _protocolReaderWriter.WriteMessage(message); await Application.Writer.WriteAsync(payload); return message is HubInvocationMessage hubMessage ? hubMessage.InvocationId : null; } public async Task ReadAsync() { while (true) { var message = TryRead(); if (message == null) { if (!await Application.Reader.WaitToReadAsync()) { return null; } } else { return message; } } } public HubMessage TryRead() { if (Application.Reader.TryRead(out var buffer) && _protocolReaderWriter.ReadMessages(buffer, _invocationBinder, out var messages)) { return messages[0]; } return null; } public void Dispose() { _cts.Cancel(); _transport.Dispose(); } private static string GetInvocationId() { return Guid.NewGuid().ToString("N"); } private class DefaultInvocationBinder : IInvocationBinder { public Type[] GetParameterTypes(string methodName) { // TODO: Possibly support actual client methods return new[] { typeof(object) }; } public Type GetReturnType(string invocationId) { return typeof(object); } } } }