2016-08-04 2 views
2

я реализовал простую линейную регрессию, и я хочу попробовать его, установив нелинейную модельПопытки построить простую функцию - питон

специально я пытаюсь подобрать модель для функции y = x^3 + 5, например

это мой код

import numpy as np 
import numpy.matlib 
import matplotlib.pyplot as plt 

def predict(X,W): 
    return np.dot(X,W) 

def gradient(X, Y, W, regTerm=0): 
    return (-np.dot(X.T, Y) + np.dot(np.dot(X.T,X),W))/(m*k) + regTerm * W /(n*k) 

def cost(X, Y, W, regTerm=0): 
    m, k = Y.shape 
    n, k = W.shape 
    Yhat = predict(X, W) 
    return np.trace(np.dot(Y-Yhat,(Y-Yhat).T))/(2*m*k) + regTerm * np.trace(np.dot(W,W.T))/(2*n*k) 

def Rsquared(X, Y, W): 
    m, k = Y.shape 
    SSres = cost(X, Y, W) 
    Ybar = np.mean(Y,axis=0) 
    Ybar = np.matlib.repmat(Ybar, m, 1) 
    SStot = np.trace(np.dot(Y-Ybar,(Y-Ybar).T)) 

    return 1-SSres/SStot 

m = 10 
n = 200 
k = 1 

trX = np.random.rand(m, n) 
trX[:, 0] = 1 

for i in range(2, n): 
    trX[:, i] = trX[:, 1] ** i 

trY = trX[:, 1] ** 3 + 5 
trY = np.reshape(trY, (m, k)) 

W = np.random.rand(n, k) 

numIter = 10000 
learningRate = 0.5 

for i in range(0, numIter): 
    W = W - learningRate * gradient(trX, trY, W) 

domain = np.linspace(0,1,100000) 
powerDomain = np.copy(domain) 
m = powerDomain.shape[0] 
powerDomain = np.reshape(powerDomain, (m, 1)) 
powerDomain = np.matlib.repmat(powerDomain, 1, n) 

for i in range(1, n): 
    powerDomain[:, i] = powerDomain[:, 0] ** i 

print(Rsquared(trX, trY, W)) 
plt.plot(trX[:, 1],trY,'o', domain, predict(powerDomain, W),'r') 
plt.show() 

R^2 Я получаю очень близко к 1, то есть я нашел очень хорошо подходит для тренировочных данных, но это не показано на графиках. Когда я сюжет данные, как правило, выглядит следующим образом:

enter image description here

это выглядит, как если бы я underfitting данных, но с такой сложной гипотезой, 200 функций (то есть я позволить многочлены вплоть до х^200) и только 10 учебных примеров, я должен очень четко перерабатывать данные, поэтому я ожидаю, что красная линия пройдет через все синие точки и разнесется между ними.

Это не то, что я получаю, что меня смущает. Что случилось?

+0

Прежде всего, исправьте это: 'trX = np.random.rand (m) ** np.arange (n)' – Julien

+0

и 'powerDomain = domain ** np.arange (n)' – Julien

ответ

0

Вы забыли установить powerDomain[:,0]=1, поэтому ваш участок идет не так, как 0. И да, вы слишком уместны: посмотрите, как быстро ваш заговор загорится, как только вы выйдете из своего учебного домена.

Смежные вопросы