Make the http end points more resty (#470)

- Use HTTP verbs to describe functionality for endpoints
- Updated TransportProtocols.md
This commit is contained in:
David Fowler 2017-05-19 23:37:17 -07:00 committed by GitHub
parent 8f18ff4423
commit 240a88f7af
21 changed files with 308 additions and 180 deletions

View File

@ -9,10 +9,10 @@ describe("Connection", () => {
it("starting connection fails if getting id fails", async (done) => {
let options: ISignalROptions = {
httpClient: <IHttpClient>{
options(url: string): Promise<string> {
return Promise.reject("error");
},
get(url: string): Promise<string> {
if (url.indexOf("/negotiate") >= 0) {
return Promise.reject("error");
}
return Promise.resolve("");
}
}
@ -34,9 +34,8 @@ describe("Connection", () => {
it("cannot start a running connection", async (done) => {
let options: ISignalROptions = {
httpClient: <IHttpClient>{
get(url: string): Promise<string> {
if (url.indexOf("/negotiate") >= 0) {
connection.start()
options(url: string): Promise<string> {
connection.start()
.then(() => {
fail();
done();
@ -47,7 +46,8 @@ describe("Connection", () => {
});
return Promise.reject("error");
}
},
get(url: string): Promise<string> {
return Promise.resolve("");
}
}
@ -67,10 +67,10 @@ describe("Connection", () => {
it("cannot start a stopped connection", async (done) => {
let options: ISignalROptions = {
httpClient: <IHttpClient>{
options(url: string): Promise<string> {
return Promise.reject("error");
},
get(url: string): Promise<string> {
if (url.indexOf("/negotiate") >= 0) {
return Promise.reject("error");
}
return Promise.resolve("");
}
}
@ -100,6 +100,10 @@ describe("Connection", () => {
it("can stop a starting connection", async (done) => {
let options: ISignalROptions = {
httpClient: <IHttpClient>{
options(url: string): Promise<string> {
connection.stop();
return Promise.resolve("");
},
get(url: string): Promise<string> {
connection.stop();
return Promise.resolve("");
@ -128,10 +132,10 @@ describe("Connection", () => {
it("preserves users connection string", async done => {
let options: ISignalROptions = {
httpClient: <IHttpClient>{
options(url: string): Promise<string> {
return Promise.resolve("42");
},
get(url: string): Promise<string> {
if (url.includes("negotiate")) {
return Promise.resolve("42");
}
return Promise.resolve("");
}
}

View File

@ -40,7 +40,8 @@ export class Connection implements IConnection {
private async startInternal(transportType: TransportType | ITransport): Promise<void> {
try {
this.connectionId = await this.httpClient.get(`${this.url}/negotiate?${this.queryString}`);
var negotiateUrl = this.url + (this.queryString ? "?" + this.queryString : "");
this.connectionId = await this.httpClient.options(negotiateUrl);
// the user tries to stop the the connection when it is being started
if (this.connectionState == ConnectionState.Disconnected) {

View File

@ -1,5 +1,6 @@
export interface IHttpClient {
get(url: string, headers?: Map<string, string>): Promise<string>;
options(url: string, headers?: Map<string, string>): Promise<string>;
post(url: string, content: string, headers?: Map<string, string>): Promise<string>;
}
@ -8,6 +9,10 @@ export class HttpClient implements IHttpClient {
return this.xhr("GET", url, headers);
}
options(url: string, headers?: Map<string, string>): Promise<string> {
return this.xhr("OPTIONS", url, headers);
}
post(url: string, content: string, headers?: Map<string, string>): Promise<string> {
return this.xhr("POST", url, headers, content);
}

View File

@ -22,7 +22,7 @@ export class WebSocketTransport implements ITransport {
connect(url: string, queryString: string = ""): Promise<void> {
return new Promise<void>((resolve, reject) => {
url = url.replace(/^http/, "ws");
let connectUrl = url + "/ws?" + queryString;
let connectUrl = url + (queryString ? "?" + queryString : "");
let webSocket = new WebSocket(connectUrl);
@ -81,6 +81,7 @@ export class ServerSentEventsTransport implements ITransport {
private eventSource: EventSource;
private url: string;
private queryString: string;
private fullUrl: string;
private httpClient: IHttpClient;
constructor(httpClient: IHttpClient) {
@ -94,10 +95,10 @@ export class ServerSentEventsTransport implements ITransport {
this.queryString = queryString;
this.url = url;
let tmp = `${this.url}/sse?${this.queryString}`;
this.fullUrl = url + (queryString ? "?" + queryString : "");
return new Promise<void>((resolve, reject) => {
let eventSource = new EventSource(`${this.url}/sse?${this.queryString}`);
let eventSource = new EventSource(this.fullUrl);
try {
eventSource.onmessage = (e: MessageEvent) => {
@ -139,7 +140,7 @@ export class ServerSentEventsTransport implements ITransport {
}
async send(data: any): Promise<void> {
return send(this.httpClient, `${this.url}/send?${this.queryString}`, data);
return send(this.httpClient, this.fullUrl, data);
}
stop(): void {
@ -156,6 +157,7 @@ export class ServerSentEventsTransport implements ITransport {
export class LongPollingTransport implements ITransport {
private url: string;
private queryString: string;
private fullUrl: string;
private httpClient: IHttpClient;
private pollXhr: XMLHttpRequest;
private shouldPoll: boolean;
@ -168,7 +170,8 @@ export class LongPollingTransport implements ITransport {
this.url = url;
this.queryString = queryString;
this.shouldPoll = true;
this.poll(url + "/poll?" + this.queryString)
this.fullUrl = url + (queryString ? "?" + queryString : "");
this.poll(this.fullUrl);
return Promise.resolve();
}
@ -231,7 +234,7 @@ export class LongPollingTransport implements ITransport {
}
async send(data: any): Promise<void> {
return send(this.httpClient, `${this.url}/send?${this.queryString}`, data);
return send(this.httpClient, this.fullUrl, data);
}
stop(): void {

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\dependencies.props" />
@ -6,6 +6,7 @@
<TargetFramework>netcoreapp2.0</TargetFramework>
<!-- Don't create a NuGet package -->
<IsPackable>false</IsPackable>
<OutputType>Exe</OutputType>
</PropertyGroup>
<ItemGroup>

View File

@ -8,6 +8,9 @@
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Redis\Microsoft.AspNetCore.SignalR.Redis.csproj" />
<ProjectReference Include="..\..\client-ts\Microsoft.AspNetCore.SignalR.Client.TS\Microsoft.AspNetCore.SignalR.Client.TS.csproj" />
@ -18,6 +21,7 @@
<PackageReference Include="Microsoft.AspNetCore.Server.IISIntegration" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.StaticFiles" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Cors" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Configuration.CommandLine" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufVersion)" />

View File

@ -21,6 +21,16 @@ namespace SocketsSample
services.AddSignalR();
// .AddRedis();
services.AddCors(o =>
{
o.AddPolicy("Everything", p =>
{
p.AllowAnyHeader()
.AllowAnyMethod()
.AllowAnyOrigin();
});
});
services.AddEndPoint<MessagesEndPoint>();
}
@ -34,6 +44,8 @@ namespace SocketsSample
app.UseDeveloperExceptionPage();
}
app.UseCors("Everything");
app.UseSignalR(routes =>
{
routes.MapHub<Chat>("/hubs");

View File

@ -5,7 +5,7 @@
<title></title>
<script>
document.addEventListener('DOMContentLoaded', () => {
var ws = new WebSocket(`ws://${document.location.host}/chat/ws`);
var ws = new WebSocket(`ws://${document.location.host}/chat`);
ws.onopen = function () {
console.log('Opened!');

View File

@ -18,7 +18,7 @@ Multi-frame messages (where a message is split into multiple frames) may not ove
The only transport which fully implements the duplex requirement is WebSockets, the others are "half-transports" which implement one end of the duplex connection. They are used in combination to achieve a duplex connection.
Throughout this document, the term `[endpoint-base]` is used to refer to the route assigned to a particular end point. The term `[connection-id]` is used to refer to the connection ID provided by the `[endpoint-base]/negotiate` end point.
Throughout this document, the term `[endpoint-base]` is used to refer to the route assigned to a particular end point. The term `[connection-id]` is used to refer to the connection ID provided by the `OPTIONS [endpoint-base]` request.
**NOTE on errors:** In all error cases, by default, the detailed exception message is **never** provided; a short description string may be provided. However, an application developer may elect to allow detailed exception messages to be emitted, which should only be used in the `Development` environment. Unexpected errors are communicated by HTTP `500 Server Error` status codes or WebSockets `1008 Policy Violation` close frames; in these cases the connection should be considered to be terminated.
@ -28,9 +28,9 @@ For the Long-Polling and Server-Sent events transports, there are two additional
## WebSockets (Full Duplex)
The WebSockets transport is unique in that it is full duplex, and a persistent connection that can be established in a single operation. As a result, the client is not required to use the `[endpoint-base]/negotiate` endpoint to establish a connection in advance. It also includes all the necessary metadata in it's own frame metadata.
The WebSockets transport is unique in that it is full duplex, and a persistent connection that can be established in a single operation. As a result, the client is not required to use the `OPTIONS [endpoint-base]` request to establish a connection in advance. It also includes all the necessary metadata in it's own frame metadata.
The WebSocket transport is activated by making a WebSocket connection to `[endpoint-base]/ws`. The **optional** `connectionId` query string value is used to identify the connection to attach to. If there is no `connectionId` query string value, a new connection is established. If the parameter is specified but there is no connection with the specified ID value, a `404 Not Found` response is returned. Upon receiving this request, the connection is established and the server responds with a WebSocket upgrade (`101 Switching Protocols`) immediately ready for frames to be sent/received. The WebSocket OpCode field is used to indicate the type of the frame (Text or Binary) and the WebSocket "FIN" flag is used to indicate the end of a message.
The WebSocket transport is activated by making a WebSocket connection to `[endpoint-base]`. The **optional** `connectionId` query string value is used to identify the connection to attach to. If there is no `connectionId` query string value, a new connection is established. If the parameter is specified but there is no connection with the specified ID value, a `404 Not Found` response is returned. Upon receiving this request, the connection is established and the server responds with a WebSocket upgrade (`101 Switching Protocols`) immediately ready for frames to be sent/received. The WebSocket OpCode field is used to indicate the type of the frame (Text or Binary) and the WebSocket "FIN" flag is used to indicate the end of a message.
Establishing a second WebSocket connection when there is already a WebSocket connection associated with the Endpoints connection is not permitted and will fail with a `409 Conflict` status code.
@ -40,9 +40,9 @@ Errors while establishing the connection are handled by returning a `500 Server
HTTP Post is a half-transport, it is only able to send messages from the Client to the Server, as such it is always used with one of the other half-transports which can send from Server to Client (Server Sent Events and Long Polling).
This transport requires that a connection be established using the `[endpoint-base]/negotiate` end point.
This transport requires that a connection be established using the `OPTIONS [endpoint-base]` request.
The HTTP POST request is made to the URL `[endpoint-base]/send`. The **mandatory** `connectionId` query string value is used to identify the connection to send to. If there is no `connectionId` query string value, a `400 Bad Request` response is returned. The content consists of frames in the same format as the Long Polling transport. It is up to the client which of the Text or Binary protocol they use, and the server is able to detect which they use either via the `Content-Type` (see the Long Polling transport section) or via the first byte (`T` for the Text-based protocol, `B` for the binary protocol). Upon receipt of the **entire** request, the server will process and deliver all the messages, responding with `202 Accepted` if all the messages are successfully processed. If a client makes another request to `/send` while an existing one is outstanding, the new request is immediately terminated by the server with the `409 Conflict` status code.
The HTTP POST request is made to the URL `[endpoint-base]`. The **mandatory** `connectionId` query string value is used to identify the connection to send to. If there is no `connectionId` query string value, a `400 Bad Request` response is returned. The content consists of frames in the same format as the Long Polling transport. It is up to the client which of the Text or Binary protocol they use, and the server is able to detect which they use either via the `Content-Type` (see the Long Polling transport section) or via the first byte (`T` for the Text-based protocol, `B` for the binary protocol). Upon receipt of the **entire** request, the server will process and deliver all the messages, responding with `202 Accepted` if all the messages are successfully processed. If a client makes another request to `/` while an existing one is outstanding, the new request is immediately terminated by the server with the `409 Conflict` status code.
If the client transmits a `Close` or `Error` frame, the server will ignore and discard any frames following that frame and immediately return `202 Accepted`. Any further attempts to send to that connection will receive a `404 Not Found` response, as the connection will have been terminated.
@ -54,7 +54,7 @@ If the relevant connection has been terminated, a `404 Not Found` status code is
## Server-Sent Events (Server-to-Client only)
Server-Sent Events (SSE) is a protocol specified by WHATWG at [https://html.spec.whatwg.org/multipage/comms.html#server-sent-events](https://html.spec.whatwg.org/multipage/comms.html#server-sent-events). It is capable of sending data from server to client only, so it must be paired with the HTTP Post transport. It also requires a connection already be established using the `[endpoint-base]/negotiate` endpoint.
Server-Sent Events (SSE) is a protocol specified by WHATWG at [https://html.spec.whatwg.org/multipage/comms.html#server-sent-events](https://html.spec.whatwg.org/multipage/comms.html#server-sent-events). It is capable of sending data from server to client only, so it must be paired with the HTTP Post transport. It also requires a connection already be established using the `OPTIONS [endpoint-base]` request.
The protocol is similar to Long Polling in that the client opens a request to an endpoint and leaves it open. The server transmits frames as "events" using the SSE protocol. The protocol encodes a single event as a sequence of key-value pair lines, separated by `:` and using any of `\r\n`, `\n` or `\r` as line-terminators, followed by a final blank line. Keys can be duplicated and their values are concatenated with `\n`. So the following represents two events:
@ -72,7 +72,7 @@ foo: boz
In the first event, the value of `baz` would be `boz\nbiz\nflarg`, due to the concatenation behavior above. Full details can be found in the spec linked above.
In this transport, the client establishes an SSE connection to `[endpoint-base]/[connection-id]/sse`, and the server responds with an HTTP response with a `Content-Type` of `text/event-stream`. The **mandatory** `connectionId` query string value is used to identify the connection to send to. If there is no `connectionId` query string value, a `400 Bad Request` response is returned, if there is no connection with the specified ID, a `404 Not Found` response is returned. Each SSE event represents a single frame from client to server. The transport uses unnamed events, which means only the `data` field is available. Thus we use the first line of the `data` field for frame metadata. The frame body starts on the **second** line of the `data` field value. The first line has the following format (Identifiers in square brackets `[]` indicate fields defined below):
In this transport, the client establishes an SSE connection to `[endpoint-base]` with an `Accept` header of `text/event-stream`, and the server responds with an HTTP response with a `Content-Type` of `text/event-stream`. The **mandatory** `connectionId` query string value is used to identify the connection to send to. If there is no `connectionId` query string value, a `400 Bad Request` response is returned, if there is no connection with the specified ID, a `404 Not Found` response is returned. Each SSE event represents a single frame from client to server. The transport uses unnamed events, which means only the `data` field is available. Thus we use the first line of the `data` field for frame metadata. The frame body starts on the **second** line of the `data` field value. The first line has the following format (Identifiers in square brackets `[]` indicate fields defined below):
```
[Type]
@ -111,22 +111,28 @@ This transport will buffer incomplete frames sent by the server until the full m
## Long Polling (Server-to-Client only)
Long Polling is a server-to-client half-transport, so it is always paired with HTTP Post. It requires a connection already be established using the `[endpoint-base]/negotiate` endpoint.
Long Polling is a server-to-client half-transport, so it is always paired with HTTP Post. It requires a connection already be established using the `OPTIONS [endpoint-base]` request.
Long Polling requires that the client poll the server for new messages. Unlike traditional polling, if there are no messages available, the server will simply block the request waiting for messages to be dispatched. At some point, the server, client or an upstream proxy will likely terminate the connection, at which point the client should immediately re-send the request. Long Polling is the only transport that allows a "reconnection" where a new request can be received while the server believes an existing request is in process. This can happen because of a time out. When this happens, the existing request is immediately terminated with status code `204 No Content`. Any messages which have already been written to the existing request will be flushed and considered sent.
Since there is such a long round-trip-time for messages, given that the client must issue a request before the server can transmit a message back, Long Polling responses contain batches of multiple messages. Also, in order to support browsers which do not support XHR2, which provides the ability to read binary data, there are two different modes for the polling transport.
A Poll is established by sending an HTTP GET request to `[endpoint-base]/poll` with the following query string parameters
A Poll is established by sending an HTTP GET request to `[endpoint-base]` with the following query string parameters
* `connectionId` (Required) - The Connection ID of the destination connection.
* `supportsBinary` (Optional: default `false`) - A boolean indicating if the client supports raw binary data in responses
When messages are available, the server responds with a body in one of the two formats below (depending upon the value of `supportsBinary`). The response may be chunked, as per the chunked encoding part of the HTTP spec.
The following headers are also supported:
* `Accept: application/vnd.microsoft.aspnetcore.endpoint-messages.v1+binary` - indicates if the client supports raw binary data in responses.
* `Accept: application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text` - indicates if the client only supports text in responses.
If the Accept header doesn't specify one of the above formats, `application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text` is assumed.
When messages are available, the server responds with a body in one of the two formats below (depending upon the value of the `Accept` header). The response may be chunked, as per the chunked encoding part of the HTTP spec.
If the `connectionId` parameter is missing, a `400 Bad Request` response is returned. If there is no connection with the ID specified in `connectionId`, a `404 Not Found` response is returned.
### Text-based encoding (`supportsBinary` = `false` or not present)
### Text-based encoding
The body will be formatted as below and encoded in UTF-8. The `Content-Type` response header is set to `application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text`. Identifiers in square brackets `[]` indicate fields defined below, and parenthesis `()` indicate grouping.
@ -164,7 +170,7 @@ Note that the final frame still ends with the `;` terminator, and that since the
This transport will buffer incomplete frames sent by the server until the full message is available and then send the message in a single frame.
### Binary encoding (`supportsBinary` = `true`)
### Binary encoding
In JavaScript/Browser clients, this encoding requires XHR2 (or similar HTTP request functionality which allows binary data) and TypedArray support.

View File

@ -187,18 +187,20 @@ namespace Microsoft.AspNetCore.Sockets.Client
private static async Task<string> GetConnectionId(Uri url, HttpClient httpClient, ILogger logger)
{
var negotiateUrl = Utils.AppendPath(url, "negotiate");
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
var connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Establishing Connection at: {0}", url);
var request = new HttpRequestMessage(HttpMethod.Options, url);
var response = await httpClient.SendAsync(request);
response.EnsureSuccessStatusCode();
var connectionId = await response.Content.ReadAsStringAsync();
logger.LogDebug("Connection Id: {0}", connectionId);
return connectionId;
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", url, ex);
throw;
}
}

View File

@ -45,9 +45,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
_application = application;
// Start sending and polling (ask for binary if the server supports it)
var pollUrl = Utils.AppendQueryString(Utils.AppendPath(url, "poll"), "supportsBinary=true");
_poller = Poll(pollUrl, _transportCts.Token);
_sender = SendUtils.SendMessages(Utils.AppendPath(url, "send"), _application, _httpClient, _transportCts, _logger);
_poller = Poll(url, _transportCts.Token);
_sender = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger);
Running = Task.WhenAll(_sender, _poller).ContinueWith(t =>
{
@ -87,6 +86,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
var request = new HttpRequestMessage(HttpMethod.Get, pollUrl);
request.Headers.UserAgent.Add(SendUtils.DefaultUserAgentHeader);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue(MessageFormatter.BinaryContentType));
var response = await _httpClient.SendAsync(request, cancellationToken);
response.EnsureSuccessStatusCode();

View File

@ -44,10 +44,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger.LogInformation("Starting {transportName}", nameof(ServerSentEventsTransport));
_application = application;
var sseUrl = Utils.AppendPath(url, "sse");
var sendUrl = Utils.AppendPath(url, "send");
var sendTask = SendUtils.SendMessages(sendUrl, _application, _httpClient, _transportCts, _logger);
var receiveTask = OpenConnection(_application, sseUrl, _transportCts.Token);
var sendTask = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger);
var receiveTask = OpenConnection(_application, url, _transportCts.Token);
Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t =>
{

View File

@ -182,8 +182,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
uriBuilder.Scheme = "wss";
}
uriBuilder.Path += "/ws";
await _webSocket.ConnectAsync(uriBuilder.Uri, _cancellationToken);
}

View File

@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
@ -42,19 +43,34 @@ namespace Microsoft.AspNetCore.Sockets
return;
}
if (context.Request.Path.StartsWithSegments(path + "/negotiate"))
if (context.Request.Path.Equals(path, StringComparison.OrdinalIgnoreCase))
{
await ProcessNegotiate(context, options);
}
else if (context.Request.Path.StartsWithSegments(path + "/send"))
{
await ProcessSend(context);
if (HttpMethods.IsOptions(context.Request.Method))
{
// OPTIONS /{path}
await ProcessNegotiate(context, options);
}
else if (HttpMethods.IsPost(context.Request.Method))
{
// POST /{path}
await ProcessSend(context);
}
else if (HttpMethods.IsGet(context.Request.Method))
{
// GET /{path}
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
await ExecuteEndpointAsync(path, context, endpoint, options);
}
else
{
context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
}
}
else
{
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
await ExecuteEndpointAsync(path, context, endpoint, options);
context.Response.StatusCode = StatusCodes.Status400BadRequest;
}
}
@ -63,7 +79,10 @@ namespace Microsoft.AspNetCore.Sockets
var supportedTransports = options.Transports;
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
// GET /{path}
// Accept: text/event-stream
var headers = context.Request.GetTypedHeaders();
if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true)
{
// Connection must already exist
var state = await GetConnectionAsync(context);
@ -84,7 +103,7 @@ namespace Microsoft.AspNetCore.Sockets
await DoPersistentConnection(endpoint, sse, context, state);
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
else if (context.Features.Get<IHttpWebSocketConnectionFeature>()?.IsWebSocketRequest == true)
{
// Connection can be established lazily
var state = await GetOrCreateConnectionAsync(context);
@ -104,8 +123,10 @@ namespace Microsoft.AspNetCore.Sockets
await DoPersistentConnection(endpoint, ws, context, state);
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
else
{
// GET /{path} maps to long polling
// Connection must already exist
var state = await GetConnectionAsync(context);
if (state == null)
@ -302,6 +323,11 @@ namespace Microsoft.AspNetCore.Sockets
private Task ProcessNegotiate<TEndPoint>(HttpContext context, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
{
// Set the allowed headers for this resource
context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS");
context.Response.ContentType = "text/plain";
// Establish the connection
var state = CreateConnection(context);
@ -352,7 +378,7 @@ namespace Microsoft.AspNetCore.Sockets
var messages = ParseSendBatch(ref reader, messageFormat);
// REVIEW: Do we want to return a specific status code here if the connection has ended?
_logger.LogDebug("Received batch of {count} message(s) in '/send'", messages.Count);
_logger.LogDebug("Received batch of {count} message(s)", messages.Count);
foreach (var message in messages)
{
while (!state.Application.Output.TryWrite(message))

View File

@ -37,10 +37,10 @@ namespace Microsoft.AspNetCore.Sockets.Transports
return;
}
// REVIEW: We could also use the 'Accept' header, in theory...
var messageFormat = string.Equals(context.Request.Query["supportsBinary"], "true", StringComparison.OrdinalIgnoreCase) ?
MessageFormat.Binary :
MessageFormat.Text;
var headers = context.Request.GetTypedHeaders();
var messageFormat = headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue(MessageFormatter.BinaryContentType)) == true ?
MessageFormat.Binary :
MessageFormat.Text;
context.Response.ContentType = MessageFormatter.GetContentType(messageFormat);
var writer = context.Response.Body.AsPipelineWriter();

View File

@ -50,11 +50,8 @@ namespace Microsoft.AspNetCore.Sockets.Transports
public async Task ProcessRequestAsync(HttpContext context, CancellationToken token)
{
var feature = context.Features.Get<IHttpWebSocketConnectionFeature>();
if (feature == null || !feature.IsWebSocketRequest)
{
_logger.LogWarning("Unable to handle WebSocket request, there is no WebSocket feature available.");
return;
}
Debug.Assert(feature != null, $"The {nameof(IHttpWebSocketConnectionFeature)} feature is missing!");
using (var ws = await feature.AcceptWebSocketConnectionAsync(_emptyContext))
{

View File

@ -298,7 +298,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
if (request.Method == HttpMethod.Get)
{
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
}
@ -479,7 +479,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/send"))
if (request.Method == HttpMethod.Post)
{
sendTcs.SetResult(await request.Content.ReadAsByteArrayAsync());
}
@ -525,7 +525,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
await Task.Yield();
var content = string.Empty;
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
if (request.Method == HttpMethod.Get)
{
content = "T2:T:42;";
}
@ -555,7 +555,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/send"))
if (request.Method == HttpMethod.Post)
{
return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError);
}
@ -586,7 +586,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
await Task.Yield();
var content = string.Empty;
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
if (request.Method == HttpMethod.Get)
{
content = "T2:T:42;";
}
@ -632,7 +633,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
if (request.Method == HttpMethod.Get)
{
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
}

View File

@ -137,7 +137,7 @@ namespace Microsoft.AspNetCore.Client.Tests
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
var statusCode = request.RequestUri.AbsolutePath.EndsWith("send")
var statusCode = request.Method == HttpMethod.Post
? HttpStatusCode.InternalServerError
: HttpStatusCode.OK;
return ResponseUtils.CreateResponse(statusCode);
@ -313,7 +313,7 @@ namespace Microsoft.AspNetCore.Client.Tests
// Check the provided request
Assert.Equal(2, sentRequests.Count);
Assert.Equal("?supportsBinary=true", sentRequests[0].RequestUri.Query);
Assert.Contains(MessageFormatter.BinaryContentType, sentRequests[0].Headers.Accept.FirstOrDefault()?.ToString());
// Check the messages received
Assert.Equal(2, messages.Count);
@ -340,7 +340,7 @@ namespace Microsoft.AspNetCore.Client.Tests
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.LocalPath.EndsWith("send"))
if (request.Method == HttpMethod.Post)
{
// Build a new request object, but convert the entire payload to string
sentRequests.Add(await request.Content.ReadAsByteArrayAsync());

View File

@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
const string message = "Hello, World!";
using (var ws = new ClientWebSocket())
{
string socketUrl = _serverFixture.WebSocketsUrl + "/echo/ws";
string socketUrl = _serverFixture.WebSocketsUrl + "/echo";
logger.LogInformation("Connecting WebSocket to {socketUrl}", socketUrl);
await ws.ConnectAsync(new Uri(socketUrl), CancellationToken.None).OrTimeout();

View File

@ -39,9 +39,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/negotiate";
context.Request.Path = "/foo";
context.Request.Method = "OPTIONS";
context.Response.Body = ms;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
var id = Encoding.UTF8.GetString(ms.ToArray());
@ -51,11 +52,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
[Theory]
[InlineData("/send")]
[InlineData("/sse")]
[InlineData("/poll")]
[InlineData("/ws")]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(string path)
[InlineData(TransportType.WebSockets)]
[InlineData(TransportType.ServerSentEvents)]
[InlineData(TransportType.LongPolling)]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(TransportType transportType)
{
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
@ -69,13 +69,46 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
}
}
[Fact]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvidedForPost()
{
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
@ -84,10 +117,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
[Theory]
[InlineData("/send")]
[InlineData("/sse")]
[InlineData("/poll")]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(string path)
[InlineData(TransportType.ServerSentEvents)]
[InlineData(TransportType.LongPolling)]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(TransportType transportType)
{
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
@ -99,9 +131,36 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
context.Request.Path = "/foo";
context.Request.Method = "GET";
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
}
}
[Fact]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvidedForPost()
{
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "POST";
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -122,12 +181,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
context.Request.Path = "/foo";
context.Request.Method = "POST";
context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}");
context.Request.ContentType = "text/plain";
context.Response.Body = strm;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -177,9 +237,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/sse", state);
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
@ -196,9 +257,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<SynchronusExceptionEndPoint>("/sse", state);
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("", context);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
@ -215,9 +277,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<SynchronusExceptionEndPoint>("/poll", state);
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("", context);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -234,9 +296,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/poll", state);
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -253,37 +315,41 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/ws", state, isWebSocketRequest: true);
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
SetTransport(context, TransportType.WebSockets);
var task = dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
var task = dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
await task.OrTimeout();
}
[Theory]
[InlineData("/ws", true)]
[InlineData("/sse", false)]
public async Task RequestToActiveConnectionId409ForStreamingTransports(string path, bool isWebSocketRequest)
[InlineData(TransportType.WebSockets)]
[InlineData(TransportType.ServerSentEvents)]
public async Task RequestToActiveConnectionId409ForStreamingTransports(TransportType transportType)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>(path, state, isWebSocketRequest: isWebSocketRequest);
var context2 = MakeRequest<TestEndPoint>(path, state, isWebSocketRequest: isWebSocketRequest);
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("", context1);
SetTransport(context1, transportType);
SetTransport(context2, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("", context2);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context1);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context2);
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
var webSocketTask = Task.CompletedTask;
if (isWebSocketRequest)
var ws = (TestWebSocketConnectionFeature)context1.Features.Get<IHttpWebSocketConnectionFeature>();
if (ws != null)
{
var ws = (TestWebSocketConnectionFeature)context1.Features.Get<IHttpWebSocketConnectionFeature>();
webSocketTask = ws.Client.ExecuteAsync(frame => Task.CompletedTask);
await ws.Client.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure), CancellationToken.None);
}
@ -303,11 +369,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/poll", state);
var context2 = MakeRequest<TestEndPoint>("/poll", state);
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("", context1);
var request2 = dispatcher.ExecuteAsync<TestEndPoint>("", context2);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context1);
var request2 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context2);
await request1;
@ -322,9 +388,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
[Theory]
[InlineData("/sse")]
[InlineData("/poll")]
public async Task RequestToDisposedConnectionIdReturns404(string path)
[InlineData(TransportType.ServerSentEvents)]
[InlineData(TransportType.LongPolling)]
public async Task RequestToDisposedConnectionIdReturns404(TransportType transportType)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
@ -332,9 +398,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>(path, state);
var context = MakeRequest<TestEndPoint>("/foo", state);
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
}
@ -347,9 +414,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>("/poll", state);
var context = MakeRequest<TestEndPoint>("/foo", state);
var task = dispatcher.ExecuteAsync<TestEndPoint>("", context);
var task = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -372,9 +439,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<BlockingEndPoint>("/sse", state);
var context = MakeRequest<BlockingEndPoint>("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("/foo", context);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -397,9 +465,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<BlockingEndPoint>("/poll", state);
var context = MakeRequest<BlockingEndPoint>("/foo", state);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("/foo", context);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -422,10 +490,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<BlockingEndPoint>("/poll", state);
var task1 = dispatcher.ExecuteAsync<BlockingEndPoint>("", context1);
var context2 = MakeRequest<BlockingEndPoint>("/poll", state);
var task2 = dispatcher.ExecuteAsync<BlockingEndPoint>("", context2);
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var task1 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context1);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var task2 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context2);
// Task 1 should finish when request 2 arrives
await task1.OrTimeout();
@ -487,7 +555,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
services.AddEndPoint<TestEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
@ -498,7 +566,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -508,7 +577,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
@ -522,7 +591,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
services.AddEndPoint<TestEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
@ -536,7 +605,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -549,7 +619,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -567,7 +637,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
services.AddEndPoint<TestEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
options.AuthorizationPolicyNames.Add("secondPolicy");
@ -580,7 +650,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -594,14 +665,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// fully "authorize" user
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -618,7 +689,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
services.AddEndPoint<TestEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
@ -633,7 +704,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -646,7 +718,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -664,7 +736,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
services.AddEndPoint<TestEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
@ -679,7 +751,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -693,7 +766,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would block if EndPoint was executed
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
@ -747,22 +820,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status)
{
var path = "";
switch (transportType)
{
case TransportType.WebSockets:
path = "/ws";
break;
case TransportType.ServerSentEvents:
path = "/sse";
break;
case TransportType.LongPolling:
path = "/poll";
break;
default:
break;
}
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
@ -777,13 +834,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
options.Transports = supportedTransports;
});
SetTransport(context, transportType);
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
Assert.Equal(status, context.Response.StatusCode);
await strm.FlushAsync();
@ -802,7 +861,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>("/send", state, format);
var context = MakeRequest<TestEndPoint>("/foo", state, format);
context.Request.Method = "POST";
context.Request.ContentType = contentType;
var endPoint = context.RequestServices.GetRequiredService<TestEndPoint>();
@ -812,7 +872,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var messages = new List<Message>();
using (context.Request.Body = new MemoryStream(buffer, writable: false))
{
await dispatcher.ExecuteAsync<TestEndPoint>("", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
}
while (state.Connection.Transport.Input.TryRead(out var message))
@ -823,7 +883,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
return messages;
}
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state, string format = null, bool isWebSocketRequest = false) where TEndPoint : EndPoint
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint
{
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -836,6 +896,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
if (format != null)
@ -845,16 +906,24 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
if (isWebSocketRequest)
{
// Add Test WebSocket feature
context.Features.Set<IHttpWebSocketConnectionFeature>(new TestWebSocketConnectionFeature());
}
return context;
}
private static void SetTransport(HttpContext context, TransportType transportType)
{
switch (transportType)
{
case TransportType.WebSockets:
context.Features.Set<IHttpWebSocketConnectionFeature>(new TestWebSocketConnectionFeature());
break;
case TransportType.ServerSentEvents:
context.Request.Headers["Accept"] = "text/event-stream";
break;
default:
break;
}
}
private static ConnectionManager CreateConnectionManager()
{
return new ConnectionManager(new Logger<ConnectionManager>(new LoggerFactory()));

View File

@ -8,6 +8,7 @@ using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
using Xunit;
@ -61,7 +62,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
if (format == MessageFormat.Binary)
{
context.Request.QueryString = QueryString.Create("supportsBinary", "true");
context.Request.Headers["Accept"] = MessageFormatter.BinaryContentType;
}
var poll = new LongPollingTransport(channel, new LoggerFactory());
var ms = new MemoryStream();