Transfer Learning with Pytorch for classifying bird species
Posted on Nov 05, 2024 @ 02:29 AM under Python Machine Learning Deep Learning Computer Vision
(NB: This is a work in progress)
Objective
This project aims to leverage transfer learning with EfficientNet B0 to address the classification of bird species. We will begin by utilizing PyTorch's Dataset and DataLoader for efficient data manipulation. Subsequently, we will construct the model and implement a training loop, monitoring both training and validation losses throughout the process. Following the training phase, the model will be evaluated against a test set, and its accuracy will be calculated. This notebook will be improved over time.
- Let's begin by importing the necessary packages and libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm
import matplotlib.pyplot as plt
import sys
from tqdm.notebook import tqdm
1. Create class to extend PyTorch's Dataset
class BirdDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data = ImageFolder(data_dir, transform=transform)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
@property
def classes(self):
return self.data.classes
The BirdDataset class defines the dataset for birds, using images from the specified directory data_dir
. Transformations will be used to ensure the images are all of the same size. The folders are in a specific standard order for the Dataset to work properly. In each folder, train
, test
, valid
, there are subfolders with the class names that will contain the images that relates to that class.
-
Initialization: In
__init__
, it loads images fromdata_dir
usingImageFolder
, with optional transformations. -
Length:
__len__
returns the number of images in the dataset. -
Item Access:
__getitem__
retrieves an image by its index. -
Classes Property:
classes
returns the different bird classes (labels) in the dataset.
Let's create the DataSet and explore
train_dir = './bird-species/train'
"""This code will create dataset object which will allow access to all images and other properties as defined in BirdDataset.
If we select an index, this will return a tuple with the image (PIL) and the associated class label.
"""
dataset = BirdDataset(
data_dir=train_dir
)
"""
Check the number of images in the dataset. You see we have 70,626. I am working on my personal computer and although I have a GPU,
I will probably need to first train on a few images here and later tranfer to Google colab
"""
len(dataset)
70626
#Extracting the image and label for an index and displaying the image and the label. PIL image so we can display in browser
image, label = dataset[600]
print(f'Label: {label}')
image
Label: 3
I want to visualize the number of images for each class count to have an idea of the imbalances if they exist. I will now do this.
# I will import Counter to help with this
from collections import Counter
# Get the labels for images
labels = [dataset[i][1] for i in range(len(dataset))]
#count occurrences in each label
label_counts = Counter(labels)
print(label_counts)
Counter({262: 248, 163: 233, 333: 233, 410: 217, 446: 214, 118: 213, 334: 207, 166: 203, 365: 202, 431: 201, 396: 200, 157: 198, 109: 197, 307: 197, 355: 197, 324: 196, 54: 194, 354: 194, 426: 193, 400: 192, 40: 190, 180: 190, 297: 190, 427: 190, 380: 189, 34: 188, 77: 188, 310: 188, 7: 187, 21: 187, 38: 187, 89: 186, 171: 186, 270: 186, 87: 185, 316: 185, 96: 184, 100: 184, 235: 184, 271: 183, 391: 182, 419: 181, 1: 180, 120: 180, 2: 179, 13: 179, 19: 179, 97: 177, 136: 177, 138: 177, 207: 177, 217: 176, 238: 176, 337: 176, 213: 175, 249: 175, 287: 175, 296: 175, 389: 175, 442: 175, 86: 173, 260: 173, 351: 173, 434: 172, 14: 170, 288: 170, 384: 170, 230: 169, 433: 169, 42: 168, 57: 168, 81: 168, 129: 168, 88: 167, 251: 167, 359: 167, 360: 167, 405: 167, 0: 166, 11: 166, 52: 166, 55: 166, 135: 166, 172: 166, 192: 166, 346: 166, 10: 165, 46: 165, 74: 165, 99: 165, 125: 165, 191: 165, 232: 165, 264: 165, 418: 165, 16: 164, 35: 164, 41: 164, 73: 164, 119: 164, 130: 164, 154: 164, 226: 164, 283: 164, 369: 164, 390: 164, 423: 164, 443: 164, 53: 163, 75: 163, 121: 163, 131: 163, 132: 163, 134: 163, 152: 163, 158: 163, 187: 163, 190: 163, 229: 163, 244: 163, 245: 163, 246: 163, 268: 163, 284: 163, 301: 163, 335: 163, 368: 163, 105: 162, 117: 162, 127: 162, 160: 162, 206: 162, 219: 162, 242: 162, 275: 162, 281: 162, 413: 162, 432: 162, 9: 161, 108: 161, 133: 161, 141: 161, 142: 161, 221: 161, 303: 161, 332: 161, 339: 161, 394: 161, 398: 161, 402: 161, 441: 161, 48: 160, 58: 160, 63: 160, 72: 160, 93: 160, 103: 160, 107: 160, 110: 160, 148: 160, 159: 160, 181: 160, 198: 160, 208: 160, 227: 160, 250: 160, 252: 160, 253: 160, 292: 160, 298: 160, 321: 160, 327: 160, 328: 160, 336: 160, 349: 160, 361: 160, 371: 160, 378: 160, 392: 160, 397: 160, 403: 160, 408: 160, 101: 159, 164: 159, 215: 159, 239: 159, 265: 159, 280: 159, 364: 159, 415: 159, 440: 159, 449: 159, 15: 158, 222: 158, 240: 158, 259: 158, 282: 158, 387: 158, 36: 157, 91: 157, 139: 157, 173: 157, 182: 157, 185: 157, 234: 157, 293: 157, 315: 157, 330: 157, 69: 156, 78: 156, 155: 156, 237: 156, 261: 156, 277: 156, 322: 156, 341: 156, 342: 156, 379: 156, 383: 156, 424: 156, 6: 155, 18: 155, 33: 155, 56: 155, 85: 155, 104: 155, 114: 155, 149: 155, 194: 155, 211: 155, 223: 155, 225: 155, 241: 155, 256: 155, 258: 155, 263: 155, 278: 155, 300: 155, 305: 155, 317: 155, 338: 155, 425: 155, 429: 155, 445: 155, 448: 155, 4: 154, 25: 154, 26: 154, 31: 154, 39: 154, 49: 154, 66: 154, 92: 154, 94: 154, 95: 154, 98: 154, 111: 154, 113: 154, 156: 154, 176: 154, 177: 154, 189: 154, 199: 154, 286: 154, 294: 154, 311: 154, 312: 154, 314: 154, 320: 154, 326: 154, 343: 154, 350: 154, 385: 154, 409: 154, 416: 154, 421: 154, 438: 154, 444: 154, 447: 154, 32: 153, 84: 153, 106: 153, 122: 153, 150: 153, 170: 153, 209: 153, 273: 153, 344: 153, 347: 153, 348: 153, 358: 153, 381: 153, 430: 153, 439: 153, 67: 152, 153: 152, 179: 152, 204: 152, 224: 152, 274: 152, 331: 152, 352: 152, 367: 152, 370: 152, 375: 152, 428: 152, 247: 151, 255: 151, 356: 151, 382: 151, 12: 150, 27: 150, 29: 150, 47: 150, 90: 150, 183: 150, 201: 150, 202: 150, 203: 150, 205: 150, 220: 150, 254: 150, 272: 150, 401: 150, 233: 149, 276: 149, 372: 147, 168: 146, 24: 144, 65: 144, 83: 144, 126: 144, 144: 144, 145: 144, 186: 144, 197: 144, 228: 144, 267: 144, 285: 144, 399: 144, 407: 144, 62: 143, 102: 143, 279: 143, 291: 143, 302: 143, 406: 143, 420: 143, 64: 142, 68: 142, 151: 142, 210: 142, 299: 142, 376: 142, 436: 142, 37: 141, 43: 141, 128: 141, 143: 141, 266: 141, 313: 141, 353: 141, 357: 141, 411: 141, 175: 140, 216: 140, 236: 140, 323: 140, 345: 140, 20: 139, 28: 139, 30: 139, 137: 139, 147: 139, 169: 139, 174: 139, 184: 139, 188: 139, 318: 139, 374: 139, 161: 138, 196: 138, 212: 138, 231: 138, 243: 138, 290: 138, 295: 138, 386: 138, 435: 138, 437: 138, 3: 137, 5: 137, 45: 137, 51: 137, 80: 137, 115: 137, 123: 137, 162: 137, 306: 137, 319: 137, 422: 137, 44: 136, 60: 136, 165: 136, 200: 136, 289: 136, 412: 136, 417: 136, 76: 135, 146: 135, 193: 135, 195: 135, 218: 135, 257: 135, 304: 135, 325: 135, 377: 135, 404: 135, 414: 135, 23: 134, 82: 134, 178: 134, 248: 134, 8: 133, 17: 133, 70: 133, 79: 133, 112: 133, 116: 133, 140: 133, 362: 133, 50: 132, 59: 132, 61: 132, 124: 132, 167: 132, 308: 132, 329: 132, 363: 132, 373: 132, 388: 132, 393: 132, 22: 131, 71: 131, 214: 131, 269: 131, 309: 131, 340: 130, 366: 130, 395: 130})
#check the number of classes
len(label_counts) #number of classes
450
Let's use matplotlib to visualize
# Get classes and then values in separate lists.
classes = list(label_counts.keys())
counts = list(label_counts.values())
plt.figure(figsize=(10, 5))
plt.bar(classes, counts, color='green')
plt.xlabel('Classes')
plt.ylabel('Counts')
plt.title('Class Distribution in Bird Dataset')
plt.xticks(rotation=45) # Rotate x-axis labels for better visibility
plt.show()
We have some imbalances that we will address. Also there are over 400 classes. I think I might only take 50 classes to train and see how that works out before attempting more. I would also like to see the actual names of the species of bird associated with the count. I need to get a dictionary that associated the target class with the folder name first before I can do this.
# Get a dictionary that associates the target class with folder names
class_targets = {v:k for k, v in dataset.data.class_to_idx.items()}
print(class_targets)
#first_10_classes = dict(list(class_targets.items())[:10])
#first_10_classes
{0: 'ABBOTTS BABBLER', 1: 'ABBOTTS BOOBY', 2: 'ABYSSINIAN GROUND HORNBILL', 3: 'AFRICAN CROWNED CRANE', 4: 'AFRICAN EMERALD CUCKOO', 5: 'AFRICAN FIREFINCH', 6: 'AFRICAN OYSTER CATCHER', 7: 'AFRICAN PIED HORNBILL', 8: 'ALBATROSS', 9: 'ALBERTS TOWHEE', 10: 'ALEXANDRINE PARAKEET', 11: 'ALPINE CHOUGH', 12: 'ALTAMIRA YELLOWTHROAT', 13: 'AMERICAN AVOCET', 14: 'AMERICAN BITTERN', 15: 'AMERICAN COOT', 16: 'AMERICAN FLAMINGO', 17: 'AMERICAN GOLDFINCH', 18: 'AMERICAN KESTREL', 19: 'AMERICAN PIPIT', 20: 'AMERICAN REDSTART', 21: 'AMERICAN WIGEON', 22: 'AMETHYST WOODSTAR', 23: 'ANDEAN GOOSE', 24: 'ANDEAN LAPWING', 25: 'ANDEAN SISKIN', 26: 'ANHINGA', 27: 'ANIANIAU', 28: 'ANNAS HUMMINGBIRD', 29: 'ANTBIRD', 30: 'ANTILLEAN EUPHONIA', 31: 'APAPANE', 32: 'APOSTLEBIRD', 33: 'ARARIPE MANAKIN', 34: 'ASHY STORM PETREL', 35: 'ASHY THRUSHBIRD', 36: 'ASIAN CRESTED IBIS', 37: 'ASIAN DOLLARD BIRD', 38: 'AUCKLAND SHAQ', 39: 'AUSTRAL CANASTERO', 40: 'AUSTRALASIAN FIGBIRD', 41: 'AVADAVAT', 42: 'AZARAS SPINETAIL', 43: 'AZURE BREASTED PITTA', 44: 'AZURE JAY', 45: 'AZURE TANAGER', 46: 'AZURE TIT', 47: 'BAIKAL TEAL', 48: 'BALD EAGLE', 49: 'BALD IBIS', 50: 'BALI STARLING', 51: 'BALTIMORE ORIOLE', 52: 'BANANAQUIT', 53: 'BAND TAILED GUAN', 54: 'BANDED BROADBILL', 55: 'BANDED PITA', 56: 'BANDED STILT', 57: 'BAR-TAILED GODWIT', 58: 'BARN OWL', 59: 'BARN SWALLOW', 60: 'BARRED PUFFBIRD', 61: 'BARROWS GOLDENEYE', 62: 'BAY-BREASTED WARBLER', 63: 'BEARDED BARBET', 64: 'BEARDED BELLBIRD', 65: 'BEARDED REEDLING', 66: 'BELTED KINGFISHER', 67: 'BIRD OF PARADISE', 68: 'BLACK & YELLOW BROADBILL', 69: 'BLACK BAZA', 70: 'BLACK COCKATO', 71: 'BLACK FRANCOLIN', 72: 'BLACK SKIMMER', 73: 'BLACK SWAN', 74: 'BLACK TAIL CRAKE', 75: 'BLACK THROATED BUSHTIT', 76: 'BLACK THROATED WARBLER', 77: 'BLACK VENTED SHEARWATER', 78: 'BLACK VULTURE', 79: 'BLACK-CAPPED CHICKADEE', 80: 'BLACK-NECKED GREBE', 81: 'BLACK-THROATED SPARROW', 82: 'BLACKBURNIAM WARBLER', 83: 'BLONDE CRESTED WOODPECKER', 84: 'BLOOD PHEASANT', 85: 'BLUE COAU', 86: 'BLUE DACNIS', 87: 'BLUE GROUSE', 88: 'BLUE HERON', 89: 'BLUE MALKOHA', 90: 'BLUE THROATED TOUCANET', 91: 'BOBOLINK', 92: 'BORNEAN BRISTLEHEAD', 93: 'BORNEAN LEAFBIRD', 94: 'BORNEAN PHEASANT', 95: 'BRANDT CORMARANT', 96: 'BREWERS BLACKBIRD', 97: 'BROWN CREPPER', 98: 'BROWN NOODY', 99: 'BROWN THRASHER', 100: 'BUFFLEHEAD', 101: 'BULWERS PHEASANT', 102: 'BURCHELLS COURSER', 103: 'BUSH TURKEY', 104: 'CAATINGA CACHOLOTE', 105: 'CACTUS WREN', 106: 'CALIFORNIA CONDOR', 107: 'CALIFORNIA GULL', 108: 'CALIFORNIA QUAIL', 109: 'CAMPO FLICKER', 110: 'CANARY', 111: 'CAPE GLOSSY STARLING', 112: 'CAPE LONGCLAW', 113: 'CAPE MAY WARBLER', 114: 'CAPE ROCK THRUSH', 115: 'CAPPED HERON', 116: 'CAPUCHINBIRD', 117: 'CARMINE BEE-EATER', 118: 'CASPIAN TERN', 119: 'CASSOWARY', 120: 'CEDAR WAXWING', 121: 'CERULEAN WARBLER', 122: 'CHARA DE COLLAR', 123: 'CHATTERING LORY', 124: 'CHESTNET BELLIED EUPHONIA', 125: 'CHINESE BAMBOO PARTRIDGE', 126: 'CHINESE POND HERON', 127: 'CHIPPING SPARROW', 128: 'CHUCAO TAPACULO', 129: 'CHUKAR PARTRIDGE', 130: 'CINNAMON ATTILA', 131: 'CINNAMON FLYCATCHER', 132: 'CINNAMON TEAL', 133: 'CLARKS NUTCRACKER', 134: 'COCK OF THE ROCK', 135: 'COCKATOO', 136: 'COLLARED ARACARI', 137: 'COMMON FIRECREST', 138: 'COMMON GRACKLE', 139: 'COMMON HOUSE MARTIN', 140: 'COMMON IORA', 141: 'COMMON LOON', 142: 'COMMON POORWILL', 143: 'COMMON STARLING', 144: 'COPPERY TAILED COUCAL', 145: 'CRAB PLOVER', 146: 'CRANE HAWK', 147: 'CREAM COLORED WOODPECKER', 148: 'CRESTED AUKLET', 149: 'CRESTED CARACARA', 150: 'CRESTED COUA', 151: 'CRESTED FIREBACK', 152: 'CRESTED KINGFISHER', 153: 'CRESTED NUTHATCH', 154: 'CRESTED OROPENDOLA', 155: 'CRESTED SHRIKETIT', 156: 'CRIMSON CHAT', 157: 'CRIMSON SUNBIRD', 158: 'CROW', 159: 'CROWNED PIGEON', 160: 'CUBAN TODY', 161: 'CUBAN TROGON', 162: 'CURL CRESTED ARACURI', 163: 'D-ARNAUDS BARBET', 164: 'DALMATIAN PELICAN', 165: 'DARJEELING WOODPECKER', 166: 'DARK EYED JUNCO', 167: 'DARWINS FLYCATCHER', 168: 'DAURIAN REDSTART', 169: 'DEMOISELLE CRANE', 170: 'DOUBLE BARRED FINCH', 171: 'DOUBLE BRESTED CORMARANT', 172: 'DOUBLE EYED FIG PARROT', 173: 'DOWNY WOODPECKER', 174: 'DUSKY LORY', 175: 'DUSKY ROBIN', 176: 'EARED PITA', 177: 'EASTERN BLUEBIRD', 178: 'EASTERN BLUEBONNET', 179: 'EASTERN GOLDEN WEAVER', 180: 'EASTERN MEADOWLARK', 181: 'EASTERN ROSELLA', 182: 'EASTERN TOWEE', 183: 'EASTERN WIP POOR WILL', 184: 'ECUADORIAN HILLSTAR', 185: 'EGYPTIAN GOOSE', 186: 'ELEGANT TROGON', 187: 'ELLIOTS PHEASANT', 188: 'EMERALD TANAGER', 189: 'EMPEROR PENGUIN', 190: 'EMU', 191: 'ENGGANO MYNA', 192: 'EURASIAN BULLFINCH', 193: 'EURASIAN GOLDEN ORIOLE', 194: 'EURASIAN MAGPIE', 195: 'EUROPEAN GOLDFINCH', 196: 'EUROPEAN TURTLE DOVE', 197: 'EVENING GROSBEAK', 198: 'FAIRY BLUEBIRD', 199: 'FAIRY PENGUIN', 200: 'FAIRY TERN', 201: 'FAN TAILED WIDOW', 202: 'FASCIATED WREN', 203: 'FIERY MINIVET', 204: 'FIORDLAND PENGUIN', 205: 'FIRE TAILLED MYZORNIS', 206: 'FLAME BOWERBIRD', 207: 'FLAME TANAGER', 208: 'FRIGATE', 209: 'GAMBELS QUAIL', 210: 'GANG GANG COCKATOO', 211: 'GILA WOODPECKER', 212: 'GILDED FLICKER', 213: 'GLOSSY IBIS', 214: 'GO AWAY BIRD', 215: 'GOLD WING WARBLER', 216: 'GOLDEN BOWER BIRD', 217: 'GOLDEN CHEEKED WARBLER', 218: 'GOLDEN CHLOROPHONIA', 219: 'GOLDEN EAGLE', 220: 'GOLDEN PARAKEET', 221: 'GOLDEN PHEASANT', 222: 'GOLDEN PIPIT', 223: 'GOULDIAN FINCH', 224: 'GRANDALA', 225: 'GRAY CATBIRD', 226: 'GRAY KINGBIRD', 227: 'GRAY PARTRIDGE', 228: 'GREAT GRAY OWL', 229: 'GREAT JACAMAR', 230: 'GREAT KISKADEE', 231: 'GREAT POTOO', 232: 'GREAT TINAMOU', 233: 'GREAT XENOPS', 234: 'GREATER PEWEE', 235: 'GREATOR SAGE GROUSE', 236: 'GREEN BROADBILL', 237: 'GREEN JAY', 238: 'GREEN MAGPIE', 239: 'GREY CUCKOOSHRIKE', 240: 'GREY PLOVER', 241: 'GROVED BILLED ANI', 242: 'GUINEA TURACO', 243: 'GUINEAFOWL', 244: 'GURNEYS PITTA', 245: 'GYRFALCON', 246: 'HAMERKOP', 247: 'HARLEQUIN DUCK', 248: 'HARLEQUIN QUAIL', 249: 'HARPY EAGLE', 250: 'HAWAIIAN GOOSE', 251: 'HAWFINCH', 252: 'HELMET VANGA', 253: 'HEPATIC TANAGER', 254: 'HIMALAYAN BLUETAIL', 255: 'HIMALAYAN MONAL', 256: 'HOATZIN', 257: 'HOODED MERGANSER', 258: 'HOOPOES', 259: 'HORNED GUAN', 260: 'HORNED LARK', 261: 'HORNED SUNGEM', 262: 'HOUSE FINCH', 263: 'HOUSE SPARROW', 264: 'HYACINTH MACAW', 265: 'IBERIAN MAGPIE', 266: 'IBISBILL', 267: 'IMPERIAL SHAQ', 268: 'INCA TERN', 269: 'INDIAN BUSTARD', 270: 'INDIAN PITTA', 271: 'INDIAN ROLLER', 272: 'INDIAN VULTURE', 273: 'INDIGO BUNTING', 274: 'INDIGO FLYCATCHER', 275: 'INLAND DOTTEREL', 276: 'IVORY BILLED ARACARI', 277: 'IVORY GULL', 278: 'IWI', 279: 'JABIRU', 280: 'JACK SNIPE', 281: 'JANDAYA PARAKEET', 282: 'JAPANESE ROBIN', 283: 'JAVA SPARROW', 284: 'JOCOTOCO ANTPITTA', 285: 'KAGU', 286: 'KAKAPO', 287: 'KILLDEAR', 288: 'KING EIDER', 289: 'KING VULTURE', 290: 'KIWI', 291: 'KOOKABURRA', 292: 'LARK BUNTING', 293: 'LAZULI BUNTING', 294: 'LESSER ADJUTANT', 295: 'LILAC ROLLER', 296: 'LITTLE AUK', 297: 'LOGGERHEAD SHRIKE', 298: 'LONG-EARED OWL', 299: 'MAGPIE GOOSE', 300: 'MALABAR HORNBILL', 301: 'MALACHITE KINGFISHER', 302: 'MALAGASY WHITE EYE', 303: 'MALEO', 304: 'MALLARD DUCK', 305: 'MANDRIN DUCK', 306: 'MANGROVE CUCKOO', 307: 'MARABOU STORK', 308: 'MASKED BOOBY', 309: 'MASKED LAPWING', 310: 'MCKAYS BUNTING', 311: 'MIKADO PHEASANT', 312: 'MOURNING DOVE', 313: 'MYNA', 314: 'NICOBAR PIGEON', 315: 'NOISY FRIARBIRD', 316: 'NORTHERN BEARDLESS TYRANNULET', 317: 'NORTHERN CARDINAL', 318: 'NORTHERN FLICKER', 319: 'NORTHERN FULMAR', 320: 'NORTHERN GANNET', 321: 'NORTHERN GOSHAWK', 322: 'NORTHERN JACANA', 323: 'NORTHERN MOCKINGBIRD', 324: 'NORTHERN PARULA', 325: 'NORTHERN RED BISHOP', 326: 'NORTHERN SHOVELER', 327: 'OCELLATED TURKEY', 328: 'OKINAWA RAIL', 329: 'ORANGE BRESTED BUNTING', 330: 'ORIENTAL BAY OWL', 331: 'OSPREY', 332: 'OSTRICH', 333: 'OVENBIRD', 334: 'OYSTER CATCHER', 335: 'PAINTED BUNTING', 336: 'PALILA', 337: 'PARADISE TANAGER', 338: 'PARAKETT AKULET', 339: 'PARUS MAJOR', 340: 'PATAGONIAN SIERRA FINCH', 341: 'PEACOCK', 342: 'PEREGRINE FALCON', 343: 'PHILIPPINE EAGLE', 344: 'PINK ROBIN', 345: 'POMARINE JAEGER', 346: 'PUFFIN', 347: 'PURPLE FINCH', 348: 'PURPLE GALLINULE', 349: 'PURPLE MARTIN', 350: 'PURPLE SWAMPHEN', 351: 'PYGMY KINGFISHER', 352: 'QUETZAL', 353: 'RAINBOW LORIKEET', 354: 'RAZORBILL', 355: 'RED BEARDED BEE EATER', 356: 'RED BELLIED PITTA', 357: 'RED BROWED FINCH', 358: 'RED FACED CORMORANT', 359: 'RED FACED WARBLER', 360: 'RED FODY', 361: 'RED HEADED DUCK', 362: 'RED HEADED WOODPECKER', 363: 'RED HONEY CREEPER', 364: 'RED NAPED TROGON', 365: 'RED TAILED HAWK', 366: 'RED TAILED THRUSH', 367: 'RED WINGED BLACKBIRD', 368: 'RED WISKERED BULBUL', 369: 'REGENT BOWERBIRD', 370: 'RING-NECKED PHEASANT', 371: 'ROADRUNNER', 372: 'ROBIN', 373: 'ROCK DOVE', 374: 'ROSY FACED LOVEBIRD', 375: 'ROUGH LEG BUZZARD', 376: 'ROYAL FLYCATCHER', 377: 'RUBY THROATED HUMMINGBIRD', 378: 'RUDY KINGFISHER', 379: 'RUFOUS KINGFISHER', 380: 'RUFUOS MOTMOT', 381: 'SAMATRAN THRUSH', 382: 'SAND MARTIN', 383: 'SANDHILL CRANE', 384: 'SATYR TRAGOPAN', 385: 'SCARLET CROWNED FRUIT DOVE', 386: 'SCARLET IBIS', 387: 'SCARLET MACAW', 388: 'SCARLET TANAGER', 389: 'SHOEBILL', 390: 'SHORT BILLED DOWITCHER', 391: 'SKUA', 392: 'SMITHS LONGSPUR', 393: 'SNOWY EGRET', 394: 'SNOWY OWL', 395: 'SNOWY PLOVER', 396: 'SORA', 397: 'SPANGLED COTINGA', 398: 'SPLENDID WREN', 399: 'SPOON BILED SANDPIPER', 400: 'SPOONBILL', 401: 'SPOTTED CATBIRD', 402: 'SRI LANKA BLUE MAGPIE', 403: 'STEAMER DUCK', 404: 'STORK BILLED KINGFISHER', 405: 'STRAWBERRY FINCH', 406: 'STRIPED OWL', 407: 'STRIPPED MANAKIN', 408: 'STRIPPED SWALLOW', 409: 'SUPERB STARLING', 410: 'SWINHOES PHEASANT', 411: 'TAILORBIRD', 412: 'TAIWAN MAGPIE', 413: 'TAKAHE', 414: 'TASMANIAN HEN', 415: 'TEAL DUCK', 416: 'TIT MOUSE', 417: 'TOUCHAN', 418: 'TOWNSENDS WARBLER', 419: 'TREE SWALLOW', 420: 'TRICOLORED BLACKBIRD', 421: 'TROPICAL KINGBIRD', 422: 'TRUMPTER SWAN', 423: 'TURKEY VULTURE', 424: 'TURQUOISE MOTMOT', 425: 'UMBRELLA BIRD', 426: 'VARIED THRUSH', 427: 'VEERY', 428: 'VENEZUELIAN TROUPIAL', 429: 'VERMILION FLYCATHER', 430: 'VICTORIA CROWNED PIGEON', 431: 'VIOLET GREEN SWALLOW', 432: 'VIOLET TURACO', 433: 'VULTURINE GUINEAFOWL', 434: 'WALL CREAPER', 435: 'WATTLED CURASSOW', 436: 'WATTLED LAPWING', 437: 'WHIMBREL', 438: 'WHITE BROWED CRAKE', 439: 'WHITE CHEEKED TURACO', 440: 'WHITE CRESTED HORNBILL', 441: 'WHITE NECKED RAVEN', 442: 'WHITE TAILED TROPIC', 443: 'WHITE THROATED BEE EATER', 444: 'WILD TURKEY', 445: 'WILSONS BIRD OF PARADISE', 446: 'WOOD DUCK', 447: 'YELLOW BELLIED FLOWERPECKER', 448: 'YELLOW CACIQUE', 449: 'YELLOW HEADED BLACKBIRD'}
#Now we can see the counts for each bird species and we see that House Finch has the most and Snow Plover has the least
labels = [class_targets[dataset[i][1]] for i in range(len(dataset))]
#count occurrences in each label
label_counts = Counter(labels)
print(label_counts)
Counter({'HOUSE FINCH': 248, 'D-ARNAUDS BARBET': 233, 'OVENBIRD': 233, 'SWINHOES PHEASANT': 217, 'WOOD DUCK': 214, 'CASPIAN TERN': 213, 'OYSTER CATCHER': 207, 'DARK EYED JUNCO': 203, 'RED TAILED HAWK': 202, 'VIOLET GREEN SWALLOW': 201, 'SORA': 200, 'CRIMSON SUNBIRD': 198, 'CAMPO FLICKER': 197, 'MARABOU STORK': 197, 'RED BEARDED BEE EATER': 197, 'NORTHERN PARULA': 196, 'BANDED BROADBILL': 194, 'RAZORBILL': 194, 'VARIED THRUSH': 193, 'SPOONBILL': 192, 'AUSTRALASIAN FIGBIRD': 190, 'EASTERN MEADOWLARK': 190, 'LOGGERHEAD SHRIKE': 190, 'VEERY': 190, 'RUFUOS MOTMOT': 189, 'ASHY STORM PETREL': 188, 'BLACK VENTED SHEARWATER': 188, 'MCKAYS BUNTING': 188, 'AFRICAN PIED HORNBILL': 187, 'AMERICAN WIGEON': 187, 'AUCKLAND SHAQ': 187, 'BLUE MALKOHA': 186, 'DOUBLE BRESTED CORMARANT': 186, 'INDIAN PITTA': 186, 'BLUE GROUSE': 185, 'NORTHERN BEARDLESS TYRANNULET': 185, 'BREWERS BLACKBIRD': 184, 'BUFFLEHEAD': 184, 'GREATOR SAGE GROUSE': 184, 'INDIAN ROLLER': 183, 'SKUA': 182, 'TREE SWALLOW': 181, 'ABBOTTS BOOBY': 180, 'CEDAR WAXWING': 180, 'ABYSSINIAN GROUND HORNBILL': 179, 'AMERICAN AVOCET': 179, 'AMERICAN PIPIT': 179, 'BROWN CREPPER': 177, 'COLLARED ARACARI': 177, 'COMMON GRACKLE': 177, 'FLAME TANAGER': 177, 'GOLDEN CHEEKED WARBLER': 176, 'GREEN MAGPIE': 176, 'PARADISE TANAGER': 176, 'GLOSSY IBIS': 175, 'HARPY EAGLE': 175, 'KILLDEAR': 175, 'LITTLE AUK': 175, 'SHOEBILL': 175, 'WHITE TAILED TROPIC': 175, 'BLUE DACNIS': 173, 'HORNED LARK': 173, 'PYGMY KINGFISHER': 173, 'WALL CREAPER': 172, 'AMERICAN BITTERN': 170, 'KING EIDER': 170, 'SATYR TRAGOPAN': 170, 'GREAT KISKADEE': 169, 'VULTURINE GUINEAFOWL': 169, 'AZARAS SPINETAIL': 168, 'BAR-TAILED GODWIT': 168, 'BLACK-THROATED SPARROW': 168, 'CHUKAR PARTRIDGE': 168, 'BLUE HERON': 167, 'HAWFINCH': 167, 'RED FACED WARBLER': 167, 'RED FODY': 167, 'STRAWBERRY FINCH': 167, 'ABBOTTS BABBLER': 166, 'ALPINE CHOUGH': 166, 'BANANAQUIT': 166, 'BANDED PITA': 166, 'COCKATOO': 166, 'DOUBLE EYED FIG PARROT': 166, 'EURASIAN BULLFINCH': 166, 'PUFFIN': 166, 'ALEXANDRINE PARAKEET': 165, 'AZURE TIT': 165, 'BLACK TAIL CRAKE': 165, 'BROWN THRASHER': 165, 'CHINESE BAMBOO PARTRIDGE': 165, 'ENGGANO MYNA': 165, 'GREAT TINAMOU': 165, 'HYACINTH MACAW': 165, 'TOWNSENDS WARBLER': 165, 'AMERICAN FLAMINGO': 164, 'ASHY THRUSHBIRD': 164, 'AVADAVAT': 164, 'BLACK SWAN': 164, 'CASSOWARY': 164, 'CINNAMON ATTILA': 164, 'CRESTED OROPENDOLA': 164, 'GRAY KINGBIRD': 164, 'JAVA SPARROW': 164, 'REGENT BOWERBIRD': 164, 'SHORT BILLED DOWITCHER': 164, 'TURKEY VULTURE': 164, 'WHITE THROATED BEE EATER': 164, 'BAND TAILED GUAN': 163, 'BLACK THROATED BUSHTIT': 163, 'CERULEAN WARBLER': 163, 'CINNAMON FLYCATCHER': 163, 'CINNAMON TEAL': 163, 'COCK OF THE ROCK': 163, 'CRESTED KINGFISHER': 163, 'CROW': 163, 'ELLIOTS PHEASANT': 163, 'EMU': 163, 'GREAT JACAMAR': 163, 'GURNEYS PITTA': 163, 'GYRFALCON': 163, 'HAMERKOP': 163, 'INCA TERN': 163, 'JOCOTOCO ANTPITTA': 163, 'MALACHITE KINGFISHER': 163, 'PAINTED BUNTING': 163, 'RED WISKERED BULBUL': 163, 'CACTUS WREN': 162, 'CARMINE BEE-EATER': 162, 'CHIPPING SPARROW': 162, 'CUBAN TODY': 162, 'FLAME BOWERBIRD': 162, 'GOLDEN EAGLE': 162, 'GUINEA TURACO': 162, 'INLAND DOTTEREL': 162, 'JANDAYA PARAKEET': 162, 'TAKAHE': 162, 'VIOLET TURACO': 162, 'ALBERTS TOWHEE': 161, 'CALIFORNIA QUAIL': 161, 'CLARKS NUTCRACKER': 161, 'COMMON LOON': 161, 'COMMON POORWILL': 161, 'GOLDEN PHEASANT': 161, 'MALEO': 161, 'OSTRICH': 161, 'PARUS MAJOR': 161, 'SNOWY OWL': 161, 'SPLENDID WREN': 161, 'SRI LANKA BLUE MAGPIE': 161, 'WHITE NECKED RAVEN': 161, 'BALD EAGLE': 160, 'BARN OWL': 160, 'BEARDED BARBET': 160, 'BLACK SKIMMER': 160, 'BORNEAN LEAFBIRD': 160, 'BUSH TURKEY': 160, 'CALIFORNIA GULL': 160, 'CANARY': 160, 'CRESTED AUKLET': 160, 'CROWNED PIGEON': 160, 'EASTERN ROSELLA': 160, 'FAIRY BLUEBIRD': 160, 'FRIGATE': 160, 'GRAY PARTRIDGE': 160, 'HAWAIIAN GOOSE': 160, 'HELMET VANGA': 160, 'HEPATIC TANAGER': 160, 'LARK BUNTING': 160, 'LONG-EARED OWL': 160, 'NORTHERN GOSHAWK': 160, 'OCELLATED TURKEY': 160, 'OKINAWA RAIL': 160, 'PALILA': 160, 'PURPLE MARTIN': 160, 'RED HEADED DUCK': 160, 'ROADRUNNER': 160, 'RUDY KINGFISHER': 160, 'SMITHS LONGSPUR': 160, 'SPANGLED COTINGA': 160, 'STEAMER DUCK': 160, 'STRIPPED SWALLOW': 160, 'BULWERS PHEASANT': 159, 'DALMATIAN PELICAN': 159, 'GOLD WING WARBLER': 159, 'GREY CUCKOOSHRIKE': 159, 'IBERIAN MAGPIE': 159, 'JACK SNIPE': 159, 'RED NAPED TROGON': 159, 'TEAL DUCK': 159, 'WHITE CRESTED HORNBILL': 159, 'YELLOW HEADED BLACKBIRD': 159, 'AMERICAN COOT': 158, 'GOLDEN PIPIT': 158, 'GREY PLOVER': 158, 'HORNED GUAN': 158, 'JAPANESE ROBIN': 158, 'SCARLET MACAW': 158, 'ASIAN CRESTED IBIS': 157, 'BOBOLINK': 157, 'COMMON HOUSE MARTIN': 157, 'DOWNY WOODPECKER': 157, 'EASTERN TOWEE': 157, 'EGYPTIAN GOOSE': 157, 'GREATER PEWEE': 157, 'LAZULI BUNTING': 157, 'NOISY FRIARBIRD': 157, 'ORIENTAL BAY OWL': 157, 'BLACK BAZA': 156, 'BLACK VULTURE': 156, 'CRESTED SHRIKETIT': 156, 'GREEN JAY': 156, 'HORNED SUNGEM': 156, 'IVORY GULL': 156, 'NORTHERN JACANA': 156, 'PEACOCK': 156, 'PEREGRINE FALCON': 156, 'RUFOUS KINGFISHER': 156, 'SANDHILL CRANE': 156, 'TURQUOISE MOTMOT': 156, 'AFRICAN OYSTER CATCHER': 155, 'AMERICAN KESTREL': 155, 'ARARIPE MANAKIN': 155, 'BANDED STILT': 155, 'BLUE COAU': 155, 'CAATINGA CACHOLOTE': 155, 'CAPE ROCK THRUSH': 155, 'CRESTED CARACARA': 155, 'EURASIAN MAGPIE': 155, 'GILA WOODPECKER': 155, 'GOULDIAN FINCH': 155, 'GRAY CATBIRD': 155, 'GROVED BILLED ANI': 155, 'HOATZIN': 155, 'HOOPOES': 155, 'HOUSE SPARROW': 155, 'IWI': 155, 'MALABAR HORNBILL': 155, 'MANDRIN DUCK': 155, 'NORTHERN CARDINAL': 155, 'PARAKETT AKULET': 155, 'UMBRELLA BIRD': 155, 'VERMILION FLYCATHER': 155, 'WILSONS BIRD OF PARADISE': 155, 'YELLOW CACIQUE': 155, 'AFRICAN EMERALD CUCKOO': 154, 'ANDEAN SISKIN': 154, 'ANHINGA': 154, 'APAPANE': 154, 'AUSTRAL CANASTERO': 154, 'BALD IBIS': 154, 'BELTED KINGFISHER': 154, 'BORNEAN BRISTLEHEAD': 154, 'BORNEAN PHEASANT': 154, 'BRANDT CORMARANT': 154, 'BROWN NOODY': 154, 'CAPE GLOSSY STARLING': 154, 'CAPE MAY WARBLER': 154, 'CRIMSON CHAT': 154, 'EARED PITA': 154, 'EASTERN BLUEBIRD': 154, 'EMPEROR PENGUIN': 154, 'FAIRY PENGUIN': 154, 'KAKAPO': 154, 'LESSER ADJUTANT': 154, 'MIKADO PHEASANT': 154, 'MOURNING DOVE': 154, 'NICOBAR PIGEON': 154, 'NORTHERN GANNET': 154, 'NORTHERN SHOVELER': 154, 'PHILIPPINE EAGLE': 154, 'PURPLE SWAMPHEN': 154, 'SCARLET CROWNED FRUIT DOVE': 154, 'SUPERB STARLING': 154, 'TIT MOUSE': 154, 'TROPICAL KINGBIRD': 154, 'WHITE BROWED CRAKE': 154, 'WILD TURKEY': 154, 'YELLOW BELLIED FLOWERPECKER': 154, 'APOSTLEBIRD': 153, 'BLOOD PHEASANT': 153, 'CALIFORNIA CONDOR': 153, 'CHARA DE COLLAR': 153, 'CRESTED COUA': 153, 'DOUBLE BARRED FINCH': 153, 'GAMBELS QUAIL': 153, 'INDIGO BUNTING': 153, 'PINK ROBIN': 153, 'PURPLE FINCH': 153, 'PURPLE GALLINULE': 153, 'RED FACED CORMORANT': 153, 'SAMATRAN THRUSH': 153, 'VICTORIA CROWNED PIGEON': 153, 'WHITE CHEEKED TURACO': 153, 'BIRD OF PARADISE': 152, 'CRESTED NUTHATCH': 152, 'EASTERN GOLDEN WEAVER': 152, 'FIORDLAND PENGUIN': 152, 'GRANDALA': 152, 'INDIGO FLYCATCHER': 152, 'OSPREY': 152, 'QUETZAL': 152, 'RED WINGED BLACKBIRD': 152, 'RING-NECKED PHEASANT': 152, 'ROUGH LEG BUZZARD': 152, 'VENEZUELIAN TROUPIAL': 152, 'HARLEQUIN DUCK': 151, 'HIMALAYAN MONAL': 151, 'RED BELLIED PITTA': 151, 'SAND MARTIN': 151, 'ALTAMIRA YELLOWTHROAT': 150, 'ANIANIAU': 150, 'ANTBIRD': 150, 'BAIKAL TEAL': 150, 'BLUE THROATED TOUCANET': 150, 'EASTERN WIP POOR WILL': 150, 'FAN TAILED WIDOW': 150, 'FASCIATED WREN': 150, 'FIERY MINIVET': 150, 'FIRE TAILLED MYZORNIS': 150, 'GOLDEN PARAKEET': 150, 'HIMALAYAN BLUETAIL': 150, 'INDIAN VULTURE': 150, 'SPOTTED CATBIRD': 150, 'GREAT XENOPS': 149, 'IVORY BILLED ARACARI': 149, 'ROBIN': 147, 'DAURIAN REDSTART': 146, 'ANDEAN LAPWING': 144, 'BEARDED REEDLING': 144, 'BLONDE CRESTED WOODPECKER': 144, 'CHINESE POND HERON': 144, 'COPPERY TAILED COUCAL': 144, 'CRAB PLOVER': 144, 'ELEGANT TROGON': 144, 'EVENING GROSBEAK': 144, 'GREAT GRAY OWL': 144, 'IMPERIAL SHAQ': 144, 'KAGU': 144, 'SPOON BILED SANDPIPER': 144, 'STRIPPED MANAKIN': 144, 'BAY-BREASTED WARBLER': 143, 'BURCHELLS COURSER': 143, 'JABIRU': 143, 'KOOKABURRA': 143, 'MALAGASY WHITE EYE': 143, 'STRIPED OWL': 143, 'TRICOLORED BLACKBIRD': 143, 'BEARDED BELLBIRD': 142, 'BLACK & YELLOW BROADBILL': 142, 'CRESTED FIREBACK': 142, 'GANG GANG COCKATOO': 142, 'MAGPIE GOOSE': 142, 'ROYAL FLYCATCHER': 142, 'WATTLED LAPWING': 142, 'ASIAN DOLLARD BIRD': 141, 'AZURE BREASTED PITTA': 141, 'CHUCAO TAPACULO': 141, 'COMMON STARLING': 141, 'IBISBILL': 141, 'MYNA': 141, 'RAINBOW LORIKEET': 141, 'RED BROWED FINCH': 141, 'TAILORBIRD': 141, 'DUSKY ROBIN': 140, 'GOLDEN BOWER BIRD': 140, 'GREEN BROADBILL': 140, 'NORTHERN MOCKINGBIRD': 140, 'POMARINE JAEGER': 140, 'AMERICAN REDSTART': 139, 'ANNAS HUMMINGBIRD': 139, 'ANTILLEAN EUPHONIA': 139, 'COMMON FIRECREST': 139, 'CREAM COLORED WOODPECKER': 139, 'DEMOISELLE CRANE': 139, 'DUSKY LORY': 139, 'ECUADORIAN HILLSTAR': 139, 'EMERALD TANAGER': 139, 'NORTHERN FLICKER': 139, 'ROSY FACED LOVEBIRD': 139, 'CUBAN TROGON': 138, 'EUROPEAN TURTLE DOVE': 138, 'GILDED FLICKER': 138, 'GREAT POTOO': 138, 'GUINEAFOWL': 138, 'KIWI': 138, 'LILAC ROLLER': 138, 'SCARLET IBIS': 138, 'WATTLED CURASSOW': 138, 'WHIMBREL': 138, 'AFRICAN CROWNED CRANE': 137, 'AFRICAN FIREFINCH': 137, 'AZURE TANAGER': 137, 'BALTIMORE ORIOLE': 137, 'BLACK-NECKED GREBE': 137, 'CAPPED HERON': 137, 'CHATTERING LORY': 137, 'CURL CRESTED ARACURI': 137, 'MANGROVE CUCKOO': 137, 'NORTHERN FULMAR': 137, 'TRUMPTER SWAN': 137, 'AZURE JAY': 136, 'BARRED PUFFBIRD': 136, 'DARJEELING WOODPECKER': 136, 'FAIRY TERN': 136, 'KING VULTURE': 136, 'TAIWAN MAGPIE': 136, 'TOUCHAN': 136, 'BLACK THROATED WARBLER': 135, 'CRANE HAWK': 135, 'EURASIAN GOLDEN ORIOLE': 135, 'EUROPEAN GOLDFINCH': 135, 'GOLDEN CHLOROPHONIA': 135, 'HOODED MERGANSER': 135, 'MALLARD DUCK': 135, 'NORTHERN RED BISHOP': 135, 'RUBY THROATED HUMMINGBIRD': 135, 'STORK BILLED KINGFISHER': 135, 'TASMANIAN HEN': 135, 'ANDEAN GOOSE': 134, 'BLACKBURNIAM WARBLER': 134, 'EASTERN BLUEBONNET': 134, 'HARLEQUIN QUAIL': 134, 'ALBATROSS': 133, 'AMERICAN GOLDFINCH': 133, 'BLACK COCKATO': 133, 'BLACK-CAPPED CHICKADEE': 133, 'CAPE LONGCLAW': 133, 'CAPUCHINBIRD': 133, 'COMMON IORA': 133, 'RED HEADED WOODPECKER': 133, 'BALI STARLING': 132, 'BARN SWALLOW': 132, 'BARROWS GOLDENEYE': 132, 'CHESTNET BELLIED EUPHONIA': 132, 'DARWINS FLYCATCHER': 132, 'MASKED BOOBY': 132, 'ORANGE BRESTED BUNTING': 132, 'RED HONEY CREEPER': 132, 'ROCK DOVE': 132, 'SCARLET TANAGER': 132, 'SNOWY EGRET': 132, 'AMETHYST WOODSTAR': 131, 'BLACK FRANCOLIN': 131, 'GO AWAY BIRD': 131, 'INDIAN BUSTARD': 131, 'MASKED LAPWING': 131, 'PATAGONIAN SIERRA FINCH': 130, 'RED TAILED THRUSH': 130, 'SNOWY PLOVER': 130})
Below I create a convenient function to view sample of birds from the images, given a dataset.
import random
# view multiple images
def view_multiple_images(dataset, count=8, figsize=(40,30)):
#The number of rows for display
rows = count//4
if count%4 > 0: # if we cannot divide into equal rows. Then we have to add an additional row for the extra images
rows += 1
# We could shorten the code above, like this: rows = count // 4 + (count % 4 > 0)
class_targets = {v:k for k, v in dataset.data.class_to_idx.items()}
idx_list = random.sample(range(len(dataset)), count)
for column, value in enumerate(idx_list):
image, label = dataset[value]
plt.subplot(rows, 4, column+1)
plt.title(f'Label: {class_targets[label]}', fontsize=7)
plt.axis('off')
plt.imshow(image)
plt.subplots_adjust(wspace=0.8, hspace=0.5)
plt.tight_layout()
plt.show()
view_multiple_images(dataset)
Create a dataset with a smaller subset of data
I created a small standalone Python program that will select a certain number of folders. This script will sort the folders and select the folders with the fewest images, which is still a good number for the class. This way, the number of images in each class is almost balanced. The program creates two new folders for training and validation (test folders can also be created). The program is shown below.
import os
import shutil
def count_items_in_folder(folder):
return sum(len(files) for _, _, files in os.walk(folder))
def select_least_items_folder(input_folder, n=10):
subfolders = [f.path for f in os.scandir(input_folder) if f.is_dir() ]
# Create a list of tuples (folder_path, item_count)
folder_item_counts = [(subfolder, count_items_in_folder(subfolder)) for subfolder in subfolders]
# Sort the list by item count and select the least n folders
least_items_folders = sorted(folder_item_counts, key=lambda x: x[1])[:n]
return [folder for folder, _ in least_items_folders]
def copy_selected_folders(selected_folders, output_folder_train, input_folder_test_valid, output_folder_test_valid):
"""Copy selected folders to the output folder"""
os.makedirs(output_folder_train, exist_ok=True)
os.makedirs(output_folder_test_valid, exist_ok=True)
for folder in selected_folders:
folder_name = os.path.basename(folder)
destination = os.path.join(output_folder_train, folder_name)
shutil.copytree(folder, destination)
input_test_valid_dir = os.path.join(input_folder_test_valid, folder_name)
destination_test_valid = os.path.join(output_folder_test_valid, folder_name)
shutil.copytree(input_test_valid_dir, destination_test_valid)
def main(input_folder_train, output_folder_train, input_folder_test_valid, output_folder_test_valid):
selected_folders = select_least_items_folder(input_folder_train, 50)
#copy_selected_folders(selected_folders, output_folder_train)
copy_selected_folders(selected_folders, output_folder_train, input_folder_test_valid, output_folder_test_valid)
if __name__ == "__main__":
img_dir_train = 'train'
img_dir_small_train = 'train_small'
# change test to valid for `test` string and `test_small` if you want to create validation folder
img_dir_test_valid = 'test'
img_dir_small_test_valid = 'test_small'
#img_dir2 ="/home/kevin/Programming/Python/ComputerVision/Bird_Classification/bird-species/train_v1/ABBOTTS BABBLER"
#c = count_items_in_folder(img_dir2)
#c = select_least_items_folder(img_dir1)
main(img_dir_train, img_dir_small_train, img_dir_test_valid, img_dir_small_test_valid)
# Variables pointing to folders
train_dir_small = './bird-species/train_small'
valid_dir_small = './bird-species/valid_small'
test_dir_small = './bird-species/test_small'
# Create a new dataset with transform. In the transform, we are resizing our image for faster processing. Let's try resizing to 128x128
# Notice in the transform, I am converting the image to Tensor so that PyTorch can process the data. We will also use DataLoader.
# Using DataLoader to wrap our dataset will allow us to batch or data for speed and process in parallel. We will use a batch size of 32
batch_size = 32
transform = transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor(),
])
train_dataset = BirdDataset(train_dir_small, transform=transform)
valid_dataset = BirdDataset(valid_dir_small, transform=transform)
test_dataset = BirdDataset(test_dir_small, transform=transform)
# We need to shuffle the training loader. This is crucial as it helps with: optimization, preventing overfitting, improve generalization..etc
# Shuffling is not necessary for validation and test set
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
Let's check the shape and our dataset and dataloader to see if the shapes and datatypes are valid
print(len(train_dataset))
print(train_dataset[300])
print(train_dataset[300][0].shape)
6671
(tensor([[[0.4078, 0.4118, 0.4118, ..., 0.4510, 0.4471, 0.4471],
[0.4039, 0.4078, 0.4078, ..., 0.4510, 0.4510, 0.4471],
[0.4078, 0.4078, 0.4118, ..., 0.4510, 0.4549, 0.4510],
...,
[0.4314, 0.4353, 0.4314, ..., 0.3569, 0.3490, 0.3490],
[0.4235, 0.4196, 0.4235, ..., 0.3529, 0.3529, 0.3490],
[0.4196, 0.4157, 0.4157, ..., 0.3529, 0.3569, 0.3490]],
[[0.5216, 0.5176, 0.5176, ..., 0.5922, 0.5882, 0.5882],
[0.5098, 0.5137, 0.5137, ..., 0.5922, 0.5922, 0.5882],
[0.5059, 0.5098, 0.5098, ..., 0.5922, 0.5961, 0.5922],
...,
[0.5647, 0.5686, 0.5725, ..., 0.4706, 0.4627, 0.4627],
[0.5608, 0.5569, 0.5569, ..., 0.4667, 0.4667, 0.4627],
[0.5569, 0.5529, 0.5490, ..., 0.4667, 0.4706, 0.4627]],
[[0.2471, 0.2471, 0.2471, ..., 0.2471, 0.2431, 0.2353],
[0.2392, 0.2431, 0.2431, ..., 0.2431, 0.2392, 0.2353],
[0.2392, 0.2431, 0.2471, ..., 0.2392, 0.2431, 0.2392],
...,
[0.2667, 0.2667, 0.2627, ..., 0.2118, 0.2039, 0.2039],
[0.2510, 0.2510, 0.2588, ..., 0.2078, 0.2078, 0.2078],
[0.2392, 0.2431, 0.2510, ..., 0.2078, 0.2157, 0.2118]]]), 2)
torch.Size([3, 128, 128])
# Get a dictionary that associates the target class with folder names
class_targets_small = {v:k for k, v in train_dataset.data.class_to_idx.items()}
print(class_targets_small)
{0: 'AFRICAN CROWNED CRANE', 1: 'ALBATROSS', 2: 'AMERICAN GOLDFINCH', 3: 'AMETHYST WOODSTAR', 4: 'ANDEAN GOOSE', 5: 'AZURE JAY', 6: 'BALI STARLING', 7: 'BARN SWALLOW', 8: 'BARRED PUFFBIRD', 9: 'BARROWS GOLDENEYE', 10: 'BLACK COCKATO', 11: 'BLACK FRANCOLIN', 12: 'BLACK THROATED WARBLER', 13: 'BLACK-CAPPED CHICKADEE', 14: 'BLACKBURNIAM WARBLER', 15: 'CAPE LONGCLAW', 16: 'CAPUCHINBIRD', 17: 'CHESTNET BELLIED EUPHONIA', 18: 'COMMON IORA', 19: 'CRANE HAWK', 20: 'DARJEELING WOODPECKER', 21: 'DARWINS FLYCATCHER', 22: 'EASTERN BLUEBONNET', 23: 'EURASIAN GOLDEN ORIOLE', 24: 'EUROPEAN GOLDFINCH', 25: 'FAIRY TERN', 26: 'GO AWAY BIRD', 27: 'GOLDEN CHLOROPHONIA', 28: 'HARLEQUIN QUAIL', 29: 'HOODED MERGANSER', 30: 'INDIAN BUSTARD', 31: 'KING VULTURE', 32: 'MALLARD DUCK', 33: 'MASKED BOOBY', 34: 'MASKED LAPWING', 35: 'NORTHERN RED BISHOP', 36: 'ORANGE BRESTED BUNTING', 37: 'PATAGONIAN SIERRA FINCH', 38: 'RED HEADED WOODPECKER', 39: 'RED HONEY CREEPER', 40: 'RED TAILED THRUSH', 41: 'ROCK DOVE', 42: 'RUBY THROATED HUMMINGBIRD', 43: 'SCARLET TANAGER', 44: 'SNOWY EGRET', 45: 'SNOWY PLOVER', 46: 'STORK BILLED KINGFISHER', 47: 'TAIWAN MAGPIE', 48: 'TASMANIAN HEN', 49: 'TOUCHAN'}
- Looks good so far.
# Use break in for loop to give us access to the first item so that we can check it.
for images, labels in valid_loader:
break
images.shape, labels.shape
(torch.Size([32, 3, 128, 128]), torch.Size([32]))
2. Create our Model with PyTorch
Here we will use efficientnet
from timm
as our base model. This requires to know the output of this base model to feed into our Linear layer. The output is 1280. We also need to get these outputs to use for our features to feed into our classifier
class BirdClassifierModel(nn.Module):
def __init__(self, num_classes=50):
super(BirdClassifierModel, self).__init__()
# define structure
self.base_model = timm.create_model('efficientnet_b0', pretrained=True)
# Below we get all the children from the model, turn it into a list then remove/cut off the last layer.
#We then creates a new model from the layers returned.
#self.features is a derived model with all the layers from the model except the last one.
# we use * to unpack them for arguments to the sequential model
self.features = nn.Sequential(*list(self.base_model.children())[:-1])
#Create model using Sequential for
self.classifier = nn.Sequential(
#Flatten tensor to 2D since Linear layer requires 2D
nn.Flatten(),
#output size of efficientnet to 512 dimensional space
nn.Linear(1280, 512),
nn.ReLU(), # activate function to apply non-linear transformation to learn more complex patterns
nn.Dropout(0.5), # preventing overfitting
nn.Linear(512, num_classes) #Final output on linear to num of predicted classes
)
def forward(self, x):
# connect features to classifer from setup
x = self.features(x)
# Final output of classifier above is returned.
x = self.classifier(x)
return x
#create model
model = BirdClassifierModel()
#view portion of parameters
print(str(model.parameters)[:400])
<bound method Module.parameters of BirdClassifierModel(
(base_model): EfficientNet(
(conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True)
)
(blocks): Sequential(
(0): Sequential(
# Model info is long, so lets only show a few lines.
print(str(model)[:300])
BirdClassifierModel(
(base_model): EfficientNet(
(conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNormAct2d(
32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
(drop): Identity()
(act): SiLU(inplace=True
3. Train Model
Before training. We need to set up the Loss function and an optimizer. A learning rate will also be defined. We will also set the number of epochs and also create variables to keep track of training loss and validation loss. A device will also be created so that we can use cuda
. Let's go.
epochs = 20
training_loss = []
validation_loss = []
lr = 0.001
#create our device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#send model to device
model.to(device)
#Loss Function
criterion = nn.CrossEntropyLoss()
#Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
#Training
model.train()
running_loss = 0.0
for imgs, labels in tqdm(train_loader, desc='Training'):
imgs, labels = imgs.to(device), labels.to(device)
#reset gradient from previous backward pass
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * labels.size(0)
train_loss = running_loss / len(train_loader.dataset)
training_loss.append(train_loss)
#Validation
model.eval()
running_loss = 0.0
#turn off backpropagation for validation
with torch.no_grad():
for imgs, labels in tqdm(valid_loader, desc="Validation"):
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
loss = criterion(outputs, labels)
running_loss += loss.item() * labels.size(0)
valid_loss = running_loss / len(valid_loader.dataset)
validation_loss.append(valid_loss)
print(f"Epoch {epoch+1}/{epochs} - Training loss: {train_loss}, Validation loss: {valid_loss}")
Epoch 1/20 - Training loss: 1.2816189220598158, Validation loss: 0.25513468569517134
.
.
.
Epoch 20/20 - Training loss: 0.08152641907136154, Validation loss: 0.19318124890327454
4. Visualize Training Results
plt.plot(training_loss, label='Training Loss')
plt.plot(validation_loss, label='Validation Loss')
plt.legend()
plt.title('Loss')
plt.show()
This model has potential as both the Traning Loss and Validation loss has decreased. Let's check the accuracy and do some more evaluation.
5. Test Accuracy
def calculate_accuracy(model, data_loader, device):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in data_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct/total
# After training, calculate accuracy on the test set
test_accuracy = calculate_accuracy(model, test_loader, device)
print(f"Test Accuracy: {test_accuracy}%")
Test Accuracy: 98.4%
6. Conclusion
The accuracy on the test set is fairly good. The model has achieved 98.5% accuracy. Accuracy alone does not tell the whole story. I will need to consider the False Positives and False Negatives. I will need to explore other metrics like precision, recall and F1-score. For now I will save this model and it's state dictionary.
# We can only save the model's state dictionary as this is more efficient than saving the entire model
torch.save(model, 'model.pth')
torch.save(model.state_dict(), 'model_state_dict.pth')
Reminder on how to reuse this model in application
model = BirdClassifierModel()
model.load_state_dict(torch.load('model_state_dict.pth'))
model.eval()