Chương 3. Triển khai thực nghiệm
3.3.2. Tiến hành quá trình huấn luyện
Khai báo thiết bị
Bước này quan trọng vì nó thiết lập rõ ràng thiết bị tính toán để đào tạo mô hình mạng nơ-ron. Nó hoạt động như một công cụ cấu hình, chỉ đạo PyTorch phân bổ tenxơ trên CPU hoặc GPU. Ở đây đang dùng CPU làm thiết bị chính
b) Lớp CustomDataset
Lớp này sẽ kế thừa từ lớp datasets.ImageFolder của torch. Lớp này sẽ sử dụng bốn tham số: root cho đường dẫn hình ảnh, special_augment_transform để tăng cường hình ảnh giả mạo, general_augment_transform để tăng cường hình ảnh thực và cuối cùng là special_classes để xác định nhãn cho hình ảnh giả mạo.
Tiếp theo ta sẽ xác định các phép biến đổi hình ảnh. Đối với mỗi quy trình biến đổi, sẽ sử dụng transforms.Resize(224) và transforms.CenterCrop(224). Mặc dù hai quy trình này có vẻ giống nhau, nhưng chúng phục vụ các mục đích riêng biệt.
transforms.Resize(224) sẽ chia tỷ lệ hình ảnh sao cho cạnh ngắn hơn là 224 pixel, duy trì tỷ lệ khung hình. Sau đó, transforms.CenterCrop(224) sẽ cắt phần giữa của hình ảnh đã thay đổi kích thước này để làm cho nó có kích thước chính xác là 224x224 pixel.
Tiếp theo ta sẽ đọc dữ liệu từ cấu trúc đã tạo trước đó sau đó truyền vào lớp CustomDataset với BATCH_SIZE=16
c) Function huấn luyện EPOCH và validate EPOCH
model.train(): Phương thức khai báo để bắt đầu quá trình huấn luyện mô hình Sau đó ta tạo các biến để theo giõi số lượng dữ liệu mất mát(loss), điểm f1(f1_score), dự đoán chính xác
Vòng lặp enumerate trên tập train_loader từ dataset lấy chỉ số (idx) và theo dõi hiệu suất của mô hình trong suốt EPOCH. Điều này đặc biệt hữu ích nếu một EPOCH mất nhiều thời gian để hoàn thành
y_pred là tham số được tính toán sau mỗi lần dự báo của mô hình và so sánh với train_y_data để tính toán độ mất mát sau mỗi lần huấn luyện
Sau khi đã có đủ thông số sẽ tạo ra các biến để lưu trữ các tính toán như total_correct total_loss, tỷ lệ mất mát trung bình avg_loss, tỷ lệ chính xác accuracy
Hàm validate_one_epoch được sử dụng để đánh giá hiệu suất của mô hình trên tập dữ liệu kiểm tra (validation) sau mỗi epoch. Đây là một bước quan trọng trong quá trình huấn luyện mô hình học sâu, giúp kiểm tra xem mô hình có hoạt động tốt trên dữ liệu chưa từng thấy hay không, và từ đó tránh overfitting.
model.eval(): Chuyển mô hình sang chế độ đánh giá
Tắt tính toán gradient bằng torch.no_grad() để tăng hiệu quả và tiết kiệm bộ nhớ.
Các thông số được trả ra tương tự như hàm huấn luyện để đánh giá được hiệu suất sau mỗi EPOCH
d) Huấn luyện mô hình
Từ thư viện timm tải mô hình Vision Transformer (ViT)
vit_base_patch16_224.augreg_in21k_ft_in1k: Mô hình được huấn luyện trước trên tập dữ liệu ImageNet-21k (gồm 21.000 lớp). Sau đó, mô hình được fine- tune (tinh chỉnh) trên tập dữ liệu ImageNet-1k (1.000 lớp).
Sử dụng EarlyStopping để theo dõi chỉ số đánh giá (validation loss) và dừng huấn luyện nếu mô hình không cải thiện sau số epoch và lưu trọng số vào file vit_teacher.pth
Tiến hành huấn luyện theo vòng lặp với số lượng EPOCH được khai báo ban đầu (50)
Hình 3. 4 Quá trình huấn luyện bắt đầu
Learning rate (lr=3e-5) được chọn trong phạm vi khuyến nghị cho mô hình Transformer (1e-5 đến 1e-4).
e) Kết quả huấn luyện và đánh giá
Sử dụng matplotlib để trực quan hóa Độ mất mát và Chính xác của mô hình qua quá trình huấn luyện và đánh giá (validate) sau mỗi EPOCHS:
Kết quả thu được sau khi huấn luyện 50 EPOCHS như sau:
Hình 3. 5 Biểu đồ Losses and Accuracies
Có thể thấy qua mỗi EPOCH mô hình đều đã có cải thiện lớn về tính hiệu quả Tiếp theo sử dụng bộ dữ liệu test để kiểm qua tính hiệu quả của mô hình đối với dữ liệu chưa từng gặp:
Kết quả đo các thông số Accuracy (Độ chính xác) Recall(Bỏ sót) False Acceptance Rate(FAR: Độ nhầm lẫn) False Rejection Rate (FRR: Từ chối hợp lệ) Half Total Error Rate:
Hình 3. 6 Các chỉ số cơ bản của mô hình
Biểu đồ cột và ma trận nhầm lẫn như sau:
Hình 3. 7 Biểu đồ cột biểu hiện phần trăm các chỉ số
Hình 3. 8 Biểu đồ nhầm lẫn