I get a TypeError when attempting to train an Tensorflow Random Forest using TensorForestEstimator.
TypeError: Input 'input_data' of 'CountExtremelyRandomStats' Op has type float64 that does not match expected type of float32.
I've tried using Python 2.7 and Python 3, and I've tried using tf.cast() to put everything in float32 but it doesn't help. I have checked the data type on execution and it's float32. The problem doesn't seem to be the data I provide (csv of all floats), so I'm not sure where to go from here.
Any suggestions of things I can try would be much appreciated.
Code:
# Build an estimator.
def build_estimator(model_dir):
params = tensor_forest.ForestHParams(
num_classes=2, num_features=760,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
if FLAGS.use_training_loss:
graph_builder_class = tensor_forest.TrainingLossForest
# Use the SKCompat wrapper, which gives us a convenient way to split in-memory data into batches.
return estimator.SKCompat(random_forest.TensorForestEstimator(params, graph_builder_class=graph_builder_class, model_dir=model_dir))
# Train and evaluate the model.
def train_and_eval():
# load datasets
training_set = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_train.csv', dtype=np.float32, header=None)
test_set = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_test.csv', dtype=np.float32, header=None)
print('###########')
print(training_set.loc[:,1].dtype) # this prints float32
# load labels
training_labels = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_train_class.csv', dtype=np.int32, names=LABEL, header=None)
test_labels = pd.read_csv('/Users/carl/Dropbox/Docs/Python/randomforest_balanced_test_class.csv', dtype=np.int32, names=LABEL, header=None)
# define the path where the model will be stored - default is current directory
model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
print('model directory = %s' % model_dir)
# build the random forest estimator
est = build_estimator(model_dir)
tf.cast(training_set, tf.float32) #error occurs with/without casts
tf.cast(test_set, tf.float32)
# train the forest to fit the training data
est.fit(x=training_set, y=training_labels) #this line throws the error