Limiting the amount of concurrent tasks in .NET 4.

2019-01-14 23:38发布

Observer the following function:

public Task RunInOrderAsync<TTaskSeed>(IEnumerable<TTaskSeed> taskSeedGenerator,
    CreateTaskDelegate<TTaskSeed> createTask,
    OnTaskErrorDelegate<TTaskSeed> onError = null,
    OnTaskSuccessDelegate<TTaskSeed> onSuccess = null) where TTaskSeed : class
{
    Action<Exception, TTaskSeed> onFailed = (exc, taskSeed) =>
    {
        if (onError != null)
        {
            onError(exc, taskSeed);
        }
    };

    Action<Task> onDone = t =>
    {
        var taskSeed = (TTaskSeed)t.AsyncState;
        if (t.Exception != null)
        {
            onFailed(t.Exception, taskSeed);
        }
        else if (onSuccess != null)
        {
            onSuccess(t, taskSeed);
        }
    };

    var enumerator = taskSeedGenerator.GetEnumerator();
    Task task = null;
    while (enumerator.MoveNext())
    {
        if (task == null)
        {
            try
            {
                task = createTask(enumerator.Current);
                Debug.Assert(ReferenceEquals(task.AsyncState, enumerator.Current));
            }
            catch (Exception exc)
            {
                onFailed(exc, enumerator.Current);
            }
        }
        else
        {
            task = task.ContinueWith((t, taskSeed) =>
            {
                onDone(t);
                var res = createTask((TTaskSeed)taskSeed);
                Debug.Assert(ReferenceEquals(res.AsyncState, taskSeed));
                return res;
            }, enumerator.Current).TaskUnwrap();
        }
    }

    if (task != null)
    {
        task = task.ContinueWith(onDone);
    }

    return task;
}

Where TaskUnwrap is the state preserving version of the standard Task.Unwrap:

public static class Extensions
{
    public static Task TaskUnwrap(this Task<Task> task, object state = null)
    {
        return task.Unwrap().ContinueWith((t, _) =>
        {
            if (t.Exception != null)
            {
                throw t.Exception;
            }
        }, state ?? task.AsyncState);
    }
}

The RunInOrderAsync method allows to run N tasks asynchronously, but sequentially - one after another. In effect, it runs the tasks created from the given seeds with the concurrency limit of 1.

Let us assume that the tasks created from the seeds by the createTask delegate do not correspond themselves to multiple concurrent tasks.

Now, I would like to throw in the maxConcurrencyLevel parameter, so the function signature would look like this:

Task RunInOrderAsync<TTaskSeed>(int maxConcurrencyLevel,
  IEnumerable<TTaskSeed> taskSeedGenerator,
  CreateTaskDelegate<TTaskSeed> createTask,
  OnTaskErrorDelegate<TTaskSeed> onError = null,
  OnTaskSuccessDelegate<TTaskSeed> onSuccess = null) where TTaskSeed : class

And here I am a bit stuck.

The SO has questions like these:

Which basically propose two ways to attack the problem:

  1. Using Parallel.ForEach with ParallelOptions specifying the MaxDegreeOfParallelism property value as equal to the desired max concurrency level.
  2. Using a custom TaskScheduler with the desired MaximumConcurrencyLevel value.

The second approach doesn't cut it, because all the tasks involved must use the same task scheduler instance. For that, all the methods used to return a Task must have an overload accepting the custom TaskScheduler instance. Unfortunately, Microsoft is not very consistent with that respect. For instance, SqlConnection.OpenAsync does not accept such an argument (but TaskFactory.FromAsync does).

The first approach implies that I will have to convert tasks to actions, something like this:

() => t.Wait()

I am not sure it is a good idea, but I will be glad to get more input on that.

Another approach is to utilize TaskFactory.ContinueWhenAny, but that is messy.

Any ideas?

EDIT 1

I would like to clarify the reasons for wanting the limit. Our tasks ultimately execute SQL statements against the same SQL server. What we want is a way to limit the amount of concurrent outgoing SQL statements. It is entirely possible that there will be other SQL statements executing concurrently from other pieces of code, but this one is a batch processor and could potentially flood the server.

Now, be advised, that although we are talking about the same SQL server, there are numerous databases on that same server. So, it is not about limiting the amount of open SQL connections to the same database, because the database may not be the same at all.

That is why doom's day solutions like ThreadPool.SetMaxThreads() are irrelevant.

Now, about SqlConnection.OpenAsync. It was made asynchronous for a reason - it might make a roundtrip to the server and thus might be subject to the network latency and other lovely side effects of distributed environment. As such it is no different than other async methods which do accept the TaskScheduler parameter. I tend to think that not accepting one is just a bug.

EDIT 2

I would like to preserve the asynchronous spirit of the original function. Hence I wish to avoid any explicit blocking solutions.

EDIT 3

