Gradient descent là gì trong machine learning?
Gradient descent là một thuật toán tối ưu dùng để tìm ra bộ tham số tốt hơn cho mô hình machine learning bằng cách giảm dần lỗi dự đoán. Nói đơn giản, nó giống như cách mô hình "đi xuống dốc" từng bước để tìm điểm có sai số thấp hơn.
Nếu bạn đang thắc mắc gradient descent là gì trong machine learning, hãy nhớ câu ngắn gọn này: đó là cách mô hình điều chỉnh tham số để học tốt hơn từ dữ liệu.
Gradient descent dùng để làm gì?
Gradient descent dùng để làm gì? Nó được dùng để tối ưu hàm mất mát, tức là hàm đo mức sai của mô hình.
Ví dụ:
- Linear regression muốn giảm sai số giữa giá trị thật và giá trị dự đoán.
- Logistic regression muốn giảm lỗi phân loại.
- Neural network cũng dùng các biến thể của gradient descent để học.
Vì vậy, gradient descent không phải là một mô hình riêng, mà là công cụ giúp mô hình học.
Cách gradient descent hoạt động
Cách gradient descent hoạt động có thể hiểu qua 3 bước lặp lại:
- Tính lỗi hiện tại của mô hình.
- Tính xem cần thay đổi tham số theo hướng nào để lỗi giảm.
- Cập nhật tham số từng chút một.
Ví dụ một bước cập nhật rất đơn giản:
w = 5.0
learning_rate = 0.1
gradient = 2.0
w = w - learning_rate * gradient
print(w)
Ở đây:
wlà tham số.gradientcho biết hướng tăng giảm.learning_ratequyết định bước đi lớn hay nhỏ.
Mô hình sẽ lặp lại việc này nhiều lần cho đến khi lỗi giảm đủ nhiều hoặc gần như không giảm thêm nữa.
Ví dụ gradient descent trực quan
Hãy tưởng tượng bạn đang đứng trên một ngọn đồi trong sương mù và muốn xuống điểm thấp nhất. Bạn không nhìn thấy toàn bộ bản đồ, nhưng bạn có thể cảm nhận độ dốc ngay chỗ mình đứng.
- Nếu mặt đất nghiêng về bên trái, bạn đi sang trái.
- Nếu nghiêng về bên phải, bạn đi sang phải.
- Nếu dốc lớn, bạn biết mình đang ở khá xa điểm thấp.
Gradient descent hoạt động rất giống hình ảnh đó.
Ví dụ gradient descent bằng vòng lặp Python
Ví dụ dưới đây chỉ để minh họa cách một tham số được cập nhật qua nhiều vòng:
w = 10.0
learning_rate = 0.2
for step in range(5):
gradient = 2 * w
w = w - learning_rate * gradient
print(f"buoc {step + 1}: w = {w}")
Trong ví dụ này, w sẽ giảm dần về gần 0. Bạn có thể xem đó là quá trình tối ưu hóa để giảm lỗi trong một bài toán rất đơn giản.
Learning rate là gì?
Learning rate là kích thước bước đi của gradient descent.
- Nếu learning rate quá lớn, mô hình có thể nhảy qua lại và không hội tụ.
- Nếu learning rate quá nhỏ, mô hình học rất chậm.
Ví dụ so sánh:
learning_rate_nho = 0.001
learning_rate_lon = 1.0
print(learning_rate_nho)
print(learning_rate_lon)
Bạn chưa cần nhớ con số nào là tốt nhất. Điều quan trọng là hiểu learning rate ảnh hưởng trực tiếp đến tốc độ và độ ổn định khi học.
Thuật toán gradient descent trong linear regression
Trong linear regression, mô hình có các tham số như hệ số góc và intercept. Gradient descent giúp điều chỉnh các tham số đó sao cho sai số dự đoán nhỏ dần.
Đây là một trong những ví dụ kinh điển nhất để hiểu thuật toán gradient descent vì bài toán khá trực quan: tìm đường thẳng sao cho gần dữ liệu thật nhất.
Batch gradient descent là gì?
Batch gradient descent là gì? Đây là cách dùng toàn bộ dữ liệu train ở mỗi lần cập nhật tham số.
Ưu điểm:
- Ổn định hơn.
- Hướng đi mượt hơn.
Nhược điểm:
- Chậm nếu dữ liệu rất lớn.
Stochastic gradient descent là gì?
Stochastic gradient descent là gì? Đây là cách cập nhật tham số sau từng mẫu dữ liệu một.
Ưu điểm:
- Nhanh hơn trên dữ liệu lớn.
- Có thể thoát một số điểm kẹt cục bộ tốt hơn.
Nhược điểm:
- Dao động nhiều hơn.
- Đường học không mượt bằng batch gradient descent.
Mini-batch gradient descent là gì?
Mini-batch gradient descent là gì? Đây là cách dùng một nhóm nhỏ dữ liệu cho mỗi lần cập nhật. Nó là lựa chọn cân bằng giữa batch và stochastic.
Trong thực tế hiện đại, mini-batch gradient descent được dùng rất nhiều vì vừa đủ nhanh vừa đủ ổn định.
So sánh batch, stochastic và mini-batch gradient descent
Bạn có thể nhớ nhanh như sau:
- Batch: dùng toàn bộ dữ liệu.
- Stochastic: dùng từng mẫu một.
- Mini-batch: dùng một nhóm nhỏ dữ liệu.
Nếu chỉ mới học, bạn chưa cần đào quá sâu vào tối ưu hiệu năng. Chỉ cần nắm rõ ý tưởng này là đã rất tốt.
Những lỗi thường gặp khi mới học gradient descent
- Nghĩ gradient descent là một mô hình riêng.
- Không phân biệt gradient descent với linear regression hoặc logistic regression.
- Quên vai trò của learning rate.
- Cố nhớ công thức phức tạp trước khi hiểu ý tưởng tối ưu hóa.
Bài tập thực hành
Hãy tự viết một đoạn code ngắn cập nhật tham số w trong 5 bước bằng gradient descent. Sau đó thử:
- Đổi learning rate thành nhỏ hơn.
- Đổi learning rate thành lớn hơn.
- Quan sát
wthay đổi khác nhau ra sao.
Bạn có thể bắt đầu với mẫu sau:
w = 8.0
learning_rate = 0.1
for step in range(5):
gradient = 2 * w
w = w - learning_rate * gradient
print(step + 1, w)
Gợi ý: hãy tự giải thích vì sao learning rate quá lớn có thể làm quá trình học trở nên không ổn định.
Câu hỏi thường gặp về gradient descent
Gradient descent là gì?
Gradient descent là thuật toán tối ưu giúp mô hình điều chỉnh tham số để giảm lỗi dự đoán. Nó là một nền tảng rất quan trọng trong machine learning.
Gradient descent dùng để làm gì?
Nó dùng để tối ưu hàm mất mát, tức là giảm sai số của mô hình trong quá trình huấn luyện. Nhiều mô hình nổi tiếng đều dùng gradient descent hoặc biến thể của nó.
Gradient descent cho người mới nên hiểu thế nào?
Bạn có thể hình dung nó như việc đi xuống dốc từng bước để tìm điểm thấp nhất. Mỗi bước đi là một lần điều chỉnh tham số theo hướng làm lỗi nhỏ hơn.
Batch, stochastic và mini-batch gradient descent khác nhau ở đâu?
Chúng khác nhau ở lượng dữ liệu được dùng cho mỗi lần cập nhật tham số. Batch dùng toàn bộ dữ liệu, stochastic dùng từng mẫu, còn mini-batch dùng một nhóm nhỏ.
Tóm tắt
Trong bài này, bạn đã hiểu gradient descent là gì trong machine learning, cách nó hoạt động, vai trò của learning rate và sự khác nhau giữa batch, stochastic, mini-batch gradient descent. Đây là nền tảng quan trọng để hiểu sâu hơn cách mô hình thực sự học từ dữ liệu. Ở bài tiếp theo, chúng ta sẽ làm quen với KNN để thấy một mô hình đơn giản có thể phân loại bằng cách nhìn vào những hàng xóm gần nhất.
Bài viết liên quan

Next.js là gì? Tại sao nên dùng Next.js để làm web?
Giới thiệu Next.js — framework React phổ biến nhất. Tìm hiểu ưu điểm, tính năng nổi bật và khi nào nên dùng.

Con bug đầu tiên trong cuộc đời lập trình viên
Câu chuyện hài hước về lần đầu gặp bug và mất 3 tiếng để tìm ra nguyên nhân chỉ là... thiếu dấu chấm phẩy.

Hướng dẫn cài đặt Python chi tiết trên Windows, macOS, Linux
Hướng dẫn từng bước cài đặt Python trên mọi hệ điều hành. Kèm cách kiểm tra và chạy chương trình đầu tiên.