2017-02-07 8 views
2

У меня был следующий код, который возвращает правильные коэффициенты. Однако, независимо от того, где я положил вызов plot, я не могу получить какой-либо вывод графика.Функция спуска градиента с графиком вывода и строкой регрессии

Я не уверен, что здесь нужен воспроизводимый пример, так как я думаю, что это можно решить, посмотрев мою функцию gradientDescent ниже? Это моя первая попытка запуска этого алгоритма в R:

gradientDescent <- function(x, y, learn_rate, conv_threshold, n, max_iter) { 
    m <- runif(1, 0, 1) 
    c <- runif(1, 0, 1) 
    yhat <- m * x + c 
    cost_error <- (1/(n + 2)) * sum((y - yhat)^2) 
    converged = F 
    iterations = 0 
    while(converged == F) { 
    m_new <- m - learn_rate * ((1/n) * (sum((yhat - y) * x))) 
    c_new <- c - learn_rate * ((1/n) * (sum(yhat - y))) 
    m <- m_new 
    c <- c_new 
    yhat <- m * x + c 
    cost_error_new <- (1/(n + 2)) * sum((y - yhat)^2) 
    if(cost_error - cost_error_new <= conv_threshold) { 
     converged = T 
    } 
    iterations = iterations + 1 
    if(iterations > max_iter) { 
     converged = T 
    return(paste("Optimal intercept:", c, "Optimal slope:", m)) 
    } 
    } 
} 
+0

Необходимо воспроизвести пример. Код, который вы указали, даже не показывает, где и что вы пытаетесь построить, поэтому диагностика невозможна. –

+0

Я отправил на лету, так что не успел вынести все это. В любом случае, я все-таки сортировал. Спасибо что нашли время ответить. – Seanosapien

ответ

1

Неясно, что вы делали, что было неэффективным. Базовые графические функции plot и abline должны иметь возможность производить вывод даже при использовании внутри функций. Графики решетки и ggplot2 основаны на grid -grpahics и поэтому нуждаются в print(), обернутых вокруг вызовов функций для создания вывода (как описано в R-FAQ). Так что попробуйте следующее:

gradientDescent <- function(x, y, learn_rate, conv_threshold, n, max_iter) 
    { ## plot.new() perhaps not needed 
     plot(x,y) 
     m <- runif(1, 0, 1) 
     c <- runif(1, 0, 1) 
     yhat <- m * x + c 
     cost_error <- (1/(n + 2)) * sum((y - yhat)^2) 
     converged = F 
     iterations = 0 
     while(converged == F) { 
     m_new <- m - learn_rate * ((1/n) * (sum((yhat - y) * x))) 
     c_new <- c - learn_rate * ((1/n) * (sum(yhat - y))) 
     m <- m_new 
     c <- c_new 
     yhat <- m * x + c 
     cost_error_new <- (1/(n + 2)) * sum((y - yhat)^2) 
     if(cost_error - cost_error_new <= conv_threshold) { 
      converged = T 
     } 
     iterations = iterations + 1 
     if(iterations > max_iter) { abline(c, m) #calculated 
      dev.off() 
      converged = T 
     return(paste("Optimal intercept:", c, "Optimal slope:", m)) 
     } 
     } 
    } 
+0

Красивые !! Это работает хорошо. Первоначально я не получал никакого сюжета с регрессионной линией, но после удаления dev.off() он работает. Благодарю вас, сэр. – Seanosapien

+0

Кроме того, оболочка печати вокруг вызова ggplot отлично работает, чтобы создать вариант вышеприведенного примера. Еще раз спасибо. – Seanosapien