I am training 3 lists of data L1, L2, L3. First i train all one them with SGDClassifier fit() and later instance by instance with partial_fit(). I I test the data with L4, L5. [The data in lists is image data and L4, L5 images are same as L2].
The predictions with fit() is correct and it is what i am expecting with partial_fit(). However the output of below code shows that both behave differently irrespective of 10,000 number of iterations for partial_fit().
Output:
fit
[1] #Tested L1. Predicts label as 1 correctly
[2] #Tested L2. Predicts label as 2 correctly
[3] #Tested L3. Predicts label as 3 correctly
[2] #Tested L4. Predicts label as 2 correctly [Data close to L2]
[2] #Tested L5. Predicts label as 2 correctly [Data close to L2]
partial_fit
[3] #Tested L1. Predicts label as 3 incorrectly
[3] #Tested L2. Predicts label as 3 incorrectly
[3] #Tested L3. Predicts label as 3 incorrectly
[3] #Tested L4. Predicts label as 3 incorrectly
[3] #Tested L5. Predicts label as 3 incorrectly
Code:
from sklearn import linear_model, neighbors
import numpy as np
L1 = [-1.98257446e-01, 1.02612168e-01, 1.06458694e-01, -4.44016755e-02,
-1.25126377e-01, -1.03119195e-01, -1.89867821e-02, -5.70720285e-02,
1.65993825e-01, -4.91751768e-02, 1.35020703e-01, 5.58929071e-02,
-1.79934561e-01, -1.61055699e-02, -3.67883481e-02, 7.28202313e-02,
-8.59514326e-02, -1.19364798e-01, -6.03461489e-02, -9.60081592e-02,
9.60884690e-02, 7.37309158e-02, -4.95407730e-02, -2.30211094e-02,
-1.59170195e-01, -3.23998809e-01, -8.31042454e-02, -7.68149048e-02,
3.26708518e-03, -5.57898730e-02, 3.65743786e-02, 3.37894261e-02,
-1.61165833e-01, -9.21991318e-02, 3.83259654e-02, 1.30853474e-01,
2.16114409e-02, 1.56024918e-02, 1.63483590e-01, 3.55638564e-04,
-1.01068482e-01, 3.11988778e-02, 2.79297493e-02, 3.43645960e-01,
7.68225491e-02, 7.39665255e-02, 9.03626233e-02, -4.77984771e-02,
1.46613032e-01, -2.24640951e-01, 9.37603638e-02, 1.30618230e-01,
5.41394278e-02, 3.57956365e-02, 9.59608406e-02, -1.01410612e-01,
1.15592867e-01, 7.47590065e-02, -2.77784020e-01, 1.61038041e-01,
2.08325848e-01, -1.48789823e-01, -9.12107825e-02, -2.09741015e-02,
2.12046385e-01, 4.47734147e-02, -8.59520137e-02, -8.20810571e-02,
1.37491941e-01, -1.57671914e-01, -1.28236525e-02, -2.89905779e-02,
-9.23343226e-02, -1.41179219e-01, -2.73343533e-01, 8.64235312e-02,
4.51376319e-01, 2.13798493e-01, -1.68360874e-01, 7.94294775e-02,
-1.16615891e-01, 4.44242992e-02, 1.32415727e-01, -1.00808069e-02,
-7.62857720e-02, 4.50578667e-02, -1.62037611e-01, 8.80152583e-02,
2.10405558e-01, 5.48043177e-02, -2.42764503e-03, 2.23779172e-01,
1.04215354e-01, 6.21869229e-03, 4.02947590e-02, 1.28729194e-02,
-1.31998569e-01, -8.53061676e-02, -7.21085370e-02, 3.05483658e-02,
7.17334375e-02, -1.21093884e-01, 4.04045768e-02, 8.53371918e-02,
-1.82588950e-01, 1.95098877e-01, -3.77971642e-02, 2.39514187e-02,
-6.40425161e-02, 2.60147993e-02, -1.23514839e-01, -5.75782135e-02,
1.23560801e-01, -1.81436151e-01, 1.73729539e-01, 1.55140847e-01,
9.45670251e-03, 1.76663831e-01, 4.24060002e-02, 5.23296222e-02,
-2.61488743e-02, -1.90883875e-04, -1.07142523e-01, -1.19456224e-01,
-4.72589768e-03, -1.22928023e-02, 1.22105561e-01, 1.08871996e-01]
L2 = [-0.13126934, 0.04299157, 0.03283413, -0.07268133, -0.0575216 ,
-0.05970731, -0.04122763, -0.12341423, 0.23687837, -0.19369504,
0.18289158, -0.02773106, -0.17346333, -0.03682114, -0.01798879,
0.12592959, -0.13210742, -0.14877586, -0.03237661, -0.08512233,
0.03863079, -0.0244094 , 0.03298262, 0.07976148, -0.14883795,
-0.41100848, -0.17795764, -0.08934171, 0.00651174, -0.0744134 ,
0.0313075 , 0.08470915, -0.18205762, -0.01133199, -0.0155912 ,
0.11513804, 0.00782543, -0.05359597, 0.18193047, -0.00212595,
-0.20811354, -0.16053183, 0.05181924, 0.23603486, 0.10422225,
0.02778829, 0.05380247, -0.04042226, 0.0341601 , -0.17557909,
0.05018872, 0.11027649, 0.05657898, 0.02233699, 0.08839077,
-0.15501094, 0.01485735, 0.04386368, -0.11386063, -0.01646214,
0.00378657, -0.10775882, -0.12292566, -0.02450235, 0.25261074,
0.14213347, -0.09663931, -0.11174012, 0.22364001, -0.17145677,
-0.00569641, 0.02280853, -0.12527066, -0.18559724, -0.29374081,
-0.00162096, 0.42862758, 0.12023295, -0.12319036, 0.10102081,
-0.05752999, -0.02222615, 0.04897028, 0.1726429 , -0.09291326,
0.12992594, -0.05943635, 0.1127295 , 0.13184965, -0.02819252,
-0.02569888, 0.13797338, -0.05463714, 0.07084383, 0.03620753,
0.02154547, -0.09113872, -0.00730729, -0.11946794, -0.00743609,
0.13593611, 0.01564942, -0.02297226, 0.11888021, -0.18092889,
0.11661324, 0.02172676, -0.09794122, 0.01236411, 0.0558071 ,
-0.1001874 , -0.1216456 , 0.13321149, -0.22005031, 0.08024856,
0.19123463, -0.06378062, 0.2226923 , 0.07309284, 0.11730921,
0.0262427 , -0.03699137, -0.1887596 , -0.02048384, 0.04079603,
-0.02144403, 0.00859149, -0.01283618]
L3 = [-1.39073551e-01, 5.75132817e-02, 1.06875971e-01, -4.47942242e-02,
6.49299771e-02, -8.30453411e-02, 3.50628048e-02, -4.86568436e-02,
1.11577645e-01, -9.53562111e-02, 2.84853131e-01, -5.57231307e-02,
-2.10671812e-01, -1.03007048e-01, 1.96518339e-02, 7.77831525e-02,
-7.90358335e-02, -3.00030578e-02, -7.82457143e-02, -1.04805976e-01,
8.18016306e-02, 6.47072643e-02, 1.21586584e-02, 8.08022916e-04,
-8.00280571e-02, -3.14502358e-01, -1.17208570e-01, -9.81831551e-02,
2.68037282e-02, -1.33987337e-01, 1.33101437e-02, 2.91747972e-02,
-1.87404498e-01, -5.92408441e-02, -7.84080178e-02, 1.05799856e-02,
-6.32970333e-02, -2.37192065e-02, 1.31071255e-01, 5.25641590e-02,
-8.04402679e-02, -9.32691842e-02, -2.31102034e-02, 2.82592803e-01,
1.47951603e-01, 8.49031657e-03, -6.55979887e-02, -1.86005980e-03,
2.86830403e-03, -2.48319194e-01, -5.38104884e-02, 1.02639243e-01,
5.23314849e-02, 7.83263296e-02, 7.35125244e-02, -5.58062941e-02,
3.26449387e-02, -2.09478531e-02, -1.95044577e-01, 9.34160873e-03,
-2.26898044e-02, -8.78838003e-02, -6.57741576e-02, -2.00360566e-02,
1.71352893e-01, 6.89927936e-02, -7.95211121e-02, -8.00146461e-02,
1.32486463e-01, -1.35504007e-01, 2.61258446e-02, 1.05848603e-01,
-9.21048969e-02, -1.80963904e-01, -1.98812112e-01, 7.26982281e-02,
3.29640329e-01, 1.04015507e-01, -1.24389552e-01, 2.69887168e-02,
-1.54598460e-01, -5.56088090e-02, 1.01781934e-01, -3.85247841e-02,
-3.20458487e-02, 3.86849903e-02, -8.98609757e-02, 8.27674717e-02,
1.06020764e-01, -7.34615028e-02, -4.03962284e-02, 1.98970288e-01,
-5.60568720e-02, 5.78189567e-02, 4.93795872e-02, -2.47523189e-04,
-6.07730448e-02, 2.19929889e-02, -1.10751927e-01, 6.69334084e-04,
8.69397819e-02, -1.09967209e-01, 1.43145397e-03, 8.74901861e-02,
-1.14516295e-01, 1.38158470e-01, 7.43495077e-02, -3.98697220e-02,
3.39040905e-02, 2.46684682e-02, -1.51388928e-01, -7.87943155e-02,
1.09218210e-01, -2.05471277e-01, 1.49658069e-01, 1.86885983e-01,
-3.31082232e-02, 1.01324990e-01, 3.32798958e-02, 5.33202365e-02,
-6.65426776e-02, -2.35776380e-02, -1.32266074e-01, -2.31741816e-02,
3.98471728e-02, 4.69821505e-02, -2.74340808e-02, -5.45420833e-02]
L4 = [-9.80433971e-02, -7.03648664e-03, -8.67843628e-04, -1.18527517e-01,
-5.99347353e-02, -3.52256261e-02, -4.00453769e-02, -9.58476141e-02,
2.23521233e-01, -1.88561112e-01, 1.72594860e-01, -4.11576033e-02,
-1.52830154e-01, -5.84353730e-02, -4.33000550e-03, 1.20912530e-01,
-1.34689406e-01, -1.79964483e-01, -3.15833911e-02, -9.25036967e-02,
-1.05666816e-02, -4.42105718e-03, 2.60549188e-02, 9.88835841e-02,
-1.62467003e-01, -4.19883490e-01, -1.71131760e-01, -9.64985639e-02,
-1.19223613e-02, -9.55987573e-02, 2.25513764e-02, 1.07761353e-01,
-2.36451998e-01, -1.74359381e-02, 5.71147725e-03, 1.24660656e-01,
6.69890456e-03, -1.86523274e-02, 1.85175732e-01, 2.91687660e-02,
-2.09594339e-01, -1.34366542e-01, 4.75538447e-02, 2.49922469e-01,
1.22993328e-01, 2.24278457e-02, 1.52391801e-02, -1.24563389e-02,
4.96755280e-02, -1.92227215e-01, 9.83141586e-02, 1.23155341e-01,
3.48911509e-02, 1.25203300e-02, 6.06377572e-02, -1.32613182e-01,
-5.22616133e-03, 7.46049434e-02, -1.53830111e-01, 4.96822223e-03,
-6.75934367e-03, -9.12150443e-02, -1.03079259e-01, -2.60316133e-02,
2.52563179e-01, 1.48371726e-01, -9.73276347e-02, -1.42138824e-01,
2.50091761e-01, -1.66190103e-01, 1.91132445e-02, 3.98359001e-02,
-1.27865523e-01, -1.90915748e-01, -2.90090829e-01, 2.87051760e-02,
4.39558297e-01, 1.14880979e-01, -1.23038329e-01, 1.02565333e-01,
-6.96414784e-02, -4.86778058e-02, 3.95676941e-02, 1.31223276e-01,
-7.37062097e-02, 1.40905678e-01, -4.61848751e-02, 1.32415891e-01,
1.50173992e-01, 1.56789012e-02, -6.01302609e-02, 1.37784094e-01,
-8.30642357e-02, 7.05572739e-02, 8.34304839e-02, 4.12208587e-02,
-8.44793320e-02, -2.76077650e-02, -1.74217999e-01, -7.80004263e-03,
7.51234069e-02, -2.18363479e-04, -4.15662788e-02, 1.44352645e-01,
-1.46695063e-01, 1.61359623e-01, 2.00959761e-02, -1.15739897e-01,
-4.57503423e-02, 8.08721706e-02, -1.02865808e-01, -1.25917166e-01,
1.34963557e-01, -2.33383894e-01, 1.03095181e-01, 1.53916180e-01,
-2.00787671e-02, 2.26398230e-01, 5.59305362e-02, 9.53603685e-02,
1.47923566e-02, -5.58686256e-02, -2.01987177e-01, -2.75421105e-02,
4.75574993e-02, -1.08102616e-02, 5.95078953e-02, 1.26588587e-02]#Close to L2
L5 = [-0.09945749, -0.00729111, 0.0092897 , -0.13243762, -0.06422047,
-0.02094417, -0.04948308, -0.12064691, 0.25643739, -0.19205171,
0.15657693, -0.03121898, -0.15308823, -0.02828152, -0.00710347,
0.11809425, -0.14299625, -0.16806611, -0.03130123, -0.08865803,
-0.0071869 , -0.00937061, 0.06185013, 0.10348818, -0.18077886,
-0.43158019, -0.17442586, -0.08369756, 0.00713679, -0.08146362,
-0.00203652, 0.09452251, -0.24805595, -0.02332739, -0.00440642,
0.13737108, 0.00089538, -0.04461086, 0.17354517, 0.02099614,
-0.22964232, -0.14414147, 0.07377731, 0.21512158, 0.12966961,
0.03000744, 0.01046804, -0.0051102 , 0.04499209, -0.1823051 ,
0.07896246, 0.11629909, 0.02137423, 0.02415319, 0.06205415,
-0.12419473, 0.01515957, 0.06340452, -0.1500473 , -0.01087676,
0.02246305, -0.0924818 , -0.09429674, -0.01974701, 0.25166726,
0.16988155, -0.09064031, -0.15273461, 0.21510246, -0.17729256,
0.00261592, 0.02652721, -0.13491498, -0.17640282, -0.31118405,
-0.00512062, 0.41723928, 0.13354909, -0.09930452, 0.10033775,
-0.06307391, -0.02699157, 0.04080637, 0.13098213, -0.08033849,
0.16044492, -0.04734115, 0.12942326, 0.14534265, 0.0249849 ,
-0.06554834, 0.13151604, -0.07915305, 0.08410332, 0.07018198,
0.06627715, -0.11851253, -0.02576792, -0.18880717, -0.00411349,
0.08233207, 0.04832725, -0.01709246, 0.15401676, -0.15097997,
0.16647491, 0.01185772, -0.11977788, -0.02823763, 0.08750527,
-0.10837749, -0.12731393, 0.11664411, -0.22722226, 0.09817819,
0.16637388, -0.01940754, 0.21179773, 0.06896579, 0.0847318 ,
0.00796246, -0.01696757, -0.19169487, -0.03898101, 0.0400917 ,
-0.03423833, 0.08150289, 0.0139573 ]#Close to L2
sgd_clf = linear_model.SGDClassifier(loss="modified_huber",max_iter =100)
classes = np.arange(5)
sgd_clf_fit = linear_model.SGDClassifier(loss="modified_huber",max_iter =100)
sgd_clf_fit.fit([L1,L2,L3],[1,2,3])
print("fit")
print(sgd_clf_fit.predict([L1]))
print(sgd_clf_fit.predict([L2]))
print(sgd_clf_fit.predict([L3]))
print(sgd_clf_fit.predict([L4]))
print(sgd_clf_fit.predict([L5]))
idx1 = 1
for i in range(10000):
sgd_clf.partial_fit([L1], [idx1], classes=classes)
idx2 = 2
for i in range(10000):
sgd_clf.partial_fit([L2],[idx2])
idx3 = 3
for i in range(10000):
sgd_clf.partial_fit([L3],[idx3])
print("partial_fit")
print(sgd_clf.predict([L1]))
print(sgd_clf.predict([L2]))
print(sgd_clf.predict([L3]))
print(sgd_clf.predict([L4]))
print(sgd_clf.predict([L5]))
How to improve my prediction result of partial_fit() to match with fit() ? I want to learn instance by instance and still predict accurately. I tried with different number of iterations but it didnot work.