Is deep learning bad at fitting simple non linear

2019-01-24 20:23发布

I am trying to create a simple deep-learning based model to predict y=x**2 But looks like deep learning is not able to learn the general function outside the scope of its training set.

Intuitively I can think that neural network might not be able to fit y=x**2 as there is no multiplication involved between the inputs.

Please note I am not asking how to create a model to fit x**2. I have already achieved that. I want to know the answers to following questions:

  1. Is my analysis correct?
  2. If the answer to 1 is yes, then isn't the prediction scope of deep learning very limited?
  3. Is there a better algorithm for predicting functions like y = x**2 both inside and outside the scope of training data?

Path to complete notebook: https://github.com/krishansubudhi/MyPracticeProjects/blob/master/KerasBasic-nonlinear.ipynb

training input:

x = np.random.random((10000,1))*1000-500
y = x**2
x_train= x

input data

training code

def getSequentialModel():
    model = Sequential()
    model.add(layers.Dense(8, kernel_regularizer=regularizers.l2(0.001), activation='relu', input_shape = (1,)))
    model.add(layers.Dense(1))
    print(model.summary())
    return model

def runmodel(model):
    model.compile(optimizer=optimizers.rmsprop(lr=0.01),loss='mse')
    from keras.callbacks import EarlyStopping
    early_stopping_monitor = EarlyStopping(patience=5)
    h = model.fit(x_train,y,validation_split=0.2,
             epochs= 300,
             batch_size=32,
             verbose=False,
             callbacks=[early_stopping_monitor])


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_18 (Dense)             (None, 8)                 16        
_________________________________________________________________
dense_19 (Dense)             (None, 1)                 9         
=================================================================
Total params: 25
Trainable params: 25
Non-trainable params: 0
_________________________________________________________________

Evaluation on random test set

enter image description here

Deep learning in this example is not good at predicting a simple non linear function. But good at predicting values in the sample space of training data.

1条回答
祖国的老花朵
2楼-- · 2019-01-24 20:43
  1. Is my analysis correct?

Given my remarks in the comments that your network is certainly not deep, let's accept that your analysis is indeed correct (after all, your model does seem to do a good job inside its training scope), in order to get to your 2nd question, which is the interesting one.

  1. If the answer to 1 is yes, then isn't the prediction scope of deep learning very limited?

Well, this is the kind of questions not exactly suitable for SO, since the exact meaning of "very limited" is arguably unclear...

So, let's try to rephrase it: should we expect DL models to predict such numerical functions outside the numeric domain on which they have been trained?

An example from a different domain may be enlightening here: suppose we have built a model able to detect & recognize animals in photos with very high accuracy (it is not hypothetical; such models do exist indeed); should we complain when the very same model cannot detect and recognize airplanes (or trees, refrigerators etc - you name it) in these same photos?

Put like that, the answer is a clear & obvious no - we should not complain, and in fact we are certainly not even surprised by such a behavior in the first place.

It is tempting for us humans to think that such models should be able to extrapolate, especially in the numeric domain, since this is something we do very "easily" ourselves; but ML models, while exceptionally good at interpolating, they fail miserably in extrapolation tasks, such as the one you present here.

Trying to make it more intuitive, think that the whole "world" of such models is confined in the domain of their training sets: my example model above would be able to generalize and recognize animals in unseen photos as long as these animals are "between" (mind the quotes) the ones it has seen during training; in a similar manner, your model does a good job predicting the function value for arguments between the sample you have used for training. But in neither case these models are expected to go beyond their training domain (i.e. extrapolate). There is no "world" for my example model beyond animals, and similarly for your model beyond [-500, 500]...

For corroboration, consider the very recent paper Neural Arithmetic Logic Units, by DeepMind; quoting from the abstract:

Neural networks can learn to represent and manipulate numerical information, but they seldom generalize well outside of the range of numerical values encountered during training.

See also a relevant tweet of a prominent practitioner.

On to your third question:

  1. Is there a better algorithm for predicting functions like y = x**2 both inside and outside the scope of training data?

As it should be clear by now, this is a (hot) area of current research; see the above paper for starters...


So, are DL models limited? Definitely - forget the scary tales about AGI for the foreseeable future. Are they very limited, as you put it? Well, I don't know... But, given their limitation in extrapolating, are they useful?

This is arguably the real question of interest, and the answer is obviously - hell, yeah!

查看更多
登录 后发表回答