Waiting on Interlocked == 0?

2019-09-21 02:21发布

问题:

Disclaimer: My C# isn't even close to as good as my C++

I am trying to learn how to do async sockets in C# in order to write a test app for a component of mine. My former attempts using TcpClient ended in failure and you can read the outstanding questions on that here:

TcpClient.NetworkStream Async operations - Canceling / Disconnect

Detect errors with NetworkStream.WriteAsync

Since, I could not get that working, I tried using Socket.BeginX and Socket.EndX instead. I got much further along. My problem now is that in the listing below, when it comes time to disconnect, which in turn calls shutdown and close on the socket, async operations are still outstanding and they will throw an object disposed exception or an object is set to null exception.

I found a similar post about that here:

After disposing async socket (.Net) callbacks still get called

However, I do not accept that answer, because if you are using exceptions for intended behavior, then 1) They are not exceptional 2) You cannot tell if the exception was thrown for your intended case or if it was thrown because you actually used a disposed object or a null reference in your async method, other than the socket.

In C++ with async socket code, I'd keep track of the number of outstanding async operations with Interlocked and when it came time to disconnect, I'd call shutdown, then wait for the interlocked to hit 0, and then close and destroy any members I needed to.

How would I go about waiting in my Disconnect method for all outstanding Async operations to complete in C# in the following listing?

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

using log4net;
using System.Net.Sockets;
using System.Net;

namespace IntegrationTests
{
    public class Client2
    {
        class ReceiveContext
        {
            public Socket    m_socket;
            public const int m_bufferSize = 1024;
            public byte[]    m_buffer = new byte[m_bufferSize];
        }

        private static readonly ILog log = LogManager.GetLogger("root");

        static private ulong m_lastId = 1;

        private ulong  m_id;
        private string m_host;
        private uint   m_port;
        private uint   m_timeoutMilliseconds;
        private string m_clientId;
        private Socket m_socket;
        private uint   m_numOutstandingAsyncOps;

        public Client2(string host, uint port, string clientId, uint timeoutMilliseconds)
        {
            m_id                     = m_lastId++;
            m_host                   = host;
            m_port                   = port;
            m_clientId              = clientId;
            m_timeoutMilliseconds    = timeoutMilliseconds;
            m_socket                 = null;
            m_numOutstandingAsyncOps = 0;
        }

        ~Client2()
        {
            Disconnect();
        }

        public void Connect()
        {
            IPHostEntry ipHostInfo = Dns.GetHostEntry(m_host);
            IPAddress[] ipV4Addresses = ipHostInfo.AddressList.Where(x => x.AddressFamily == AddressFamily.InterNetwork).ToArray();
            IPAddress[] ipV6Addresses = ipHostInfo.AddressList.Where(x => x.AddressFamily == AddressFamily.InterNetworkV6).ToArray();
            IPEndPoint endpoint = new IPEndPoint(ipV4Addresses[0], (int)m_port);

            m_socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
            m_socket.ReceiveTimeout = (int)m_timeoutMilliseconds;
            m_socket.SendTimeout    = (int)m_timeoutMilliseconds;

            try
            {
                m_socket.Connect(endpoint);

                log.Info(string.Format("Connected to: {0}", m_socket.RemoteEndPoint.ToString()));

                // Issue the next async receive
                ReceiveContext context = new ReceiveContext();
                context.m_socket = m_socket;
                m_socket.BeginReceive(context.m_buffer, 0, ReceiveContext.m_bufferSize, SocketFlags.None, new AsyncCallback(OnReceive), context);
            }
            catch (Exception e)
            {
                // Error
                log.Error(string.Format("Client #{0} Exception caught OnConnect. Exception: {1}"
                                       , m_id, e.ToString()));
            }
        }

        public void Disconnect()
        {
            if (m_socket != null)
            {
                m_socket.Shutdown(SocketShutdown.Both);

                // TODO - <--- Error here in the callbacks where they try to use the socket and it is disposed
                //        We need to wait here until all outstanding async operations complete
                //        Should we use Interlocked to keep track of them and wait on it somehow?
                m_socket.Close();
                m_socket = null;
            }
        }

