Splitting Data: Train vs Test
In machine learning, datasets are divided into training and testing sets to evaluate how well a model generalizes to unseen data.
Training set— used to teach the model patterns in the data.Testing set— used to evaluate performance on data the model hasn’t seen before.
Without this separation, models risk overfitting — memorizing data instead of learning generalizable patterns.
Using train_test_split in Scikit-learn
The train_test_split() function randomly divides data into training and testing sets with a single line of code.
Basic Train-Test Split
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Load dataset
iris = load_iris()
X, y = iris.data, iris.target
# Split into train (80%) and test (20%)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
print("Train size:", X_train.shape)
print("Test size:", X_test.shape)
Controlling Randomness
Use the random_state parameter to make your results reproducible.
Without it, each run will generate a slightly different split.
Fixed Random State
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=123
)
print("Train size:", X_train.shape)
print("Test size:", X_test.shape)
Stratified Splits
For classification tasks, set stratify=y to keep class proportions consistent between training and testing sets.
Stratified Split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, stratify=y, random_state=42
)
# Check distribution
import numpy as np
unique_train, counts_train = np.unique(y_train, return_counts=True)
unique_test, counts_test = np.unique(y_test, return_counts=True)
print("Train distribution:", dict(zip(unique_train, counts_train)))
print("Test distribution:", dict(zip(unique_test, counts_test)))
Key Takeaways
- Always split your data before training to prevent overfitting.
- Use
train_test_split()— it’s simple, flexible, and built into Scikit-learn. - Apply
stratify=yfor classification to preserve label proportions. - Set
random_statefor consistent, reproducible results.
Want to learn more?
Join CodeFriends Plus membership or enroll in a course to start your journey.