Have you ever used an app that can tell what kind of fruit is in a photo — like apple, banana, or orange? That’s a perfect example of multiclass classification — a type of machine learning where a computer learns how to choose one answer from three or more options.
Let’s break it down step-by-step, using plain English.
What Does It Do?
Multiclass classification helps a computer pick the right category when there are three or more possibilities. It’s kind of like a smart quiz machine that guesses the correct answer based on clues.
- 📊 Example: Based on your age, location, and shopping habits, a computer might guess what kind of shopper you are:
- Budget Shopper
- Brand Loyalist
- Occasional Buyer
The machine looks at the facts and then picks the most likely type.
What Kind of Data Does It Need?
Just like how a teacher needs both questions and answers to make a practice test, the machine also needs:
- Input data (like age, income, or where someone lives), and
- Correct labels (what kind of shopper each person really is).
This is called labeled data, and it helps the machine learn from examples.
How Does the Machine Make a Decision?
The machine doesn’t just guess blindly. It uses math to figure out the probability for each possible answer.
Let’s say you give it some info, and it replies with:
- Budget Shopper: 10% chance
- Brand Loyalist: 60% chance
- Occasional Buyer: 30% chance
The machine chooses the one with the highest chance — in this case, “Brand Loyalist.”
How Does It Actually Work?
There are a couple of different ways the computer can be taught to make these decisions:
1. One-vs-Rest (OvR)
- The machine builds a separate mini-model for each category.
- Each one says “yes” or “no” to its category.
- The one that’s most confident wins.
It’s like asking three judges, “Is this person your type of shopper?” The loudest “yes” gets the vote.
2. Multinomial Method
- Instead of making several mini-models, the machine uses one big model.
- This model looks at all the categories at once and gives a full set of probabilities.
- It always makes sure the total adds up to 100%.
How Do We Know If It’s Working?
To check how smart the machine is, we use some of the same tools as in binary classification (where there are just two choices like yes/no):
- Accuracy: How often does it guess right?
- Precision: When it says “Brand Loyalist,” how often is it correct?
- Recall: How many “Brand Loyalists” did it catch out of all the real ones?
- F1-score: A mix of precision and recall in one number.
We also use a confusion matrix, which is like a table that shows:
- What the machine guessed, vs.
- What the correct answer actually was.
This helps us see which classes it’s doing well on, and where it might be getting confused.
Final Thoughts
Multiclass classification helps machines do things like:
- Recognize different types of customers,
- Sort images into categories (dog, cat, bird),
- Suggest movie genres you might like.
It’s like giving the computer multiple choice questions — and teaching it how to get the answer right.
Python code to implement Multiclass Classification
# Step 1: Import libraries
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
# Step 2: Load the iris dataset
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = pd.Series(iris.target, name="species") # 0, 1, 2
# Optional: Replace numbers with actual species names for readability
species_names = {0: 'setosa', 1: 'versicolor', 2: 'virginica'}
y_named = y.map(species_names)
# Step 3: Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y_named, test_size=0.2, random_state=42)
# Step 4: Train a Logistic Regression model
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)
# Step 5: Predict on test data
y_pred = model.predict(X_test)
# Step 6: Evaluate model
print("Classification Report:")
print(classification_report(y_test, y_pred))
print("Confusion Matrix:")
conf_matrix = confusion_matrix(y_test, y_pred, labels=['setosa', 'versicolor', 'virginica'])
print(conf_matrix)
# Step 7: Visualize confusion matrix
sns.heatmap(conf_matrix, annot=True, cmap="YlGnBu", xticklabels=species_names.values(), yticklabels=species_names.values())
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix - Iris Classification")
plt.show()