Java NIO client

2019-01-12 08:29发布

问题:

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.*;

public class EchoServer {
    private InetAddress addr;
    private int port;
    private Selector selector;
    private Map<SocketChannel,List<byte[]>> dataMap;

    public EchoServer(InetAddress addr, int port) throws IOException {
        this.addr = addr;
        this.port = port;
        dataMap = new HashMap<SocketChannel,List<byte[]>>();
        startServer();
    }

    private void startServer() throws IOException {
        // create selector and channel
        this.selector = Selector.open();
        ServerSocketChannel serverChannel = ServerSocketChannel.open();
        serverChannel.configureBlocking(false);

        // bind to port
        InetSocketAddress listenAddr = new InetSocketAddress(this.addr, this.port);
        serverChannel.socket().bind(listenAddr);
        serverChannel.register(this.selector, SelectionKey.OP_ACCEPT);

        log("Echo server ready. Ctrl-C to stop.");

        // processing
        while (true) {
            // wait for events
            this.selector.select();

            // wakeup to work on selected keys
            Iterator keys = this.selector.selectedKeys().iterator();
            while (keys.hasNext()) {
                SelectionKey key = (SelectionKey) keys.next();

                // this is necessary to prevent the same key from coming up 
                // again the next time around.
                keys.remove();

                if (! key.isValid()) {
                    continue;
                }

                if (key.isAcceptable()) {
                    this.accept(key);
                }
                else if (key.isReadable()) {
                    this.read(key);
                }
                else if (key.isWritable()) {
                    this.write(key);
                }
                else if (key.isConnectable()) {
                    this.doConnect(key);
                }
            }
        }
    }
    private void doConnect(SelectionKey key) {
            SocketChannel channel = (SocketChannel) key.channel();
            if (channel.finishConnect()) {
                    /* success */
                    System.out.println("Connected");
            } else {
                    /* failure */
                    System.out.println("failure");
            }
    }
    public void connect(String hostname, int port) throws IOException {
        SocketChannel clientChannel = SocketChannel.open();
        clientChannel.configureBlocking(false);
        clientChannel.connect(new InetSocketAddress(hostname,port));
        clientChannel.register(selector,SelectionKey.OP_CONNECT);
        clientChannel.write(ByteBuffer.wrap(("$Hello "+UserInfo[0]+"|").getBytes("US-ASCII")));
    }
    private void accept(SelectionKey key) throws IOException {
        ServerSocketChannel serverChannel = (ServerSocketChannel) key.channel();
        SocketChannel channel = serverChannel.accept();
        channel.configureBlocking(false);

        // write welcome message
        channel.write(ByteBuffer.wrap("Welcome, this is the echo server\r\n".getBytes("US-ASCII")));

        Socket socket = channel.socket();
        SocketAddress remoteAddr = socket.getRemoteSocketAddress();
        log("Connected to: " + remoteAddr);        
        dataMap.put(channel, new ArrayList<byte[]>()); // register channel with selector for further IO
        channel.register(this.selector, SelectionKey.OP_READ);
    }
    private void read(SelectionKey key) throws IOException {
        SocketChannel channel = (SocketChannel) key.channel();

        ByteBuffer buffer = ByteBuffer.allocate(8192);
        int numRead = -1;
        try {
            numRead = channel.read(buffer);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        if (numRead == -1) {
            this.dataMap.remove(channel);
            Socket socket = channel.socket();
            SocketAddress remoteAddr = socket.getRemoteSocketAddress();
            log("Connection closed by client: " + remoteAddr);
            channel.close();
            key.cancel();
            return;
        }
        byte[] data = new byte[numRead];
        System.arraycopy(buffer.array(), 0, data, 0, numRead);
        log("Got: " + new String(data, "US-ASCII"));        
        doEcho(key, data); // write back to client
    }
    private void write(SelectionKey key) throws IOException {
        SocketChannel channel = (SocketChannel) key.channel();
        List<byte[]> pendingData = this.dataMap.get(channel);
        Iterator<byte[]> items = pendingData.iterator();
        while (items.hasNext()) {
            byte[] item = items.next();
            items.remove();
            channel.write(ByteBuffer.wrap(item));
        }
        key.interestOps(SelectionKey.OP_READ);
    }
    private void doEcho(SelectionKey key, byte[] data) {
        SocketChannel channel = (SocketChannel) key.channel();
        List<byte[]> pendingData = this.dataMap.get(channel);
        pendingData.add(data);
        key.interestOps(SelectionKey.OP_WRITE);
    }
    private static void log(String s) {
        System.out.println(s);
    }
    public static void main(String[] args) throws Exception {
        new EchoServer(null, 8989);
    }
}

The program works with incoming connections. But when I make an outgoing connection, the program does not work. I need to make some connections in a row through the connect (String hostname, int port) and receive data in a method read(). The program stops working on the line clientChannel.register(...)

回答1:

You need to check for a connectable key, e.g.

if (key.isConnectable()) {
  this.doConnect(key);
}
...
private void doConnect(SelectionKey key) {
  SocketChannel channel = (SocketChannel) key.channel();
  if (channel.finishConnect()) {
    /* success */
  } else {
    /* failure */
  }
}

Use SocketChannel.finishConnect to determine whether the connection was established successfully.



回答2:

This is my NIO Client example Ive been using. It gives a timed out open/write/read functions and is suitable for request-response messaging.

Tricky part is always how do parties recognize a packet is fully received. This example assumes

