Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/EntityFramework.MemoryJoin/MemoryJoiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ public static IQueryable<T> FromLocalList<T>(this DbContext context, IList<T> da
/// <returns></returns>
public static IQueryable<T> FromLocalList<T>(this DbContext context, IList<T> data, Type queryClass, ValuesInjectionMethod method)
{
var tableName = EFHelper.GetTableName(context, queryClass);
if (MemoryJoinerInterceptor.IsInterceptionEnabled(
new[] { context }, out InterceptionOptions opts))
new[] {context}, false, out var opts) && opts.Any(x => x.QueryTableName == tableName))
{
throw new InvalidOperationException(
"Only one data set can be applied to single DbContext before actuall DB request is done");
"A table name (attribute on queryClass) can be applied only once to single DbContext before actual DB request is done");
}

var propMapping = allowedMappingDict.GetOrAdd(queryClass, MappingHelper.GetPropertyMappings);
Expand Down Expand Up @@ -117,20 +118,22 @@ static void PrepareInjection<T>(
Type queryClass,
ValuesInjectionMethod method)
{
var tableName = EFHelper.GetTableName(context, queryClass);

var opts = new InterceptionOptions
{
QueryTableName = EFHelper.GetTableName(context, queryClass),
QueryTableName = tableName,
ColumnNames = mapping.UserProperties.Keys.ToArray(),
Data = data
.Select(x => mapping.UserProperties.ToDictionary(y => y.Key, y => y.Value(x)))
.ToList(),
ContextType = context.GetType(),
ValuesInjectMethod = (ValuesInjectionMethodInternal)method,
KeyColumnName = mapping.KeyColumnName
KeyColumnName = mapping.KeyColumnName,
DynamicTableName = tableName
};

MemoryJoinerInterceptor.SetInterception(context, opts);
}

}
}
110 changes: 72 additions & 38 deletions src/EntityFramework.MemoryJoin/MemoryJoinerInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,47 @@ namespace EntityFramework.MemoryJoin
{
internal class MemoryJoinerInterceptor : IDbCommandInterceptor
{
private static readonly ConcurrentDictionary<DbContext, InterceptionOptions> InterceptionOptions =
new ConcurrentDictionary<DbContext, InterceptionOptions>();
private static readonly ConcurrentDictionary<DbContext, List<InterceptionOptions>> InterceptionOptions =
new ConcurrentDictionary<DbContext, List<InterceptionOptions>>();

private static readonly Object Locker = new object();

internal static void SetInterception(DbContext context, InterceptionOptions options)
{
InterceptionOptions[context] = options;
lock (Locker)
{
if (!InterceptionOptions.TryGetValue(context, out var opts))
{
opts = new List<InterceptionOptions>();
InterceptionOptions[context] = opts;
}

opts.Add(options);
}
}

internal static bool IsInterceptionEnabled(IEnumerable<DbContext> contexts, out InterceptionOptions options)
internal static bool IsInterceptionEnabled(IEnumerable<DbContext> contexts, bool removeContextOptions,
out IReadOnlyList<InterceptionOptions> options)
{
options = null;
using (var enumerator = contexts.GetEnumerator())
lock (Locker)
{
if (!enumerator.MoveNext()) return false;

var firstOne = enumerator.Current;
var result = firstOne != null &&
InterceptionOptions.TryGetValue(firstOne, out options) &&
!enumerator.MoveNext();
if (result)
InterceptionOptions.TryRemove(firstOne, out options);

return result;
options = null;
List<InterceptionOptions> internalOptions = null;
using (var enumerator = contexts.GetEnumerator())
{
if (!enumerator.MoveNext()) return false;

var firstOne = enumerator.Current;
var result = firstOne != null &&
InterceptionOptions.TryGetValue(firstOne, out internalOptions) &&
!enumerator.MoveNext();
options = internalOptions;

if (result && removeContextOptions)
InterceptionOptions.TryRemove(firstOne, out _);

return result;
}
}
}

Expand All @@ -43,7 +61,7 @@ public void NonQueryExecuted(DbCommand command, DbCommandInterceptionContext<int

public void NonQueryExecuting(DbCommand command, DbCommandInterceptionContext<int> interceptionContext)
{
if (IsInterceptionEnabled(interceptionContext.DbContexts, out var opts))
if (IsInterceptionEnabled(interceptionContext.DbContexts, true, out var opts))
ModifyQuery(command, opts);
}

Expand All @@ -53,7 +71,7 @@ public void ReaderExecuted(DbCommand command, DbCommandInterceptionContext<DbDat

public void ReaderExecuting(DbCommand command, DbCommandInterceptionContext<DbDataReader> interceptionContext)
{
if (IsInterceptionEnabled(interceptionContext.DbContexts, out var opts))
if (IsInterceptionEnabled(interceptionContext.DbContexts, true, out var opts))
ModifyQuery(command, opts);
}

Expand All @@ -63,34 +81,50 @@ public void ScalarExecuted(DbCommand command, DbCommandInterceptionContext<objec

public void ScalarExecuting(DbCommand command, DbCommandInterceptionContext<object> interceptionContext)
{
if (IsInterceptionEnabled(interceptionContext.DbContexts, out var opts))
if (IsInterceptionEnabled(interceptionContext.DbContexts, true, out var opts))
ModifyQuery(command, opts);
}

private static void ModifyQuery(DbCommand command, InterceptionOptions opts)
private static void ModifyQuery(DbCommand command, IReadOnlyList<InterceptionOptions> opts)
{
var tableNamePosition = command.CommandText.IndexOf(opts.QueryTableName, StringComparison.Ordinal);
if (tableNamePosition < 0)
return;
var sb = new StringBuilder(100);
sb.Append("WITH ");
var counter = 0;
var commandStart = 0;
foreach (var currentOptions in opts)
{
var tableNamePosition =
command.CommandText.IndexOf(currentOptions.QueryTableName, StringComparison.Ordinal);
if (tableNamePosition < 0)
continue;

var nextSpace = command.CommandText.IndexOf(' ', tableNamePosition);
var prevSpace = command.CommandText.LastIndexOf(' ', tableNamePosition);
var tableFullName = command.CommandText.Substring(prevSpace + 1, nextSpace - prevSpace - 1);
commandStart = command.CommandText.LastIndexOf(';', tableNamePosition) + 1;

command.CommandText = command.CommandText.Replace(tableFullName, opts.DynamicTableName);
var nextSpace = command.CommandText.IndexOf(' ', tableNamePosition);
var prevSpace = command.CommandText.LastIndexOf(' ', tableNamePosition);
var tableFullName = command.CommandText.Substring(prevSpace + 1, nextSpace - prevSpace - 1);

var sb = new StringBuilder(100);
sb.Append("WITH ").Append(opts.DynamicTableName).Append(" AS (").AppendLine();
MappingHelper.ComposeTableSql(
sb, opts,
command,
command.Parameters);
command.CommandText = command.CommandText.Replace(tableFullName, currentOptions.DynamicTableName);

if (counter > 0)
{
sb.AppendLine(",");
}

sb.Append(currentOptions.DynamicTableName).Append(" AS (").AppendLine();

sb.AppendLine();
sb.AppendLine(")");
sb.Append(command.CommandText);
MappingHelper.ComposeTableSql(
sb, currentOptions,
command,
command.Parameters);

sb.AppendLine();
sb.AppendLine(")");

counter++;
}

command.CommandText = sb.ToString();
command.CommandText = command.CommandText.Insert(commandStart, sb.ToString());
}
}
}
}