Why Transfer Learning is a Game-Changer for AI Development

Harsh Mishra
6 min readMar 22, 2023

Hello people, I hope everyone is doing well. Today in this blog we will learn about transfer learning and how can you implement it using TensorFlow.

To train a machine learning model or a neural network that can yield the best results requires what?

  1. A large amount of Data
  2. Enough training time
  3. Hardware and Software infrastructure

What if any of the above is missing? How can we train a neural network without having an ample amount of data, even if you have it can you afford to train a model for months? do you have that kind of supercomputer at your home or office and that time? Most of the time the answer to all these questions would be no… So what can we do?

We can use TRANSFER LEARNING.

What is Transfer Learning?

According to Wikipedia, “Transfer learning is a research problem in machine learning that focuses on storing knowledge gained while solving one problem and applying it to a different but related problem.”

Let me explain this in simple words. The name transfer learning comes from the process of using the learning from one problem to solve another. Learning is nothing but just getting the value of weights and biases for a neural network that gives you the desired result. Every epoch we change our weights and biases in the pursuit of reaching the truth value and we call that training (learning in our case). This takes time and requires huge data. What if we use the weights and biases of another network that is already trained on huge data, and that data is similar to ours? This is exactly what we call transfer learning. we use an already-trained model as a starting point and fine-tune it for our specific task. This approach allows us to train models on smaller datasets and achieve higher accuracy.

To understand this let’s take an example of a widely used neural network, CNN (Convolutional Neural Network)

CNN is made of two parts:

  1. C — Convolution
  2. NN — Neural Network

The most important part of training is feature extraction which is done by the convolutional part of the CNN. The process although automatic is very time-consuming and computationally heavy. Also, you require a huge amount of image data to extract accurate features from them. Then Neural Network in CNN is just for the prediction part. Once your Convolutional has done the feature extraction your neural network will take those features’ output to make a decision.

In transfer learning, we skip the feature extraction part as we are using a pre-trained model. In the pre-trained model, we have already trained a model on a different huge dataset. We are using its convolutional layers knowing that they are trained well to identify features from the images and we add a neural network in front of that pre-trained convolutional layers. So when we train this model, we only train the neural network we added which saves us huge time, and also the need to have huge data. Whatever data we have would be sufficient to just train the neural network which can simply make a decision of predicting a class.

This is how transfer learning works. You can get a pre-trained neural network architecture on huge data and use it directly in any of your projects and get outstanding results with less amount of time and data.

The benefit of using transfer learning is that the pre-trained model already has learned important features that are useful for our new task. By fine-tuning the model on our specific dataset, we can adapt the pre-trained features to our specific problem, leading to higher accuracy and faster training times.

Tensorflow Keras Implementation

Also all the code written in this blog is available in colab notebook: https://colab.research.google.com/drive/1U0Qt0wNMZr1_osgHTgAXmPQZeLqaYWS-?usp=sharing

Link for the dataset used- https://www.kaggle.com/datasets/muhammadardiputra/potato-leaf-disease-dataset

Here, we are creating a CNN model to classify three types of potato leaf: Healthy, Early Blight, and Late Blight. Early blight and late blight is a very common diseases found in potato crops.


#importing necessary libraries
import tensorflow as tf
from tensorflow.keras import models, layers
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, AveragePooling2D, Flatten, GlobalAveragePooling2D, Dense, Dropout
from keras.layers.merge import concatenate

#initializing some variables
IMAGE_SIZE = 256
BATCH_SIZE = 32
CHANNELS = 3
EPOCHS = 50

#importing training, testing and validation data from drive

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"/content/drive/MyDrive/DL sem7/Potato/Train",
shuffle = True,
image_size = (IMAGE_SIZE, IMAGE_SIZE),
batch_size = BATCH_SIZE
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"/content/drive/MyDrive/DL sem7/Potato/Valid",
shuffle = True,
image_size = (IMAGE_SIZE, IMAGE_SIZE),
batch_size = BATCH_SIZE
)
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
"/content/drive/MyDrive/DL sem7/Potato/Test",
shuffle = True,
image_size = (IMAGE_SIZE, IMAGE_SIZE),
batch_size = BATCH_SIZE
)

'''
These layers are for preprocessing and we will use it later while
model building. One is to resize the image and other divides
the size of 256 by 255 to get a number between 0 and 1
'''
resize_and_rescale = tf.keras.Sequential([
layers.experimental.preprocessing.Resizing(IMAGE_SIZE,IMAGE_SIZE),
layers.experimental.preprocessing.Rescaling(1.0/255)
])

'''
Data Augmentation is needed when we have less data, this boosts
the accuracy of our model by augmenting the data.
'''
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
layers.experimental.preprocessing.RandomRotation(0.2),
])

input_shape = (IMAGE_SIZE, IMAGE_SIZE, CHANNELS)

# lets start model building

resnet_model = tf.keras.models.Sequential()

''' this is the pre-trained mode resnet50 model. you can see that include_top
is set to false as we would be adding our own neural network on top of the
pre-trained convolutions of resnet50. Weights is set to imagenet (the dataset
on which this model is pre-trained)
'''
pretrained_model = tf.keras.applications.ResNet50(
include_top=False,
input_shape=input_shape,
pooling='avg',
classes=3,
weights='imagenet')
for layer in pretrained_model.layers:
layer.trainable=False
# lets add the neural network on top to complete the CNN model
resnet_model.add(resize_and_rescale)
resnet_model.add(data_augmentation)
resnet_model.add(pretrained_model)
resnet_model.add(layers.Flatten())
resnet_model.add(layers.Dense(512, activation='relu'))
resnet_model.add(layers.Dense(3, activation='softmax'))
resnet_model.build(input_shape = (32, 256, 256, 3))

resnet_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['accuracy']
)

history = resnet_model.fit(
train_ds,
batch_size=BATCH_SIZE,
validation_data=val_ds,
verbose=1,
epochs=50,
)

Popular Use Cases of Transfer Learning:

  1. Image Classification: Transfer learning has been widely used in image classification tasks. For example, a pre-trained model such as ResNet or VGG can be used as a starting point and fine-tuned on a specific dataset to classify images.
  2. Natural Language Processing: Transfer learning has also been used in natural language processing tasks such as sentiment analysis, text classification, and language modeling. A pre-trained model such as BERT or GPT can be used as a starting point and fine-tuned on a specific dataset to perform these tasks.
  3. Speech Recognition: Transfer learning has been used in speech recognition tasks where a pre-trained model such as DeepSpeech can be fine-tuned on a specific dataset to recognize speech.

So, this is the end of our project here. We used a pre-trained CNN architecture and understood the concept of transfer learning. And that is it for today guys.

I hope you guys got to learn something new and enjoyed this blog. If you do then like it, and share it with your friends. Take care. keep learning.

You could also reach me through my LinkedIn account- https://www.linkedin.com/in/harsh-mishra-4b79031b3/

BECOME a WRITER at MLearning.ai

--

--

Harsh Mishra

Data science / ML enthusiast | Front-end developer | CS engineering student