  • Server gets one_line_command+newline (client->server packet)
  • Client receives 1..n lines with ">>" terminator line without trailing newline in a terminator line (server->client packet)
  • you could specify terminator to be ">>\n" but readUntil needs small fix in a newline parser

You could write 4-byte length header, fixed size packet splitter or delimiter 0x27 byte but make sure it cannot be a value of data payload. NIO read() or write() never assumes you receive a full packet in one call. Or may read two or more packet bytes in one read() buffer. It is up to us make a packet parser not loosing bytes.

import java.util.*;
import java.io.*;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;

public class DPSocket {
    private boolean debug;
    private String host;
    private int port;
    private String charset;
    private ByteArrayOutputStream inBuffer;
    private ByteBuffer buf;
    private Selector selector;
    private SocketChannel channel;

    public DPSocket(String host, int port, String charset) {
        this.charset = charset==null || charset.equals("") ? "UTF-8" : charset;
        this.host = host;
        this.port = port;
    }

    public boolean isDebug() { return debug; }
    public void setDebug(boolean b) { debug=b; }

    public void open(long timeout) throws IOException {
        selector = Selector.open();
        channel = SocketChannel.open();
        channel.configureBlocking(false);
        channel.register(selector, SelectionKey.OP_CONNECT);
        channel.connect(new InetSocketAddress(host, port));
        inBuffer = new ByteArrayOutputStream(1024);
        buf = ByteBuffer.allocate(1*1024);
        long sleep = Math.min(timeout, 1000);
        while(timeout > 0) {
            if (selector.select(sleep) < 1) {
                timeout-=sleep;
                continue;
            }
            Iterator<SelectionKey> keys = selector.selectedKeys().iterator();
            while(keys.hasNext()) {
                SelectionKey key = keys.next();
                keys.remove();
                if (!key.isValid() || !key.isConnectable()) continue;
                SocketChannel channel = (SocketChannel)key.channel();
                if (channel.isConnectionPending()) {
                    channel.finishConnect();
                    channel.configureBlocking(false);
                    if (debug) System.out.println("finishConnect");
                    return; // we are ready to receive bytes
                }
            }
        }
        throw new IOException("Connection timed out");
    }

    public void close() {
        try { if(channel!=null) channel.close(); } catch(Exception ex) { }
        try { if(selector!=null) selector.close(); } catch(Exception ex) { }
        inBuffer=null;
        buf=null;
    }   

    public void write(String data, long timeout) throws IOException {
        write(data.getBytes(charset), timeout);
    }

