2015-11-24 2 views
1

Я делаю демонстрацию различных типов регрессии в numpy с ipython, и до сих пор мне удалось без сложностей построить простую линейную регрессию. Теперь, когда я продолжаю делать квадратичную привязку к моим данным и переходить к ее построению, я не получаю квадратичную кривую, а вместо этого получаю много строк. Вот код, я бегу, что порождает проблему:Пытается построить квадратичную регрессию, получая несколько строк

import numpy 
from numpy import random 
from matplotlib import pyplot as plt 
import math 

# Generate random data 
X = random.random((100,1)) 
epsilon=random.randn(100,1) 
f = 3+5*X+epsilon 

# least squares system 
A =numpy.array([numpy.ones((100,1)),X,X**2]) 
A = numpy.squeeze(A) 
A = A.T 
quadfit = numpy.linalg.solve(numpy.dot(A.transpose(),A),numpy.dot(A.transpose(),f)) 

# plot the data and the fitted parabola 
qdbeta0,qdbeta1,qdbeta2 = quadfit[0][0],quadfit[1][0],quadfit[2][0] 
plt.scatter(X,f) 
plt.plot(X,qdbeta0+qdbeta1*X+qdbeta2*X**2) 
plt.show() 

То, что я получаю эту картину (масштаб изображения, чтобы показать проблему): enter image description here

Вы можете видеть, что вместо того, чтобы иметь единую параболу что соответствует данным, у меня есть огромное количество отдельных линий, которые делают то, что я не уверен. Любая помощь будет принята с благодарностью.

+1

Вам нужно сделать 'x' будет 1D массив , а не 100x1 2D-массив. – BrenBarn

+0

@BrenBarn Если я все время сжимаю до X, я получаю массив (100,), но теперь получаю десятки парабол, соответствующих моим данным. – Matt

ответ

2

Ваш X упорядочен случайным образом, поэтому это не очень хороший набор значений x, чтобы использовать одну сплошную линию, потому что она должна удвоиться на себя. Вы можете отсортировать его, я думаю, но ТБЙ я просто создать новый массив координат х и использовать те:

plt.scatter(X,f) 
x = np.linspace(0, 1, 1000) 
plt.plot(x,qdbeta0+qdbeta1*x+qdbeta2*x**2) 

дает мне example plot

+0

Фантастический, спасибо огромное! Это была ошибка с ошибкой с моей стороны. – Matt

+1

@Matt: Кстати, есть ли причина, по которой вы это делаете вручную? 'np.polyfit (X.squeeze(), f.squeeze(), 2)' дает вам ваши три коэффициента намного проще. – DSM

+0

Да, две причины. Во-первых, я не знал о полифите. Во-вторых, я кодирую математику, которую я записал, поэтому я пишу свой питон, чтобы отразить нормальные уравнения, чтобы связь была ясной. – Matt

Смежные вопросы