Я пытаюсь реализовать логистический классификатор с помощью python. Цель состоит в том, чтобы обучить алго, чтобы идентифицировать цифры 0-9, используя набор данных рукописных цифр mnist. Однако fmin_cg, кажется, меняет размеры моих входных аргументов. Я пробовал переделывать аргументы внутри cost() и градиента() без везения; просто больше ошибок.Логистическая регрессия с использованием Scipy's fmin_cg
from scipy.io import loadmat
from numpy import shape, zeros, ones, dot, hstack, vstack, log, transpose, kron
from scipy.special import expit as sigmoid
import scipy.optimize
def cost(theta, X, y):
h = sigmoid(X.dot(theta))
pos_class = y.T.dot(log(h))
neg_class = (1.0-y).T.dot(log(1.0-h))
cost = ((-1.0/m)*(pos_class+neg_class))
return cost
def gradient(theta, X, y):
h = sigmoid(X.dot(theta))
grad = (1.0/m)*(X.T.dot((h-y)))
return grad
def one_vs_all(X, y, theta):
# add x1 feature,x1 = 1, to each example set
X = hstack((ones((m,1)),X))
# train the classifier for digit 9.0
temp_y = (y == 9.0)+0
result = scipy.optimize.fmin_cg(cost, fprime=gradient, x0=theta, \
args=(X, temp_y), maxiter=50, disp=False, full_output=True)
print result[1]
# Load data from Matlab file
data = loadmat('data.mat')
X,y = data['X'],data['y']
m,n = shape(X)
theta = zeros((n+1, 1))
one_vs_all(X, y, theta)
Ошибка получаю:
Traceback (most recent call last):
File "/Users/jkarimi91/Documents/Digit Recognizer/Digit_Recognizer.py", line 36, in <module>
one_vs_all(X, y, theta)
File "/Users/jkarimi91/Documents/Digit Recognizer/Digit_Recognizer.py", line 26, in one_vs_all
args=(X, temp_y), maxiter=50, disp=False, full_output=True)
File "/anaconda/lib/python2.7/site-packages/scipy/optimize/optimize.py", line 1092, in fmin_cg
res = _minimize_cg(f, x0, args, fprime, callback=callback, **opts)
File "/anaconda/lib/python2.7/site-packages/scipy/optimize/optimize.py", line 1156, in _minimize_cg
deltak = numpy.dot(gfk, gfk)
ValueError: shapes (401,5000) and (401,5000) not aligned: 5000 (dim 1) != 401 (dim 0)
[Finished in 1.0s with exit code 1]