I have two applications that interact using Thrift. They share the same secret key and I need to encrypt their messages. It makes sense to use symmetric algorithm (AES, for example), but I haven't found any library to do this. So I made a research and see following options:
Use built-in SSL support
I can use built-in SSL support, establish secure connection and use my secret key just as authentication token. It requires to install certificates in addition to the secret key they already have, but I don't need to implement anything except checking that secret key received from client is the same as secret key stored locally.
Implement symmetric encryption
So far, there are following options:
- Extend
TSocket
and overridewrite()
andread()
methods and en- / decrypt data in them. Will have increasing of traffic on small writes. For example, ifTBinaryProtocol
writes 4-bytes integer, it will take one block (16 bytes) in encrypted state. - Extend
TSocket
and wrapInputStream
andOutputStream
withCipherInputStream
andCipherOutputStream
.CipherOutputStream
will not encrypt small byte arrays immediately, updatingCipher
with them. After we have enough data, they will be encrypted and written to the underlyingOutputStream
. So it will wait until you add 4 4-byte ints and encrypt them then. It allows us not wasting traffic, but is also a cause of problem - if last value will not fill the block, it will be never encrypted and written to the underlying stream. It expects me to write number of bytes divisible by its block size (16 byte), but I can't do this usingTBinaryProtocol
. - Re-implement
TBinaryProtocol
, caching all writes instead of writing them to stream and encrypting inwriteMessageEnd()
method. Implement decryption inreadMessageBegin()
. I think encryption should be performed on the transport layer, not protocol one.
Please share your thoughts with me.
UPDATE
Java Implementation on Top of TFramedTransport
TEncryptedFramedTransport.java
package tutorial;
import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.TTransportFactory;
import javax.crypto.Cipher;
import java.security.Key;
/**
* TEncryptedFramedTransport is a buffered TTransport. It encrypts fully read message
* with the "AES/ECB/PKCS5Padding" symmetric algorithm and send it, preceeding with a 4-byte frame size.
*/
public class TEncryptedFramedTransport extends TTransport {
public static final String ALGORITHM = "AES/ECB/PKCS5Padding";
private Cipher encryptingCipher;
private Cipher decryptingCipher;
protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;
private int maxLength_;
private TTransport transport_ = null;
private final TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(1024);
private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(new byte[0]);
public static class Factory extends TTransportFactory {
private int maxLength_;
private Key secretKey_;
public Factory(Key secretKey) {
this(secretKey, DEFAULT_MAX_LENGTH);
}
public Factory(Key secretKey, int maxLength) {
maxLength_ = maxLength;
secretKey_ = secretKey;
}
@Override
public TTransport getTransport(TTransport base) {
return new TEncryptedFramedTransport(base, secretKey_, maxLength_);
}
}
/**
* Constructor wraps around another tranpsort
*/
public TEncryptedFramedTransport(TTransport transport, Key secretKey, int maxLength) {
transport_ = transport;
maxLength_ = maxLength;
try {
encryptingCipher = Cipher.getInstance(ALGORITHM);
encryptingCipher.init(Cipher.ENCRYPT_MODE, secretKey);
decryptingCipher = Cipher.getInstance(ALGORITHM);
decryptingCipher.init(Cipher.DECRYPT_MODE, secretKey);
} catch (Exception e) {
throw new RuntimeException("Unable to initialize ciphers.");
}
}
public TEncryptedFramedTransport(TTransport transport, Key secretKey) {
this(transport, secretKey, DEFAULT_MAX_LENGTH);
}
public void open() throws TTransportException {
transport_.open();
}
public boolean isOpen() {
return transport_.isOpen();
}
public void close() {
transport_.close();
}
public int read(byte[] buf, int off, int len) throws TTransportException {
if (readBuffer_ != null) {
int got = readBuffer_.read(buf, off, len);
if (got > 0) {
return got;
}
}
// Read another frame of data
readFrame();
return readBuffer_.read(buf, off, len);
}
@Override
public byte[] getBuffer() {
return readBuffer_.getBuffer();
}
@Override
public int getBufferPosition() {
return readBuffer_.getBufferPosition();
}
@Override
public int getBytesRemainingInBuffer() {
return readBuffer_.getBytesRemainingInBuffer();
}
@Override
public void consumeBuffer(int len) {
readBuffer_.consumeBuffer(len);
}
private final byte[] i32buf = new byte[4];
private void readFrame() throws TTransportException {
transport_.readAll(i32buf, 0, 4);
int size = decodeFrameSize(i32buf);
if (size < 0) {
throw new TTransportException("Read a negative frame size (" + size + ")!");
}
if (size > maxLength_) {
throw new TTransportException("Frame size (" + size + ") larger than max length (" + maxLength_ + ")!");
}
byte[] buff = new byte[size];
transport_.readAll(buff, 0, size);
try {
buff = decryptingCipher.doFinal(buff);
} catch (Exception e) {
throw new TTransportException(0, e);
}
readBuffer_.reset(buff);
}
public void write(byte[] buf, int off, int len) throws TTransportException {
writeBuffer_.write(buf, off, len);
}
@Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer_.get();
int len = writeBuffer_.len();
writeBuffer_.reset();
try {
buf = encryptingCipher.doFinal(buf, 0, len);
} catch (Exception e) {
throw new TTransportException(0, e);
}
encodeFrameSize(buf.length, i32buf);
transport_.write(i32buf, 0, 4);
transport_.write(buf);
transport_.flush();
}
public static void encodeFrameSize(final int frameSize, final byte[] buf) {
buf[0] = (byte) (0xff & (frameSize >> 24));
buf[1] = (byte) (0xff & (frameSize >> 16));
buf[2] = (byte) (0xff & (frameSize >> 8));
buf[3] = (byte) (0xff & (frameSize));
}
public static int decodeFrameSize(final byte[] buf) {
return
((buf[0] & 0xff) << 24) |
((buf[1] & 0xff) << 16) |
((buf[2] & 0xff) << 8) |
((buf[3] & 0xff));
}
}
MultiplicationServer.java
package tutorial;
import co.runit.prototype.CryptoTool;
import org.apache.thrift.server.TNonblockingServer;
import org.apache.thrift.server.TServer;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TNonblockingServerTransport;
import java.security.Key;
public class MultiplicationServer {
public static MultiplicationHandler handler;
public static MultiplicationService.Processor processor;
public static void main(String[] args) {
try {
handler = new MultiplicationHandler();
processor = new MultiplicationService.Processor(handler);
Runnable simple = () -> startServer(processor);
new Thread(simple).start();
} catch (Exception x) {
x.printStackTrace();
}
}
public static void startServer(MultiplicationService.Processor processor) {
try {
Key key = CryptoTool.decodeKeyBase64("1OUXS3MczVFp3SdfX41U0A==");
TNonblockingServerTransport serverTransport = new TNonblockingServerSocket(9090);
TServer server = new TNonblockingServer(new TNonblockingServer.Args(serverTransport)
.transportFactory(new TEncryptedFramedTransport.Factory(key))
.processor(processor));
System.out.println("Starting the simple server...");
server.serve();
} catch (Exception e) {
e.printStackTrace();
}
}
}
MultiplicationClient.java
package tutorial;
import co.runit.prototype.CryptoTool;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import java.security.Key;
public class MultiplicationClient {
public static void main(String[] args) {
Key key = CryptoTool.decodeKeyBase64("1OUXS3MczVFp3SdfX41U0A==");
try {
TSocket baseTransport = new TSocket("localhost", 9090);
TTransport transport = new TEncryptedFramedTransport(baseTransport, key);
transport.open();
TProtocol protocol = new TBinaryProtocol(transport);
MultiplicationService.Client client = new MultiplicationService.Client(protocol);
perform(client);
transport.close();
} catch (TException x) {
x.printStackTrace();
}
}
private static void perform(MultiplicationService.Client client) throws TException {
int product = client.multiply(3, 5);
System.out.println("3*5=" + product);
}
}
Of course, keys must be the same on the client and server. To generate and store it in Base64:
public static String generateKey() throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
KeyGenerator generator = KeyGenerator.getInstance("AES");
generator.init(128);
Key key = generator.generateKey();
return encodeKeyBase64(key);
}
public static String encodeKeyBase64(Key key) {
return Base64.getEncoder().encodeToString(key.getEncoded());
}
public static Key decodeKeyBase64(String encodedKey) {
byte[] keyBytes = Base64.getDecoder().decode(encodedKey);
return new SecretKeySpec(keyBytes, ALGORITHM);
}
UPDATE 2
Python Implementation on Top of TFramedTransport
TEncryptedTransport.py
from cStringIO import StringIO
from struct import pack, unpack
from Crypto.Cipher import AES
from thrift.transport.TTransport import TTransportBase, CReadableTransport
__author__ = 'Marboni'
BLOCK_SIZE = 16
pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * chr(BLOCK_SIZE - len(s) % BLOCK_SIZE)
unpad = lambda s: '' if not s else s[0:-ord(s[-1])]
class TEncryptedFramedTransportFactory:
def __init__(self, key):
self.__key = key
def getTransport(self, trans):
return TEncryptedFramedTransport(trans, self.__key)
class TEncryptedFramedTransport(TTransportBase, CReadableTransport):
def __init__(self, trans, key):
self.__trans = trans
self.__rbuf = StringIO()
self.__wbuf = StringIO()
self.__cipher = AES.new(key)
def isOpen(self):
return self.__trans.isOpen()
def open(self):
return self.__trans.open()
def close(self):
return self.__trans.close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self.readFrame()
return self.__rbuf.read(sz)
def readFrame(self):
buff = self.__trans.readAll(4)
sz, = unpack('!i', buff)
encrypted = StringIO(self.__trans.readAll(sz)).getvalue()
decrypted = unpad(self.__cipher.decrypt(encrypted))
self.__rbuf = StringIO(decrypted)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
wout = self.__wbuf.getvalue()
self.__wbuf = StringIO()
encrypted = self.__cipher.encrypt(pad(wout))
encrypted_len = len(encrypted)
buf = pack("!i", encrypted_len) + encrypted
self.__trans.write(buf)
self.__trans.flush()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
while len(prefix) < reqlen:
self.readFrame()
prefix += self.__rbuf.getvalue()
self.__rbuf = StringIO(prefix)
return self.__rbuf
MultiplicationClient.py
import base64
from thrift import Thrift
from thrift.transport import TSocket
from thrift.protocol import TBinaryProtocol
from tutorial import MultiplicationService, TEncryptedTransport
key = base64.b64decode("1OUXS3MczVFp3SdfX41U0A==")
try:
transport = TSocket.TSocket('localhost', 9090)
transport = TEncryptedTransport.TEncryptedFramedTransport(transport, key)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = MultiplicationService.Client(protocol)
transport.open()
product = client.multiply(4, 5, 'Echo!')
print '4*5=%d' % product
transport.close()
except Thrift.TException, tx:
print tx.message