Implement AsyncManualResetEvent using Lazy to d

2019-09-09 12:13发布

I'm implementing an AsyncManualResetEvent based on Stephen Toub's example. However, I would like to know if the event, or specifically, the underlying Task<T> has been waited on.

I've already investigated the Task class, and there doesn't seem to be a sensible way to determine if it has ever been 'awaited' or if a continuation has been added.

In this case however, I control access to the underlying task source, so I can listen for any calls to the WaitAsync method instead. In thinking about how to do this, I decided to use a Lazy<T> and just see if it has been created.

sealed class AsyncManualResetEvent {
    public bool HasWaiters => tcs.IsValueCreated;

    public AsyncManualResetEvent() {
        Reset();
    }

    public Task WaitAsync() => tcs.Value.Task;

    public void Set() {
        if (tcs.IsValueCreated) {
            tcs.Value.TrySetResult(result: true);
        }
    }

    public void Reset() {
        tcs = new Lazy<TaskCompletionSource<bool>>(LazyThreadSafetyMode.PublicationOnly);
    }

    Lazy<TaskCompletionSource<bool>> tcs;
}

My question then, is whether this is a safe approach, specifically will this guarantee that there are never any orphaned/lost continuations while the event is being reset?

1条回答
贪生不怕死
2楼-- · 2019-09-09 12:33

If you truly wanted to know if anyone called await on your task (not just the fact that they called WaitAsync()) you could make a custom awaiter that acts as a wrapper for the TaskAwaiter that is used by m_tcs.Task.

public class AsyncManualResetEvent
{
    private volatile Completion _completion = new Completion();

    public bool HasWaiters => _completion.HasWaiters;

    public Completion WaitAsync()
    {
        return _completion;
    }

    public void Set()
    {
        _completion.Set();
    }

    public void Reset()
    {
        while (true)
        {
            var completion = _completion;
            if (!completion.IsCompleted ||
                Interlocked.CompareExchange(ref _completion, new Completion(), completion) == completion)
                return;
        }
    }
}

public class Completion
{
    private readonly TaskCompletionSource<bool> _tcs;
    private readonly CompletionAwaiter _awaiter;

    public Completion()
    {
        _tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
        _awaiter = new CompletionAwaiter(_tcs.Task, this);
    }

    public CompletionAwaiter GetAwaiter() => _awaiter;
    public bool IsCompleted => _tcs.Task.IsCompleted;
    public bool HasWaiters { get; private set; }
    public void Set() => _tcs.TrySetResult(true);

    public struct CompletionAwaiter : ICriticalNotifyCompletion
    {
        private readonly TaskAwaiter _taskAwaiter;
        private readonly Completion _parent;

        internal CompletionAwaiter(Task task, Completion parent)
        {
            _parent = parent;
            _taskAwaiter = task.GetAwaiter();
        }

        public bool IsCompleted => _taskAwaiter.IsCompleted;
        public void GetResult() => _taskAwaiter.GetResult();
        public void OnCompleted(Action continuation)
        {
            _parent.HasWaiters = true;
            _taskAwaiter.OnCompleted(continuation);
        }

        public void UnsafeOnCompleted(Action continuation)
        {
            _parent.HasWaiters = true;
            _taskAwaiter.UnsafeOnCompleted(continuation);
        }
    }
}

Now if anyone registered a continuation with OnCompleted or UnsafeOnCompleted the bool HasWaiters will become true.

I also added TaskCreationOptions.RunContinuationsAsynchronously to fix the issue Stephen fixes with the Task.Factory.StartNew at the end of the article (It was introduced to .NET after the article was written).


If you just want to see if anyone called WaitAsync you can simplify it a lot, you just need a class to hold your flag and your completion source.

public class AsyncManualResetEvent
{
    private volatile CompletionWrapper _completionWrapper = new CompletionWrapper();

    public Task WaitAsync()
    {
        var wrapper = _completionWrapper;
        wrapper.WaitAsyncCalled = true;
        return wrapper.Tcs.Task;
    }

    public bool WaitAsyncCalled
    {
        get { return _completionWrapper.WaitAsyncCalled; }
    }

    public void Set() {
        _completionWrapper.Tcs.TrySetResult(true); }

    public void Reset()
    {
        while (true)
        {
            var wrapper = _completionWrapper;
            if (!wrapper.Tcs.Task.IsCompleted ||
                Interlocked.CompareExchange(ref _completionWrapper, new CompletionWrapper(), wrapper) == wrapper)
                return;
        }
    }
    private class CompletionWrapper
    {
        public TaskCompletionSource<bool> Tcs { get; } = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
        public bool WaitAsyncCalled { get; set; }
    }
}
查看更多
登录 后发表回答