Thanks to @fsimonazzi's answer I now have a working implementation of the desired functionality. Here is the code:

        var sem = new SemaphoreSlim(maxConcurrencyLevel);
        var tasks = new List<Task>();

        var enumerator = taskSeedGenerator.GetEnumerator();
        while (enumerator.MoveNext())
        {
            tasks.Add(sem.WaitAsync().ContinueWith((_, taskSeed) =>
            {
                Task task = null;
                try
                {
                    task = createTask((TTaskSeed)taskSeed);
                    if (task != null)
                    {
                        Debug.Assert(ReferenceEquals(task.AsyncState, taskSeed));
                        task = task.ContinueWith(t =>
                        {
                            sem.Release();
                            onDone(t);
                        });
                    }
                }
                catch (Exception exc)
                {
                    sem.Release();
                    onFailed(exc, (TTaskSeed)taskSeed);
                }
                return task;
            }, enumerator.Current).TaskUnwrap());
        }

        return Task.Factory.ContinueWhenAll(tasks.ToArray(), _ => sem.Dispose());

5条回答
做个烂人
2楼-- · 2019-01-15 00:13

The two best solutions available today are Semaphoreslim (as per @fsimonazzi's answer) and a TPL Dataflow block (i.e., ActionBlock<T> or TransformBlock<T>). Both of those blocks have a simple way to set the level of concurrency.

Parallel is not an ideal approach, because you would need to block on your asynchronous operations, using up a thread pool thread for each one.

Also, TaskScheduler will not work here. FYI, TaskScheduler is "inherited" through async methods as I describe on my async intro blog post. The reason it won't work for your problem is because task schedulers only control executing tasks, not event tasks - so, the SQL operations like OpenAsync don't "count" towards the concurrency limit.

查看更多
神经病院院长
3楼-- · 2019-01-15 00:15

Here is a variation of @fsimonazzi's answer without the SemaphoreSlim, as cool as that is.

private static async Task DoStuff<T>(int maxConcurrency, IEnumerable<T> items, Func<T, Task> createTask)
{
    var tasks = new List<Task>();
    foreach (var item in items)
    {
        if (tasks.Count >= maxConcurrency)
        {
            await Task.WhenAll(tasks);
            tasks.Clear();
        }
        var task = createTask(item);
        tasks.Add(task);
    }
    await Task.WhenAll(tasks);
}
查看更多
我命由我不由天
4楼-- · 2019-01-15 00:18

Here is a variation of @scott-turner's answer, as cool as that is. His answer submits work in chunks of maxConcurrency and waits until each chunk completes in full before submitting the next chunk. This variation submits new tasks as needed to try and ensure maxConcurrency tasks are always in flight. It also demonstrates working with Task< T > instead of Task.

Note the benefit of this over the SemaphoreSlim variation is with SemaphoreSlim you need to await two different types of Task's - the semaphores and the work. That's problematic if work is of type Task< T > instead of Task.

    private static async Task<R[]> concurrentAsync<T, R>(int maxConcurrency, IEnumerable<T> items, Func<T, Task<R>> createTask)
    {
        var allTasks = new List<Task<R>>();
        var activeTasks = new List<Task<R>>();
        foreach (var item in items)
        {
            if (activeTasks.Count >= maxConcurrency)
            {
                var completedTask = await Task.WhenAny(activeTasks);
                activeTasks.Remove(completedTask);
            }
            var task = createTask(item);
            allTasks.Add(task);
            activeTasks.Add(task);
        }
        return await Task.WhenAll(allTasks);
    }
查看更多
女痞
5楼-- · 2019-01-15 00:23

Already a lot of answers here. I want to address the comment you made in Stephens answer, about an example of using the TPL Dataflow to limit concurrency. Even tough you have left a comment in another answer of this question that you don't use the Task-based approach anymore for this it might help other people.

An example of using the ActionBlock<T> for this is:

private static async Task DoStuff<T>(int maxConcurrency, IEnumerable<T> items, Func<T, Task> createTask)
{
    var ab = new ActionBlock<T>(createTask, new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = maxConcurrency });

    foreach (var item in items)
    {
        ab.Post(item);
    }

    ab.Complete();
    await ab.Completion;
}

More information about the TPL Dataflow can be found here: https://msdn.microsoft.com/en-us/library/system.threading.tasks.dataflow(v=vs.110).aspx

查看更多
beautiful°
6楼-- · 2019-01-15 00:27

You can use a semaphore to throttle the processing. Using the WaitAsync() method you get the asynchrony you expected. Something like this (error handling removed for brevity):

private static async Task DoStuff<T>(int maxConcurrency, IEnumerable<T> items, Func<T, Task> createTask)
{
    using (var sem = new SemaphoreSlim(maxConcurrency))
    {
        var tasks = new List<Task>();

        foreach (var item in items)
        {
            await sem.WaitAsync();
            var task = createTask(item).ContinueWith(t => sem.Release());
            tasks.Add(task);
        }

        await Task.WhenAll(tasks);
    }
}

Edited to remove bug where the semaphore could be disposed before all release operations had a chance to be executed.

查看更多
登录 后发表回答