Auto Image Classification
In [ ]:
import os
from pathlib import Path
import sys
import warnings
warnings.filterwarnings("ignore")
os.chdir('../../')
In [ ]:
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData
from gradsflow import AutoImageClassifier
In [ ]:
data_dir = "/Users/aniket/personal/gradsflow/gradsflow/data/"
# download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", data_dir)
In [ ]:
datamodule = ImageClassificationData.from_folders(
train_folder=f"{data_dir}/hymenoptera_data/train/",
val_folder=f"{data_dir}/hymenoptera_data/val/",
)
In [ ]:
suggested_conf = dict(
optimizers=["adam", "sgd"],
lr=(5e-4, 1e-3),
)
model = AutoImageClassifier(
datamodule,
suggested_conf=suggested_conf,
max_epochs=2,
n_trials=4,
optimization_metric="val_accuracy",
timeout=50,
)
print("AutoImageClassifier initialised!")
model.hp_tune()
In [ ]:
In [ ]:
In [ ]:
In [ ]:
model.analysis.best_checkpoint
In [ ]:
Last update: August 29, 2021