aspnetcore/src/Microsoft.AspNet.Session/DistributedSession.cs

294 lines
9.9 KiB
C#

// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.AspNet.Http;
using Microsoft.Framework.Caching.Distributed;
using Microsoft.Framework.Internal;
using Microsoft.Framework.Logging;
namespace Microsoft.AspNet.Session
{
public class DistributedSession : ISession
{
private const byte SerializationRevision = 1;
private const int KeyLengthLimit = ushort.MaxValue;
private readonly IDistributedCache _cache;
private readonly string _sessionId;
private readonly TimeSpan _idleTimeout;
private readonly Func<bool> _tryEstablishSession;
private readonly IDictionary<EncodedKey, byte[]> _store;
private readonly ILogger _logger;
private bool _isModified;
private bool _loaded;
private bool _isNewSessionKey;
public DistributedSession([NotNull] IDistributedCache cache, [NotNull] string sessionId, TimeSpan idleTimeout,
[NotNull] Func<bool> tryEstablishSession, [NotNull] ILoggerFactory loggerFactory, bool isNewSessionKey)
{
_cache = cache;
_sessionId = sessionId;
_idleTimeout = idleTimeout;
_tryEstablishSession = tryEstablishSession;
_store = new Dictionary<EncodedKey, byte[]>();
_logger = loggerFactory.CreateLogger<DistributedSession>();
_isNewSessionKey = isNewSessionKey;
}
public IEnumerable<string> Keys
{
get
{
Load(); // TODO: Silent failure
return _store.Keys.Select(key => key.KeyString);
}
}
public bool TryGetValue(string key, out byte[] value)
{
Load(); // TODO: Silent failure
return _store.TryGetValue(new EncodedKey(key), out value);
}
public void Set(string key, ArraySegment<byte> value)
{
var encodedKey = new EncodedKey(key);
if (encodedKey.KeyBytes.Length > KeyLengthLimit)
{
throw new ArgumentOutOfRangeException("key", key,
string.Format("The key cannot be longer than '{0}' when encoded with UTF-8.", KeyLengthLimit));
}
if (value.Array == null)
{
throw new ArgumentException("The ArraySegment<byte>.Array cannot be null.", "value");
}
Load();
if (!_tryEstablishSession())
{
throw new InvalidOperationException("The session cannot be established after the response has started.");
}
_isModified = true;
byte[] copy = new byte[value.Count];
Buffer.BlockCopy(value.Array, value.Offset, copy, 0, value.Count);
_store[encodedKey] = copy;
}
public void Remove(string key)
{
Load();
_isModified |= _store.Remove(new EncodedKey(key));
}
public void Clear()
{
Load();
_isModified |= _store.Count > 0;
_store.Clear();
}
// TODO: This should throw if called directly, but most other places it should fail silently (e.g. TryGetValue should just return null).
public void Load()
{
if (!_loaded)
{
Stream data;
if (_cache.TryGetValue(_sessionId, out data))
{
Deserialize(data);
}
else if (!_isNewSessionKey)
{
_logger.LogWarning("Accessing expired session {0}", _sessionId);
}
_loaded = true;
}
}
public void Commit()
{
if (_isModified)
{
Stream data;
if (_logger.IsEnabled(LogLevel.Information) && !_cache.TryGetValue(_sessionId, out data))
{
_logger.LogInformation("Session {0} started", _sessionId);
}
_isModified = false;
_cache.Set(_sessionId, context => {
context.SetSlidingExpiration(_idleTimeout);
Serialize(context.Data);
});
}
}
// Format:
// Serialization revision: 1 byte, range 0-255
// Entry count: 3 bytes, range 0-16,777,215
// foreach entry:
// key name byte length: 2 bytes, range 0-65,535
// UTF-8 encoded key name byte[]
// data byte length: 4 bytes, range 0-2,147,483,647
// data byte[]
private void Serialize(Stream output)
{
output.WriteByte(SerializationRevision);
SerializeNumAs3Bytes(output, _store.Count);
foreach (var entry in _store)
{
var keyBytes = entry.Key.KeyBytes;
SerializeNumAs2Bytes(output, keyBytes.Length);
output.Write(keyBytes, 0, keyBytes.Length);
SerializeNumAs4Bytes(output, entry.Value.Length);
output.Write(entry.Value, 0, entry.Value.Length);
}
}
private void Deserialize(Stream content)
{
if (content == null || content.ReadByte() != SerializationRevision)
{
// TODO: Throw?
// Replace the un-readable format.
_isModified = true;
return;
}
int expectedEntries = DeserializeNumFrom3Bytes(content);
for (int i = 0; i < expectedEntries; i++)
{
int keyLength = DeserializeNumFrom2Bytes(content);
var key = new EncodedKey(content.ReadBytes(keyLength));
int dataLength = DeserializeNumFrom4Bytes(content);
_store[key] = content.ReadBytes(dataLength);
}
}
private void SerializeNumAs2Bytes(Stream output, int num)
{
if (num < 0 || ushort.MaxValue < num)
{
throw new ArgumentOutOfRangeException("num", num, "The value cannot be serialized in two bytes.");
}
output.WriteByte((byte)(num >> 8));
output.WriteByte((byte)(0xFF & num));
}
private int DeserializeNumFrom2Bytes(Stream content)
{
return content.ReadByte() << 8 | content.ReadByte();
}
private void SerializeNumAs3Bytes(Stream output, int num)
{
if (num < 0 || 0xFFFFFF < num)
{
throw new ArgumentOutOfRangeException("num", num, "The value cannot be serialized in three bytes.");
}
output.WriteByte((byte)(num >> 16));
output.WriteByte((byte)(0xFF & (num >> 8)));
output.WriteByte((byte)(0xFF & num));
}
private int DeserializeNumFrom3Bytes(Stream content)
{
return content.ReadByte() << 16 | content.ReadByte() << 8 | content.ReadByte();
}
private void SerializeNumAs4Bytes(Stream output, int num)
{
if (num < 0)
{
throw new ArgumentOutOfRangeException("num", num, "The value cannot be negative.");
}
output.WriteByte((byte)(num >> 24));
output.WriteByte((byte)(0xFF & (num >> 16)));
output.WriteByte((byte)(0xFF & (num >> 8)));
output.WriteByte((byte)(0xFF & num));
}
private int DeserializeNumFrom4Bytes(Stream content)
{
return content.ReadByte() << 24 | content.ReadByte() << 16 | content.ReadByte() << 8 | content.ReadByte();
}
// Keys are stored in their utf-8 encoded state.
// This saves us from de-serializing and re-serializing every key on every request.
private class EncodedKey
{
private string _keyString;
private int? _hashCode;
internal EncodedKey(string key)
{
_keyString = key;
KeyBytes = Encoding.UTF8.GetBytes(key);
}
public EncodedKey(byte[] key)
{
KeyBytes = key;
}
internal string KeyString
{
get
{
if (_keyString == null)
{
_keyString = Encoding.UTF8.GetString(KeyBytes, 0, KeyBytes.Length);
}
return _keyString;
}
}
internal byte[] KeyBytes { get; private set; }
public override bool Equals(object obj)
{
var otherKey = obj as EncodedKey;
if (otherKey == null)
{
return false;
}
if (KeyBytes.Length != otherKey.KeyBytes.Length)
{
return false;
}
if (_hashCode.HasValue && otherKey._hashCode.HasValue
&& _hashCode.Value != otherKey._hashCode.Value)
{
return false;
}
for (int i = 0; i < KeyBytes.Length; i++)
{
if (KeyBytes[i] != otherKey.KeyBytes[i])
{
return false;
}
}
return true;
}
public override int GetHashCode()
{
if (!_hashCode.HasValue)
{
_hashCode = SipHash.GetHashCode(KeyBytes);
}
return _hashCode.Value;
}
public override string ToString()
{
return KeyString;
}
}
}
}