using System; using System.Collections.Generic; using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR.Client.Tests { delegate Task RequestDelegate(HttpRequestMessage requestMessage, CancellationToken cancellationToken); public class TestHttpMessageHandler : HttpMessageHandler { private List _receivedRequests = new List(); private RequestDelegate _app; private List> _middleware = new List>(); public bool Disposed { get; private set; } public IReadOnlyList ReceivedRequests { get { lock (_receivedRequests) { return _receivedRequests.ToArray(); } } } public TestHttpMessageHandler(bool autoNegotiate = true, bool handleFirstPoll = true) { if (autoNegotiate) { OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent())); } if (handleFirstPoll) { var firstPoll = true; OnRequest(async (request, next, cancellationToken) => { if (ResponseUtils.IsLongPollRequest(request) && firstPoll) { firstPoll = false; return ResponseUtils.CreateResponse(HttpStatusCode.OK); } else { return await next(); } }); } } protected override void Dispose(bool disposing) { Disposed = true; base.Dispose(disposing); } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { await Task.Yield(); lock (_receivedRequests) { _receivedRequests.Add(request); if (_app == null) { _middleware.Reverse(); RequestDelegate handler = BaseHandler; foreach (var middleware in _middleware) { handler = middleware(handler); } _app = handler; } } return await _app(request, cancellationToken); } public static TestHttpMessageHandler CreateDefault() { var testHttpMessageHandler = new TestHttpMessageHandler(); var deleteCts = new CancellationTokenSource(); testHttpMessageHandler.OnSocketSend((_, __) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); testHttpMessageHandler.OnLongPoll(async cancellationToken => { var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, deleteCts.Token); // Just block until canceled var tcs = new TaskCompletionSource(); using (cts.Token.Register(() => tcs.TrySetResult(null))) { await tcs.Task; } return ResponseUtils.CreateResponse(HttpStatusCode.NoContent); }); testHttpMessageHandler.OnRequest((request, next, cancellationToken) => { if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.StartsWith("/?id=")) { deleteCts.Cancel(); return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); } return next(); }); return testHttpMessageHandler; } public void OnRequest(Func>, CancellationToken, Task> handler) { void OnRequestCore(Func middleware) { _middleware.Add(middleware); } OnRequestCore(next => { return (request, cancellationToken) => { return handler(request, () => next(request, cancellationToken), cancellationToken); }; }); } public void OnGet(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Get, pathAndQuery, handler); public void OnPost(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Post, pathAndQuery, handler); public void OnPut(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Put, pathAndQuery, handler); public void OnDelete(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Delete, pathAndQuery, handler); public void OnHead(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Head, pathAndQuery, handler); public void OnOptions(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Options, pathAndQuery, handler); public void OnTrace(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Trace, pathAndQuery, handler); public void OnRequest(HttpMethod method, string pathAndQuery, Func> handler) { OnRequest((request, next, cancellationToken) => { if (request.Method.Equals(method) && string.Equals(request.RequestUri.PathAndQuery, pathAndQuery)) { return handler(request, cancellationToken); } else { return next(); } }); } public void OnNegotiate(Func handler) => OnNegotiate((req, cancellationToken) => Task.FromResult(handler(req, cancellationToken))); public void OnNegotiate(Func> handler) { OnRequest((request, next, cancellationToken) => { if (ResponseUtils.IsNegotiateRequest(request)) { return handler(request, cancellationToken); } else { return next(); } }); } public void OnLongPollDelete(Func handler) => OnLongPollDelete((cancellationToken) => Task.FromResult(handler(cancellationToken))); public void OnLongPollDelete(Func> handler) { OnRequest((request, next, cancellationToken) => { if (ResponseUtils.IsLongPollDeleteRequest(request)) { return handler(cancellationToken); } else { return next(); } }); } public void OnLongPoll(Func handler) => OnLongPoll(cancellationToken => Task.FromResult(handler(cancellationToken))); public void OnLongPoll(Func> handler) { OnLongPoll((request, token) => handler(token)); } public void OnLongPoll(Func handler) { OnLongPoll((request, token) => Task.FromResult(handler(request, token))); } public void OnLongPoll(Func> handler) { OnRequest((request, next, cancellationToken) => { if (ResponseUtils.IsLongPollRequest(request)) { return handler(request, cancellationToken); } else { return next(); } }); } public void OnSocketSend(Func handler) => OnSocketSend((data, cancellationToken) => Task.FromResult(handler(data, cancellationToken))); public void OnSocketSend(Func> handler) { OnRequest(async (request, next, cancellationToken) => { if (ResponseUtils.IsSocketSendRequest(request)) { var data = await request.Content.ReadAsByteArrayAsync(); return await handler(data, cancellationToken); } else { return await next(); } }); } private Task BaseHandler(HttpRequestMessage request, CancellationToken cancellationToken) { return Task.FromException(new InvalidOperationException($"Http endpoint not implemented: {request.Method} {request.RequestUri}")); } } }