diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs index a2007c6dfc..614265305a 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/ServerSentEventsTransport.cs @@ -44,9 +44,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports try { - var ms = new MemoryStream(); while (await _application.WaitToReadAsync(token)) { + var ms = new MemoryStream(); while (_application.TryRead(out var buffer)) { _logger.SSEWritingMessage(_connectionId, buffer.Length); @@ -61,10 +61,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports throw new InvalidOperationException("Ran out of space to format messages!"); } } - } - ms.Seek(0, SeekOrigin.Begin); - await ms.CopyToAsync(context.Response.Body); + ms.Seek(0, SeekOrigin.Begin); + await ms.CopyToAsync(context.Response.Body); + } await _application.Completion; } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index 1d090f5c3f..fbdc4137b8 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -7,6 +7,7 @@ using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal.Transports; using Microsoft.Extensions.Logging; using Xunit; @@ -46,6 +47,30 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.True(feature.ResponseBufferingDisabled); } + [Fact] + public async Task SSEWritesMessages() + { + var channel = Channel.CreateUnbounded(new ChannelOptimizations + { + AllowSynchronousContinuations = true + }); + + var context = new DefaultHttpContext(); + var ms = new MemoryStream(); + context.Response.Body = ms; + var sse = new ServerSentEventsTransport(channel, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + var task = sse.ProcessRequestAsync(context, context.RequestAborted); + + await channel.Out.WriteAsync(Encoding.ASCII.GetBytes("Hello")); + + Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray())); + + channel.Out.TryComplete(); + + await task.OrTimeout(); + } + [Theory] [InlineData("Hello World", ":\r\ndata: Hello World\r\n\r\n")] [InlineData("Hello\nWorld", ":\r\ndata: Hello\r\ndata: World\r\n\r\n")]