diff --git a/src/EntityFramework.MemoryJoin/MemoryJoiner.cs b/src/EntityFramework.MemoryJoin/MemoryJoiner.cs index c9f72df..173943a 100644 --- a/src/EntityFramework.MemoryJoin/MemoryJoiner.cs +++ b/src/EntityFramework.MemoryJoin/MemoryJoiner.cs @@ -77,11 +77,12 @@ public static IQueryable FromLocalList(this DbContext context, IList da /// public static IQueryable FromLocalList(this DbContext context, IList data, Type queryClass, ValuesInjectionMethod method) { + var tableName = EFHelper.GetTableName(context, queryClass); if (MemoryJoinerInterceptor.IsInterceptionEnabled( - new[] { context }, out InterceptionOptions opts)) + new[] {context}, false, out var opts) && opts.Any(x => x.QueryTableName == tableName)) { throw new InvalidOperationException( - "Only one data set can be applied to single DbContext before actuall DB request is done"); + "A table name (attribute on queryClass) can be applied only once to single DbContext before actual DB request is done"); } var propMapping = allowedMappingDict.GetOrAdd(queryClass, MappingHelper.GetPropertyMappings); @@ -117,20 +118,22 @@ static void PrepareInjection( Type queryClass, ValuesInjectionMethod method) { + var tableName = EFHelper.GetTableName(context, queryClass); + var opts = new InterceptionOptions { - QueryTableName = EFHelper.GetTableName(context, queryClass), + QueryTableName = tableName, ColumnNames = mapping.UserProperties.Keys.ToArray(), Data = data .Select(x => mapping.UserProperties.ToDictionary(y => y.Key, y => y.Value(x))) .ToList(), ContextType = context.GetType(), ValuesInjectMethod = (ValuesInjectionMethodInternal)method, - KeyColumnName = mapping.KeyColumnName + KeyColumnName = mapping.KeyColumnName, + DynamicTableName = tableName }; MemoryJoinerInterceptor.SetInterception(context, opts); } - } } diff --git a/src/EntityFramework.MemoryJoin/MemoryJoinerInterceptor.cs b/src/EntityFramework.MemoryJoin/MemoryJoinerInterceptor.cs index cb34ceb..deb0e5d 100644 --- a/src/EntityFramework.MemoryJoin/MemoryJoinerInterceptor.cs +++ b/src/EntityFramework.MemoryJoin/MemoryJoinerInterceptor.cs @@ -11,29 +11,47 @@ namespace EntityFramework.MemoryJoin { internal class MemoryJoinerInterceptor : IDbCommandInterceptor { - private static readonly ConcurrentDictionary InterceptionOptions = - new ConcurrentDictionary(); + private static readonly ConcurrentDictionary> InterceptionOptions = + new ConcurrentDictionary>(); + + private static readonly Object Locker = new object(); internal static void SetInterception(DbContext context, InterceptionOptions options) { - InterceptionOptions[context] = options; + lock (Locker) + { + if (!InterceptionOptions.TryGetValue(context, out var opts)) + { + opts = new List(); + InterceptionOptions[context] = opts; + } + + opts.Add(options); + } } - internal static bool IsInterceptionEnabled(IEnumerable contexts, out InterceptionOptions options) + internal static bool IsInterceptionEnabled(IEnumerable contexts, bool removeContextOptions, + out IReadOnlyList options) { - options = null; - using (var enumerator = contexts.GetEnumerator()) + lock (Locker) { - if (!enumerator.MoveNext()) return false; - - var firstOne = enumerator.Current; - var result = firstOne != null && - InterceptionOptions.TryGetValue(firstOne, out options) && - !enumerator.MoveNext(); - if (result) - InterceptionOptions.TryRemove(firstOne, out options); - - return result; + options = null; + List internalOptions = null; + using (var enumerator = contexts.GetEnumerator()) + { + if (!enumerator.MoveNext()) return false; + + var firstOne = enumerator.Current; + var result = firstOne != null && + InterceptionOptions.TryGetValue(firstOne, out internalOptions) && + !enumerator.MoveNext(); + options = internalOptions; + + if (result && removeContextOptions) + InterceptionOptions.TryRemove(firstOne, out _); + + return result; + } } } @@ -43,7 +61,7 @@ public void NonQueryExecuted(DbCommand command, DbCommandInterceptionContext interceptionContext) { - if (IsInterceptionEnabled(interceptionContext.DbContexts, out var opts)) + if (IsInterceptionEnabled(interceptionContext.DbContexts, true, out var opts)) ModifyQuery(command, opts); } @@ -53,7 +71,7 @@ public void ReaderExecuted(DbCommand command, DbCommandInterceptionContext interceptionContext) { - if (IsInterceptionEnabled(interceptionContext.DbContexts, out var opts)) + if (IsInterceptionEnabled(interceptionContext.DbContexts, true, out var opts)) ModifyQuery(command, opts); } @@ -63,34 +81,50 @@ public void ScalarExecuted(DbCommand command, DbCommandInterceptionContext interceptionContext) { - if (IsInterceptionEnabled(interceptionContext.DbContexts, out var opts)) + if (IsInterceptionEnabled(interceptionContext.DbContexts, true, out var opts)) ModifyQuery(command, opts); } - private static void ModifyQuery(DbCommand command, InterceptionOptions opts) + private static void ModifyQuery(DbCommand command, IReadOnlyList opts) { - var tableNamePosition = command.CommandText.IndexOf(opts.QueryTableName, StringComparison.Ordinal); - if (tableNamePosition < 0) - return; + var sb = new StringBuilder(100); + sb.Append("WITH "); + var counter = 0; + var commandStart = 0; + foreach (var currentOptions in opts) + { + var tableNamePosition = + command.CommandText.IndexOf(currentOptions.QueryTableName, StringComparison.Ordinal); + if (tableNamePosition < 0) + continue; - var nextSpace = command.CommandText.IndexOf(' ', tableNamePosition); - var prevSpace = command.CommandText.LastIndexOf(' ', tableNamePosition); - var tableFullName = command.CommandText.Substring(prevSpace + 1, nextSpace - prevSpace - 1); + commandStart = command.CommandText.LastIndexOf(';', tableNamePosition) + 1; - command.CommandText = command.CommandText.Replace(tableFullName, opts.DynamicTableName); + var nextSpace = command.CommandText.IndexOf(' ', tableNamePosition); + var prevSpace = command.CommandText.LastIndexOf(' ', tableNamePosition); + var tableFullName = command.CommandText.Substring(prevSpace + 1, nextSpace - prevSpace - 1); - var sb = new StringBuilder(100); - sb.Append("WITH ").Append(opts.DynamicTableName).Append(" AS (").AppendLine(); - MappingHelper.ComposeTableSql( - sb, opts, - command, - command.Parameters); + command.CommandText = command.CommandText.Replace(tableFullName, currentOptions.DynamicTableName); + + if (counter > 0) + { + sb.AppendLine(","); + } + + sb.Append(currentOptions.DynamicTableName).Append(" AS (").AppendLine(); - sb.AppendLine(); - sb.AppendLine(")"); - sb.Append(command.CommandText); + MappingHelper.ComposeTableSql( + sb, currentOptions, + command, + command.Parameters); + + sb.AppendLine(); + sb.AppendLine(")"); + + counter++; + } - command.CommandText = sb.ToString(); + command.CommandText = command.CommandText.Insert(commandStart, sb.ToString()); } } -} +} \ No newline at end of file