2016-02-24 3 views
0

У меня есть ошибка в следующем коде, который возвращает inf, inf для Thetas.Non-vectorized Gradient Descent

def gradient_descent(x, y, t0, t1, alpha, num_iters): 
for i in range(num_iters): 
    t0_sum = 0 
    t1_sum = 0 
    for i in range(m_num): # I have a feeling that the following partial derivatives are wrong 
    t0_sum += ((t1*x[i])+t0 - y[i]) 
    t1_sum += (((t1*x[i])+t0 - y[i])*(x[i])) 
    t0 = t0 - (alpha/m_num * (t0_sum)) 
    t1 = t1 - (alpha/m_num * (t1_sum)) 
return t0, t1 

Благодаря

+0

Переменные х и у представляют собой списки целых чисел, если это делает никакой разницы. –

ответ

0

Ваш код почти правильно. Единственное: вы использовали одну и ту же переменную для обоих циклов.

Просто измените i на j в первом цикле и он будет работать.

Для проверки вы можете использовать normal equation, который обеспечивает наилучшее решение проблемы без использования каких-либо петель.

Вот мой код:

import numpy as np 

def gradient_descent(x, y, t0, t1, alpha, num_iters): 

    m_num = len(x); 

    for j in range(num_iters): 
     t0_sum = 0 
     t1_sum = 0 

     for i in range(m_num): 
      t0_sum += ((t1*x[i])+t0 - y[i]) 
      t1_sum += (((t1*x[i])+t0 - y[i])*(x[i])) 

     t0 = t0 - (alpha/m_num * (t0_sum)) 
     t1 = t1 - (alpha/m_num * (t1_sum)) 

    return t0, t1 


def norm_equation(x, y): 
    m = len(x); 
    x = np.asarray([x]).transpose() 
    y = np.asarray([y]).transpose() 

    x = np.hstack((np.ones((m, 1)), x)) 

    t = np.dot(np.dot(np.linalg.pinv(np.dot(x.transpose(), x)), x.transpose()), y) 
    return t 


x = [6, 5, 8, 7, 5, 8, 7, 8, 6, 5, 5, 14] 
y = [17, 9, 13, 11, 6, 11, 4, 12, 6, 3, 3, 15] 

t0 = 0 
t1 = 0 
alpha = 0.008 
num_iters = 10000 

t0, t1 = gradient_descent(x, y, t0, t1, alpha, num_iters) 
print("Gradient descent:") 
print("t0 = " + str(t0) + "; t1 = " + str(t1)) 

print 

t = norm_equation(x, y) 
print("Normal equation") 
print("t0 = " + str(t.item(0)) + "; t1 = " + str(t.item(1))) 

Результат:

Gradient descent: 
t0 = 1.56634355366; t1 = 1.08575561307 

Normal equation 
t0 = 1.56666666667; t1 = 1.08571428571 
+0

Большое вам спасибо! Я не могу поверить, что это была такая простая ошибка –

Смежные вопросы