Auto Text Classification
In [ ]:
import os
import sys
os.chdir("../../")
In [ ]:
import warnings
warnings.filterwarnings("ignore")
In [ ]:
from flash.core.data.utils import download_data
from flash.text import TextClassificationData
from gradsflow import AutoTextClassifier
In [ ]:
# download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
)
In [ ]:
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoTextClassifier(
datamodule,
suggested_backbones=["sgugger/tiny-distilbert-classification"],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=5,
)
print("AutoTextClassifier initialised!")
model.hp_tune()
Last update: September 25, 2021