Auto Image Classification
First, install gradsflow
pip install git+https://github.com/gradsflow/gradsflow@main
In [ ]:
import os
import sys
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")
os.chdir("../../")
In [ ]:
import ray
from flash.core.data.utils import download_data
from gradsflow import AutoImageClassifier
from gradsflow.data.image import image_dataset_from_directory
Let's use Hymenoptera
dataset provided by Flash which contain images of Ants and Bees for creating Image Classication Model.
In [ ]:
data_dir = "/Users/aniket/personal/gradsflow/gradsflow/data/" # replace with your filepath
# download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", data_dir)
In [ ]:
train_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/train/", transform=True)
train_dl = train_data["dl"]
val_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/val/", transform=True)
val_dl = val_data["dl"]
If you want to run Gradsflow on a remote server then first setup ray cluster and initialize ray with the remote address.
In [ ]:
# ray.init(address="REMOTE_IP_ADDR")
# ray.init(local_mode=True)
To train an image classifier create an object of AutoImageClassifier
and provide number of trials and timeout.
In [ ]:
model = AutoImageClassifier(
train_dataloader=train_dl,
val_dataloader=val_dl,
num_classes=2,
n_trials=1,
optimization_metric="val_accuracy",
timeout=50,
)
In [ ]:
print("AutoImageClassifier initialised!")
model.hp_tune()
print("completed!")
AutoImageClassifier initialised!
In [ ]:
ray.shutdown()
Last update: September 26, 2021