diff --git a/questions/47_implement-gradient-descent-variants-with-mse-loss/solution.py b/questions/47_implement-gradient-descent-variants-with-mse-loss/solution.py index fa4579b3..5e296182 100644 --- a/questions/47_implement-gradient-descent-variants-with-mse-loss/solution.py +++ b/questions/47_implement-gradient-descent-variants-with-mse-loss/solution.py @@ -1,32 +1,33 @@ import numpy as np -def gradient_descent(X, y, weights, learning_rate, n_iterations, batch_size=1, method='batch'): - m = len(y) - +def gradient_descent( + X: np.ndarray, + y: np.ndarray, + weights: np.ndarray, + learning_rate: float, + n_iterations: int, + batch_size: int = 1, + method: str = "batch", +) -> np.ndarray: + m: int = X.shape[0] + n: int = X.shape[1] + w = np.zeros((n, 1)) + + match method: + case "batch": + batch_size: int = m + case "stochastic": + batch_size: int = 1 + case "mini_batch": + batch_size: int = batch_size + case _: + return w + for _ in range(n_iterations): - if method == 'batch': - # Calculate the gradient using all data points - predictions = X.dot(weights) - errors = predictions - y - gradient = 2 * X.T.dot(errors) / m - weights = weights - learning_rate * gradient - - elif method == 'stochastic': - # Update weights for each data point individually - for i in range(m): - prediction = X[i].dot(weights) - error = prediction - y[i] - gradient = 2 * X[i].T.dot(error) - weights = weights - learning_rate * gradient - - elif method == 'mini_batch': - # Update weights using sequential batches of data points without shuffling - for i in range(0, m, batch_size): - X_batch = X[i:i+batch_size] - y_batch = y[i:i+batch_size] - predictions = X_batch.dot(weights) - errors = predictions - y_batch - gradient = 2 * X_batch.T.dot(errors) / batch_size - weights = weights - learning_rate * gradient - - return weights + for i in range(0, m, batch_size): + x_batch = X[i : min(i + batch_size, m), :] + y_batch = y[i : min(i + batch_size, m)] + y_hat = x_batch @ w + derivative = x_batch.T @ (y_hat.reshape((-1, 1)) - y_batch.reshape((-1, 1))) + w = w - 2 * learning_rate / batch_size * derivative + return w.flatten()