from model import build_modelfrom utils import save_model, save_plotsfrom datasets import train_loader, valid_loader, datasetfrom tqdm.notebook import tqdm
Classes: ['cloudy', 'desert', 'green_area', 'water']
Total number of images: 5631
Total training images: 4505
Total valid_images: 1126
Load the weights for Reset Model
lr =0.001epochs =20device = ('cuda'if torch.cuda.is_available() else'cpu')print(f"computation device: {device}\n")
computation device: cuda
model = build_model( pretrained=True, fine_tune=False, num_classes=len(dataset.classes)).to(device)# total parameters and trainable parameterstotal_params =sum(p.numel() for p in model.parameters())print(f"{total_params:,} total parameters.")total_trainable_params =sum( p.numel() for p in model.parameters() if p.requires_grad)print(f"{total_trainable_params:,} training parameters.\n")
[INFO]: Loading pre-trained weights
[INFO]: Freezing hidden layers...
21,286,724 total parameters.
2,052 training parameters.
# optimizeroptimizer = optim.Adam(model.parameters(), lr=lr)# loss functioncriterion = nn.CrossEntropyLoss()
Training and Validation Functions
def train(model, trainloader, optimizer, criterion): model.train()print('Training') train_running_loss =0.0 train_running_correct =0 counter =0for i, data in tqdm(enumerate(trainloader), total=len(trainloader)): counter +=1 image, labels = data image = image.to(device) labels = labels.to(device) optimizer.zero_grad()# forward pass outputs = model(image)# calculate the loss loss = criterion(outputs, labels) train_running_loss += loss.item()# calculate the accuracy _, preds = torch.max(outputs.data, 1) train_running_correct += (preds == labels).sum().item()# backpropagation loss.backward()# update the optimizer parameters optimizer.step()# loss and accuracy for the complete epoch epoch_loss = train_running_loss / counter epoch_acc =100.* (train_running_correct /len(trainloader.dataset))return epoch_loss, epoch_acc
def validate(model, testloader, criterion, class_names): model.eval()print('Validation') valid_running_loss =0.0 valid_running_correct =0 counter =0# we need two lists to keep track of class-wise accuracy class_correct =list(0.for i inrange(len(class_names))) class_total =list(0.for i inrange(len(class_names)))with torch.no_grad():for i, data in tqdm(enumerate(testloader), total=len(testloader)): counter +=1 image, labels = data image = image.to(device) labels = labels.to(device)# forward pass outputs = model(image)# calculate the loss loss = criterion(outputs, labels) valid_running_loss += loss.item()# calculate the accuracy _, preds = torch.max(outputs.data, 1) valid_running_correct += (preds == labels).sum().item()# calculate the accuracy for each class correct = (preds == labels).squeeze()for i inrange(len(preds)): label = labels[i] class_correct[label] += correct[i].item() class_total[label] +=1# loss and accuracy for the complete epoch epoch_loss = valid_running_loss / counter epoch_acc =100.* (valid_running_correct /len(testloader.dataset))# print the accuracy for each class after every epochprint('\n')for i inrange(len(class_names)):print(f"Accuracy of class {class_names[i]}: {100*class_correct[i]/class_total[i]}")print('\n')return epoch_loss, epoch_acc
Train for 20 Epochs
# lists to keep track of losses and accuraciestrain_loss, valid_loss = [], []train_acc, valid_acc = [], []# start the trainingfor epoch inrange(epochs):#print(f"[INFO]: Epoch {epoch+1} of {epochs}") train_epoch_loss, train_epoch_acc = train(model, train_loader, optimizer, criterion) valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader, criterion, dataset.classes) train_loss.append(train_epoch_loss) valid_loss.append(valid_epoch_loss) train_acc.append(train_epoch_acc) valid_acc.append(valid_epoch_acc)print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")print('-'*50)# save the trained model weightssave_model(epochs, model, optimizer, criterion)# save the loss and accuracy plotssave_plots(train_acc, valid_acc, train_loss, valid_loss)print('TRAINING COMPLETE')
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 95.57522123893806
Accuracy of class green_area: 96.53979238754326
Accuracy of class water: 93.8566552901024
Training loss: 0.057, training acc: 98.180
Validation loss: 0.142, validation acc: 96.359
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.37106918238993
Accuracy of class desert: 92.92035398230088
Accuracy of class green_area: 97.57785467128028
Accuracy of class water: 83.2764505119454
Training loss: 0.049, training acc: 98.557
Validation loss: 0.208, validation acc: 93.428
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 96.90265486725664
Accuracy of class green_area: 96.88581314878893
Accuracy of class water: 84.98293515358361
Training loss: 0.048, training acc: 98.424
Validation loss: 0.189, validation acc: 94.405
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 98.23008849557522
Accuracy of class green_area: 95.50173010380622
Accuracy of class water: 91.46757679180887
Training loss: 0.047, training acc: 98.690
Validation loss: 0.162, validation acc: 96.004
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 98.74213836477988
Accuracy of class desert: 97.78761061946902
Accuracy of class green_area: 89.96539792387543
Accuracy of class water: 96.24573378839591
Training loss: 0.060, training acc: 98.091
Validation loss: 0.161, validation acc: 95.648
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 98.11320754716981
Accuracy of class desert: 98.67256637168141
Accuracy of class green_area: 96.88581314878893
Accuracy of class water: 88.39590443686006
Training loss: 0.049, training acc: 98.313
Validation loss: 0.159, validation acc: 95.382
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 97.34513274336283
Accuracy of class green_area: 97.92387543252595
Accuracy of class water: 91.12627986348123
Training loss: 0.049, training acc: 98.402
Validation loss: 0.143, validation acc: 96.359
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 97.78761061946902
Accuracy of class green_area: 93.77162629757785
Accuracy of class water: 93.8566552901024
Training loss: 0.040, training acc: 98.912
Validation loss: 0.161, validation acc: 96.092
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 92.47787610619469
Accuracy of class green_area: 96.19377162629758
Accuracy of class water: 91.80887372013652
Training loss: 0.039, training acc: 98.690
Validation loss: 0.163, validation acc: 95.115
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 96.90265486725664
Accuracy of class green_area: 96.19377162629758
Accuracy of class water: 89.419795221843
Training loss: 0.044, training acc: 98.468
Validation loss: 0.162, validation acc: 95.382
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 97.48427672955975
Accuracy of class desert: 98.67256637168141
Accuracy of class green_area: 87.5432525951557
Accuracy of class water: 93.51535836177474
Training loss: 0.043, training acc: 98.513
Validation loss: 0.178, validation acc: 94.139
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 98.42767295597484
Accuracy of class desert: 98.67256637168141
Accuracy of class green_area: 88.23529411764706
Accuracy of class water: 96.24573378839591
Training loss: 0.041, training acc: 98.602
Validation loss: 0.162, validation acc: 95.293
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 98.23008849557522
Accuracy of class green_area: 86.85121107266436
Accuracy of class water: 92.83276450511946
Training loss: 0.044, training acc: 98.513
Validation loss: 0.183, validation acc: 94.139
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 97.34513274336283
Accuracy of class green_area: 94.80968858131487
Accuracy of class water: 92.15017064846417
Training loss: 0.042, training acc: 98.668
Validation loss: 0.157, validation acc: 95.826
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 98.67256637168141
Accuracy of class green_area: 92.38754325259515
Accuracy of class water: 95.56313993174061
Training loss: 0.044, training acc: 98.446
Validation loss: 0.143, validation acc: 96.359
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 97.78761061946902
Accuracy of class green_area: 96.53979238754326
Accuracy of class water: 93.51535836177474
Training loss: 0.039, training acc: 98.468
Validation loss: 0.125, validation acc: 96.714
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 96.90265486725664
Accuracy of class green_area: 95.15570934256056
Accuracy of class water: 92.83276450511946
Training loss: 0.038, training acc: 98.713
Validation loss: 0.149, validation acc: 96.004
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 98.74213836477988
Accuracy of class desert: 98.67256637168141
Accuracy of class green_area: 89.27335640138408
Accuracy of class water: 94.53924914675768
Training loss: 0.041, training acc: 98.713
Validation loss: 0.154, validation acc: 95.204
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.37106918238993
Accuracy of class desert: 97.78761061946902
Accuracy of class green_area: 96.19377162629758
Accuracy of class water: 88.73720136518772
Training loss: 0.040, training acc: 98.602
Validation loss: 0.163, validation acc: 95.471
--------------------------------------------------
Training
Validation
Accuracy of class cloudy: 99.05660377358491
Accuracy of class desert: 97.78761061946902
Accuracy of class green_area: 97.2318339100346
Accuracy of class water: 91.46757679180887
Training loss: 0.042, training acc: 98.557
Validation loss: 0.133, validation acc: 96.359
--------------------------------------------------
TRAINING COMPLETE
Inference
import torchimport cv2import torchvision.transforms as transformsfrom model import build_model
device ='cpu'
# list containing all the labelslabels = ['cloudy', 'desert', 'green_area', 'water']# initialize the model and load the trained weightsmodel = build_model( pretrained=False, fine_tune=False, num_classes=4).to(device)print('[INFO]: Loading custom-trained weights...')checkpoint = torch.load('outputs/model.pth', map_location=device)model.load_state_dict(checkpoint['model_state_dict'])model.eval()# define preprocess transformstransform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )])