Sometimes Parallel.ForEach is not what we want - instead, we need to process pieces of each collection and have control over when all the processing is complete.
Introduction
Parallel.ForEach
is great when you want to process each element if collection in parallel, up to the number of logical processors. However, sometimes we need the ability to process a portion, or batch, of the entire collection, each in its own thread. A use case for this requirement is database transactions -- let's say you have a large number of transactions to insert
or update
, which can be done in parallel. With Parallel.ForEach
, you would have to initialize the connection (granted, they are pooled) and other time and memory consuming activities, like initializing the Linq2SQL context. For example:
Parallel.ForEach(someCollection, item =>
{
var conn = new SqlConnection(connectionString);
using (var context = new ModelDataContext(conn))
{
}
});
This sort of defeats the purpose of processing the collection in parallel. And while you could create a ConcurrentDictionary
with the thread ID as a key and the context as the value, that is both silly and not a general solution.
Let's Look at Parallel.ForEach
Let's take a quick look at Parallel.ForEach
:
Here's a simple test function that looks at how many threads Parallel.ForEach
actually creates to process a simple collection of integers. Notice we have a 1ms delay in each iteration to force the threads into actually doing a little work.
class Program
{
static ConcurrentDictionary<int, int> threadIdCounts;
static void Main(string[] args)
{
threadIdCounts = new ConcurrentDictionary<int, int>();
var plr = Parallel.ForEach(Enumerable.Range(0, 1000), DoSomething);
threadIdCounts.ForEach(kvp => Console.WriteLine($"TID: {kvp.Key}, Count = {kvp.Value}"));
}
static void DoSomething(int n)
{
DoWork();
}
static void DoWork()
{
int tid = Thread.CurrentThread.ManagedThreadId;
if (!threadIdCounts.TryGetValue(tid, out int count))
{
threadIdCounts[tid] = 0;
}
threadIdCounts[tid] = count + 1;
Thread.Sleep(1);
}
}
Notice that Parallel.Task
ended up creating 5 threads even though my laptop only has four logical cores:
TID: 1, Count = 189
TID: 3, Count = 189
TID: 4, Count = 189
TID: 5, Count = 189
TID: 6, Count = 244
Press any key to continue . . .
Also, Parallel.ForEach
blocks the calling thread until all the tasks are complete, which is also something you may not wish to do.
Introducing BatchParallel
BatchParallel
is an extension method that splits the collection into n / numProcessors
sub-collections, then invokes the action for each sub-collection, adding an additional task for any remainder.
public static Task[] BatchParallel<T>(this IEnumerable<T> collection,
Action<IEnumerable<T>> action, bool singleThread = false)
{
int processors = singleThread ? 1 : Environment.ProcessorCount;
int n = collection.Count();
int nPerProc = n / processors;
Task[] tasks = new Task[processors + 1];
processors.ForEach(p => tasks[p] =
Task.Run(() => action(collection.Skip(p * nPerProc).Take(nPerProc))));
int remainder = n - nPerProc * processors;
var lastTask = Task.Run(() =>
action(collection.Skip(nPerProc * processors).Take(remainder)));
tasks[processors] = lastTask;
return tasks;
}
Furthermore, it returns the Task
collection so you choose when you want to await on the completion of the tasks. There is also an option to run all the tasks on a single thread, which I find makes debugging a lot easier.
Usage Example
This example shows both Parallel.ForEach
and the BatchParallel
usage:
class Program
{
static ConcurrentDictionary<int, int> threadIdCounts;
static void Main(string[] args)
{
Console.WriteLine("Parallel.ForeEach example:");
threadIdCounts = new ConcurrentDictionary<int, int>();
var plr = Parallel.ForEach(Enumerable.Range(0, 1000), DoSomething);
threadIdCounts.ForEach(kvp => Console.WriteLine($"TID: {kvp.Key}, Count = {kvp.Value}"));
Console.WriteLine("\r\nBatchParallel example:");
threadIdCounts = new ConcurrentDictionary<int, int>();
var tasks = Enumerable.Range(0, 1000).BatchParallel(batch => DoSomething(batch));
Task.WaitAll(tasks);
threadIdCounts.ForEach(kvp => Console.WriteLine($"TID: {kvp.Key}, Count = {kvp.Value}"));
}
static void DoSomething(int n)
{
DoWork();
}
static void DoSomething<T>(IEnumerable<T> batch)
{
batch.ForEach(n => DoWork());
}
static void DoWork()
{
int tid = Thread.CurrentThread.ManagedThreadId;
if (!threadIdCounts.TryGetValue(tid, out int count))
{
threadIdCounts[tid] = 0;
}
threadIdCounts[tid] = count + 1;
Thread.Sleep(1);
}
}
Result
Parallel.ForeEach example:
TID: 1, Count = 244
TID: 3, Count = 189
TID: 4, Count = 189
TID: 5, Count = 189
TID: 6, Count = 189
BatchParallel example:
TID: 3, Count = 250
TID: 4, Count = 250
TID: 5, Count = 250
TID: 6, Count = 250
Press any key to continue . . .
Things to Note
Notice that BatchParallel
created threads only for the number of logical cores that I have and split the work evenly. You can also pass in an optional parameter as true
if you want all the items in the collection to process in parallel, and lastly, the collection of Task
objects is returned, giving you the choice as to when to wait for the completion of the tasks.
Remainder Edge Case
As a simple non-unit test example of a case when there is a remainder. Given:
tasks = Enumerable.Range(0, 1003).BatchParallel(batch => DoSomething(batch));
We now see:
BatchParallel with remainder example:
TID: 3, Count = 250
TID: 4, Count = 250
TID: 5, Count = 253
TID: 6, Count = 250
Press any key to continue . . .
Note that the Task library reused thread ID 5 to process the first 250 and the remaining 3.
Additional Extension Methods I'm Using
I'm also using these extension methods:
public static void ForEach<T>(this IEnumerable<T> collection, Action<T> action)
{
foreach (var item in collection)
{
action(item);
}
}
public static void ForEach(this int n, Action<int> action)
{
for (int i = 0; i < n; i++)
{
action(i);
}
}
Conclusion
Not sure what to say here, the work should speak for itself. Hopefully, you find this useful, and you can always implement a static
method instead of an extension method:
public static class Batch
{
public static Task[] Parallel<T>(this IEnumerable<T> collection,
Action<IEnumerable<T>> action, bool singleThread = false)
{
int processors = singleThread ? 1 : Environment.ProcessorCount;
int n = collection.Count();
int nPerProc = n / processors;
Task[] tasks = new Task[processors + 1];
processors.ForEach(p => tasks[p] =
Task.Run(() => action(collection.Skip(p * nPerProc).Take(nPerProc))));
int remainder = n - nPerProc * processors;
var lastTask = Task.Run(() =>
action(collection.Skip(nPerProc * processors).Take(remainder)));
tasks[processors] = lastTask;
return tasks;
}
}
Usage:
tasks = Batch.Parallel(Enumerable.Range(0, 1003), batch => DoSomething(batch));
Task.WaitAll(tasks);
Have fun!
History
- 13th December, 2020: Initial version