    public void write(byte[] bytes, long timeout) throws IOException {
        ByteBuffer outBuffer = ByteBuffer.wrap(bytes);
        channel.register(selector, SelectionKey.OP_WRITE);
        long sleep = Math.min(timeout, 1000);
        while(timeout > 0) {
            if (selector.select(sleep) < 1) {
                timeout-=sleep;
                continue;
            }
            Iterator<SelectionKey> keys = selector.selectedKeys().iterator();
            while(keys.hasNext()) {
                SelectionKey key = keys.next();
                keys.remove();
                if (!key.isValid() || !key.isWritable()) continue;
                SocketChannel channel = (SocketChannel)key.channel();
                if (debug) System.out.println("write remaining="+outBuffer.remaining());
                channel.write(outBuffer);
                if (debug) System.out.println("write remaining="+outBuffer.remaining());
                if (outBuffer.remaining()<1)
                    return;
            }
        }
        throw new IOException("Write timed out");
    }

    public List<String> readUntil(String terminator, long timeout, boolean trimLines) throws IOException {
        return readUntil(new String[]{terminator}, timeout, trimLines);
    }

    public List<String> readUntil(String[] terminators, long timeout, boolean trimLines) throws IOException {
        List<String> lines = new ArrayList<String>(12);
        inBuffer.reset();

        // End of packet terminator strings, line startsWith "aabbcc" string.
        byte[][] arrTerminators = new byte[terminators.length][];
        int[] idxTerminators = new int[terminators.length];
        for(int idx=0; idx < terminators.length; idx++) {
            arrTerminators[idx] = terminators[idx].getBytes(charset);
            idxTerminators[idx] = 0;
        }
        int idxLineByte=-1;

        channel.register(selector, SelectionKey.OP_READ);
        long sleep = Math.min(timeout, 1000);
        while(timeout>0) {
            if (selector.select(sleep) < 1) {
                timeout-=sleep;
                continue;
            }
            Iterator<SelectionKey> keys = selector.selectedKeys().iterator();
            while(keys.hasNext()) {
                SelectionKey key = keys.next();
                keys.remove();
                if (!key.isValid() || !key.isReadable()) continue;
                SocketChannel channel = (SocketChannel)key.channel();
                buf.clear();
                int len = channel.read(buf);
                if (len == -1) throw new IOException("Socket disconnected");
                buf.flip();
                for(int idx=0; idx<len; idx++) {
                    byte cb = buf.get(idx);
                    if (cb!='\n') {
                        idxLineByte++;
                        inBuffer.write(cb);
                        for(int idxter=0; idxter < arrTerminators.length; idxter++) {
                            byte[] arrTerminator = arrTerminators[idxter];
                            if (idxLineByte==idxTerminators[idxter]
                                    && arrTerminator[ idxTerminators[idxter] ]==cb) {
                                idxTerminators[idxter]++;
                                if (idxTerminators[idxter]==arrTerminator.length)
                                    return lines;
                            } else idxTerminators[idxter]=0;
                        }
                    } else  {
                        String line = inBuffer.toString(charset);
                        lines.add(trimLines ? line.trim() : line);
                        inBuffer.reset();
                        idxLineByte=-1;
                        for(int idxter=0; idxter<arrTerminators.length; idxter++)
                            idxTerminators[idxter]=0;
                    }
                }
            }
        }
        throw new IOException("Read timed out");
    }

    // **************************
    // *** test socket client ***
    // **************************

    public static void main(String[] args) throws Exception {
        String NEWLINE = "\n";
        int TIMEOUT=5000;

        DPSocket dps = new DPSocket("myserver.com", 1234, "UTF-8");
        dps.setDebug(true);

        try {
            List<String> lines;
            dps.open(15000);

            dps.write("Command1 arg1 arg2"+NEWLINE, TIMEOUT);
            lines = dps.readUntil(">>", TIMEOUT, true);

            dps.write("Command2 arg1 arg2"+NEWLINE, TIMEOUT);
            lines = dps.readUntil(">>", TIMEOUT, true);
        } catch (Exception ex) {
            String msg = ex.getMessage();
            if (msg==null) msg = ex.getClass().getName();
            if (msg.contains("timed out") || msg.contains("Invalid command ")) {
                System.out.println("ERROR: " + ex.getMessage());
            } else {
                System.out.print("ERROR: ");
                ex.printStackTrace();
            }
        } finally {
            dps.close();
        }
    }

}