Skip to content

Model

Model provide training functionality with model.fit(...) inspired from Keras

Examples:

model = Model(cnn)
model.compile("crossentropyloss", "adam", learning_rate=1e-3, metrics="accuracy")
model.fit(autodataset)

Parameters:

Name Type Description Default
learner

Trainable model

required
accelerator_config

HuggingFace Accelerator config

required

compile(self, loss=None, optimizer=None, learning_rate=0.0003, metrics=None, loss_config=None, optimizer_config=None)

Examples:

model = Model(net)
model.compile(loss="crossentropyloss", optimizer="adam", learning_rate=1e-3, metrics="accuracy")

Parameters:

Name Type Description Default
loss Union[str, torch.nn.modules.loss._Loss]

name of loss, torch Loss class object or any functional method. See available_losses()

None
optimizer Union[str, Callable]

optimizer name or torch.optim.Optimizer Class

None
learning_rate float

defaults to 1e-3

0.0003
metrics Union[str, torchmetrics.metric.Metric, List[Union[str, torchmetrics.metric.Metric]]]

list of metrics to calculate. See available_metrics()

None
loss_config Optional[dict]

Dict config if any to pass to loss function

None
optimizer_config Optional[dict]

Dict config if any to pass to Optimizer

None

fit(self, autodataset, max_epochs=1, steps_per_epoch=None, callbacks=None, resume=True, show_progress=True, progress_kwargs=None)

Similar to Keras model.fit(...) it trains the model for specified epochs and returns Tracker object

Examples:

autodataset = AutoDataset(train_dataloader, val_dataloader)
model = Model(cnn)
model.compile("crossentropyloss", "adam", learning_rate=1e-3, metrics="accuracy")
model.fit(autodataset)

Parameters:

Name Type Description Default
autodataset AutoDataset

AutoDataset object encapsulate dataloader and datamodule

required
max_epochs int

number of epochs to train

1
steps_per_epoch Optional[int]

Number of steps trained in a single current_epoch

None
callbacks Union[List[gradsflow.callbacks.callbacks.Callback], gradsflow.callbacks.callbacks.Callback, str, List[str]]

Callback object or string

None
resume bool

Resume training from the last current_epoch

True
show_progress bool

Enable to show training progress

True
progress_kwargs

Arguments for rich.progress

None

Returns:

Type Description
Tracker

Tracker object


Last update: October 3, 2021