        public void Login()
        {
            string loginRequest = string.Format("loginstuff{0})", m_clientId);
            var data = Encoding.ASCII.GetBytes(loginRequest);

            m_socket.BeginSend(data, 0, data.Length, 0, new AsyncCallback(OnSend), m_socket);
        }

        public void MakeRequest(string thingy)
        {
            string message = string.Format("requeststuff{0}", thingy);
            var data = Encoding.ASCII.GetBytes(message);

            m_socket.BeginSend(data, 0, data.Length, 0, new AsyncCallback(OnSend), m_socket);
        }

        void OnReceive(IAsyncResult asyncResult)
        {
            ReceiveContext context = (ReceiveContext)asyncResult.AsyncState;

            string data = null;
            try
            {
                int bytesReceived = context.m_socket.EndReceive(asyncResult);
                data = Encoding.ASCII.GetString(context.m_buffer, 0, bytesReceived);

                ReceiveContext newContext = new ReceiveContext();
                newContext.m_socket = context.m_socket;

                m_socket.BeginReceive(newContext.m_buffer, 0, ReceiveContext.m_bufferSize, SocketFlags.None, new AsyncCallback(OnReceive), newContext);
            }
            catch(SocketException e)
            {
                if(e.SocketErrorCode == SocketError.ConnectionAborted) // Check if we disconnected on our end
                {
                    return;
                }
            }
            catch (Exception e)
            {
                // Error
                log.Error(string.Format("Client #{0} Exception caught OnReceive. Exception: {1}"
                                       , m_id, e.ToString()));
            }
        }

        void OnSend(IAsyncResult asyncResult)
        {
            Socket socket = (Socket)asyncResult.AsyncState;

            try
            {
                int bytesSent = socket.EndSend(asyncResult);
            }
            catch(Exception e)
            {
                log.Error(string.Format("Client #{0} Exception caught OnSend. Exception: {1}"
                                       , m_id, e.ToString()));
            }
        }
    }
}

Main:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

using log4net;
using log4net.Config;

namespace IntegrationTests
{
    class Program
    {
        private static readonly ILog log = LogManager.GetLogger("root");

        static void Main(string[] args)
        {
            try
            {
                XmlConfigurator.Configure();
                log.Info("Starting Component Integration Tests...");

                Client2 client = new Client2("127.0.0.1", 24001, "MyClientId", 60000);
                client.Connect();
                client.Login();
                client.MakeRequest("StuffAndPuff");

                System.Threading.Thread.Sleep(60000); // Sim work until user shutsdown

                client.Disconnect();
            }
            catch (Exception e)
            {
                log.Error(string.Format("Caught an exception in main. Exception: {0}"
                                      , e.ToString()));
            }
        }
    }
}

EDIT:

Here is my additional attempt using the answer proposed by Evk, to the best of my ability. It works fine as far as I can tell.

