aspnetcore/src/Microsoft.AspNetCore.Socket.../LongPollingTransport.cs

142 lines
5.2 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Client.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
namespace Microsoft.AspNetCore.Sockets.Client
{
public class LongPollingTransport : ITransport
{
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private Channel<byte[], SendMessage> _application;
private Task _sender;
private Task _poller;
private string _connectionId;
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
public Task Running { get; private set; } = Task.CompletedTask;
public TransferMode? Mode { get; private set; }
public LongPollingTransport(HttpClient httpClient)
: this(httpClient, null)
{ }
public LongPollingTransport(HttpClient httpClient, ILoggerFactory loggerFactory)
{
_httpClient = httpClient;
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<LongPollingTransport>();
}
public Task StartAsync(Uri url, Channel<byte[], SendMessage> application, TransferMode requestedTransferMode, string connectionId)
{
if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text)
{
throw new ArgumentException("Invalid transfer mode.", nameof(requestedTransferMode));
}
_application = application;
Mode = requestedTransferMode;
_connectionId = connectionId;
_logger.StartTransport(_connectionId, Mode.Value);
// Start sending and polling (ask for binary if the server supports it)
_poller = Poll(url, _transportCts.Token);
_sender = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger, _connectionId);
Running = Task.WhenAll(_sender, _poller).ContinueWith(t =>
{
_logger.TransportStopped(_connectionId, t.Exception?.InnerException);
_application.Out.TryComplete(t.IsFaulted ? t.Exception.InnerException : null);
return t;
}).Unwrap();
return Task.CompletedTask;
}
public async Task StopAsync()
{
_logger.TransportStopping(_connectionId);
_transportCts.Cancel();
try
{
await Running;
}
catch
{
// exceptions have been handled in the Running task continuation by closing the channel with the exception
}
}
private async Task Poll(Uri pollUrl, CancellationToken cancellationToken)
{
_logger.StartReceive(_connectionId);
try
{
while (!cancellationToken.IsCancellationRequested)
{
var request = new HttpRequestMessage(HttpMethod.Get, pollUrl);
request.Headers.UserAgent.Add(SendUtils.DefaultUserAgentHeader);
var response = await _httpClient.SendAsync(request, cancellationToken);
response.EnsureSuccessStatusCode();
if (response.StatusCode == HttpStatusCode.NoContent || cancellationToken.IsCancellationRequested)
{
_logger.ClosingConnection(_connectionId);
// Transport closed or polling stopped, we're done
break;
}
else
{
_logger.ReceivedMessages(_connectionId);
// Until Pipeline starts natively supporting BytesReader, this is the easiest way to do this.
var payload = await response.Content.ReadAsByteArrayAsync();
if (payload.Length > 0)
{
while (!_application.Out.TryWrite(payload))
{
if (cancellationToken.IsCancellationRequested || !await _application.Out.WaitToWriteAsync(cancellationToken))
{
return;
}
}
}
}
}
}
catch (OperationCanceledException)
{
// transport is being closed
_logger.ReceiveCanceled(_connectionId);
}
catch (Exception ex)
{
_logger.ErrorPolling(_connectionId, pollUrl, ex);
throw;
}
finally
{
// Make sure the send loop is terminated
_transportCts.Cancel();
_logger.ReceiveStopped(_connectionId);
}
}
}
}