Problem with this is, that I feel like I basically made everything that was asynchronous into a synchronous call, because of the requirements to lock around anything that would change the counter or the state of the socket. Again, I am novice at C# compared to my C++, so please do point out if I completely missed the mark interpreting his answer.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace IntegrationTests
{
    public class Client
    {
        class ReceiveContext
        {
            public const int     _bufferSize    = 1024;
            public byte[]        _buffer        = new byte[_bufferSize]; // Contains bytes from one receive
            public StringBuilder _stringBuilder = new StringBuilder();   // Contains bytes for multiple receives in order to build message up to delim
        }

        private static readonly ILog _log = LogManager.GetLogger("root");

        static private ulong _lastId = 1;
        private ulong  _id;

        protected string         _host;
        protected int            _port;
        protected int            _timeoutMilliseconds;
        protected string         _sessionId;
        protected Socket         _socket;
        protected object         _lockNumOutstandingAsyncOps;
        protected int            _numOutstandingAsyncOps;
        private bool             _disposed = false;

        public Client(string host, int port, string sessionId, int timeoutMilliseconds)
        {
            _id                         = _lastId++;
            _host                       = host;
            _port                       = port;
            _sessionId                  = sessionId;
            _timeoutMilliseconds        = timeoutMilliseconds;
            _socket                     = null;
            _numOutstandingAsyncOps     = 0;
            _lockNumOutstandingAsyncOps = new object();
        }

        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }

        protected virtual void Dispose(bool disposing)
        {
            if(_disposed)
            {
                return;
            }

            if (disposing)
            {
                _socket.Close();
            }

            _disposed = true;
        }

        public void Connect()
        {
            lock (_lockNumOutstandingAsyncOps)
            {
                IPHostEntry ipHostInfo = Dns.GetHostEntry(_host);
                IPAddress[] ipV4Addresses = ipHostInfo.AddressList.Where(x => x.AddressFamily == AddressFamily.InterNetwork).ToArray();
                IPAddress[] ipV6Addresses = ipHostInfo.AddressList.Where(x => x.AddressFamily == AddressFamily.InterNetworkV6).ToArray();
                IPEndPoint endpoint = new IPEndPoint(ipV4Addresses[0], _port);

                _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                _socket.ReceiveTimeout = _timeoutMilliseconds;
                _socket.SendTimeout = _timeoutMilliseconds;

                try
                {
                    _socket.Connect(endpoint);
                }
                catch (Exception e)
                {
                    // Error
                    Debug.WriteLine(string.Format("Client #{0} Exception caught OnConnect. Exception: {1}"
                                           , _id, e.ToString()));
                    return;
                }

                Debug.WriteLine(string.Format("Client #{0} connected to: {1}", _id, _socket.RemoteEndPoint.ToString()));

                // Issue the first async receive
                ReceiveContext context = new ReceiveContext();

                ++_numOutstandingAsyncOps;
                _socket.BeginReceive(context._buffer, 0, ReceiveContext._bufferSize, SocketFlags.None, new AsyncCallback(OnReceive), context);
            }
        }

        public void Disconnect()
        {
            if (_socket != null)
            {
                // We need to wait here until all outstanding async operations complete
                // In order to avoid getting 'Object was disposed' exceptions in those async ops that use the socket
                lock(_lockNumOutstandingAsyncOps)
                {
                    Debug.WriteLine(string.Format("Client #{0} Disconnecting...", _id));

                    _socket.Shutdown(SocketShutdown.Both);

                    while (_numOutstandingAsyncOps > 0)
                    {
                        Monitor.Wait(_lockNumOutstandingAsyncOps);
                    }

                    _socket.Close();
                    _socket = null;
                }
            }
        }

        public void Login()
        {
            lock (_lockNumOutstandingAsyncOps)
            {
                if (_socket != null && _socket.Connected)
                {
                    string loginRequest = string.Format("loginstuff{0}", _clientId);
                    var data = Encoding.ASCII.GetBytes(loginRequest);

                    Debug.WriteLine(string.Format("Client #{0} Sending Login Request: {1}"
                                           , _id, loginRequest));

                    ++_numOutstandingAsyncOps;
                    _socket.BeginSend(data, 0, data.Length, 0, new AsyncCallback(OnSend), _socket);
                }
                else
                {
                    Debug.WriteLine(string.Format("Client #{0} Login was called, but Socket is null or no longer connected."
                                           , _id));
                }
            }
        }

        public void MakeRequest(string thingy)
        {
            lock (_lockNumOutstandingAsyncOps)
            {
                if (_socket != null && _socket.Connected)
                {
                    string message = string.Format("requeststuff{0}", thingy);
                    var data = Encoding.ASCII.GetBytes(message);

                    Debug.WriteLine(string.Format("Client #{0} Sending Request: {1}"
                                           , _id, message));

                    ++_numOutstandingAsyncOps;
                    _socket.BeginSend(data, 0, data.Length, 0, new AsyncCallback(OnSend), _socket);
                }
                else
                {
                    Debug.WriteLine(string.Format("Client #{0} MakeRequest was called, but Socket is null or no longer connected."
                                           , _id));
                }
            }
        }

        protected void OnReceive(IAsyncResult asyncResult)
        {
            lock (_lockNumOutstandingAsyncOps)
            {
                ReceiveContext context = (ReceiveContext)asyncResult.AsyncState;

                string data = null;

                try
                {
                    int bytesReceived = _socket.EndReceive(asyncResult);
                    data = Encoding.ASCII.GetString(context._buffer, 0, bytesReceived);

                    // If the remote host shuts down the Socket connection with the Shutdown method, and all available data has been received,
                    // the EndReceive method will complete immediately and return zero bytes
                    if (bytesReceived > 0)
                    {
                        StringBuilder stringBuilder = context._stringBuilder.Append(data);

                        int index = -1;
                        do
                        {
                            index = stringBuilder.ToString().IndexOf("#");
                            if (index != -1)
                            {
                                string message = stringBuilder.ToString().Substring(0, index + 1);
                                stringBuilder.Remove(0, index + 1);

                                Debug.WriteLine(string.Format("Client #{0} Received Data: {1}"
                                                       , _id, message));
                            }
                        } while (index != -1);
                    }
                }
                catch (SocketException e)
                {
                    // Check if we disconnected on our end
                    if (e.SocketErrorCode == SocketError.ConnectionAborted)
                    {
                        // Ignore
                    }
                    else
                    {
                        // Error
                        Debug.WriteLine(string.Format("Client #{0} SocketException caught OnReceive. Exception: {1}"
                                               , _id, e.ToString()));
                        Disconnect();
                    }
                }
                catch (Exception e)
                {
                    // Error
                    Debug.WriteLine(string.Format("Client #{0} Exception caught OnReceive. Exception: {1}"
                                           , _id, e.ToString()));
                    Disconnect();
                }
                finally
                {
                    --_numOutstandingAsyncOps;
                    Monitor.Pulse(_lockNumOutstandingAsyncOps);
                }
            }

            // Issue the next async receive
            lock (_lockNumOutstandingAsyncOps)
            {
                if (_socket != null && _socket.Connected)
                {
                    ++_numOutstandingAsyncOps;

                    ReceiveContext newContext = new ReceiveContext();
                    _socket.BeginReceive(newContext._buffer, 0, ReceiveContext._bufferSize, SocketFlags.None, new AsyncCallback(OnReceive), newContext);
                }
            }
        }

        protected void OnSend(IAsyncResult asyncResult)
        {
            lock (_lockNumOutstandingAsyncOps)
            {
                try
                {
                    int bytesSent = _socket.EndSend(asyncResult);
                }
                catch (Exception e)
                {
                    Debug.WriteLine(string.Format("Client #{0} Exception caught OnSend. Exception: {1}"
                                           , _id, e.ToString()));
                    Disconnect();
                }
                finally
                {
                    --_numOutstandingAsyncOps;
                    Monitor.Pulse(_lockNumOutstandingAsyncOps);
                }
            }
        }
    }
}

回答1:

You can use Monitor.Wait and Monitor.Pulse for that:

static int _outstandingOperations;
static readonly object _lock = new object();
static void Main() {
    for (int i = 0; i < 100; i++) {
        var tmp = i;
        Task.Run(() =>
        {
            lock (_lock) {
                _outstandingOperations++;
            }
            // some work
            Thread.Sleep(new Random(tmp).Next(0, 5000));
            lock (_lock) {
                _outstandingOperations--;
                // notify condition might have changed
                Monitor.Pulse(_lock);
            }
        });
    }

    lock (_lock) {
        // condition check
        while (_outstandingOperations > 0)
            // will wait here until pulsed, lock will be released during wait
            Monitor.Wait(_lock);
    }
}