Model
Model provide training functionality with model.fit(...)
inspired from Keras
Examples:
model = Model(cnn)
model.compile("crossentropyloss", "adam", learning_rate=1e-3, metrics="accuracy")
model.fit(autodataset)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learner |
Trainable model |
required | |
accelerator_config |
HuggingFace Accelerator config |
required |
compile(self, loss=None, optimizer=None, learning_rate=0.0003, metrics=None, loss_config=None, optimizer_config=None)
¶
Examples:
model = Model(net)
model.compile(loss="crossentropyloss", optimizer="adam", learning_rate=1e-3, metrics="accuracy")
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss |
Union[str, torch.nn.modules.loss._Loss] |
name of loss or torch Loss class object. See |
None |
optimizer |
Union[str, torch.optim.optimizer.Optimizer] |
optimizer name or |
None |
learning_rate |
float |
defaults to 1e-3 |
0.0003 |
metrics |
Union[str, torchmetrics.metric.Metric, List[Union[str, torchmetrics.metric.Metric]]] |
list of metrics to calculate. See |
None |
loss_config |
Optional[dict] |
Dict config if any to pass to loss function |
None |
optimizer_config |
Optional[dict] |
Dict config if any to pass to Optimizer |
None |
fit(self, autodataset, max_epochs=1, steps_per_epoch=None, callbacks=None, resume=True, show_progress=True, progress_kwargs=None)
¶
Similar to Keras model.fit(...) it trains the model for specified epochs and returns Tracker object
Examples:
autodataset = AutoDataset(train_dataloader, val_dataloader)
model = Model(cnn)
model.compile("crossentropyloss", "adam", learning_rate=1e-3, metrics="accuracy")
model.fit(autodataset)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
autodataset |
AutoDataset |
AutoDataset object encapsulate dataloader and datamodule |
required |
max_epochs |
int |
number of epochs to train |
1 |
steps_per_epoch |
Optional[int] |
Number of steps trained in a single current_epoch |
None |
callbacks |
Union[List[str], gradsflow.callbacks.callbacks.Callback] |
Callback object or string |
None |
resume |
bool |
Resume training from the last current_epoch |
True |
show_progress |
bool |
Enable to show training progress |
True |
progress_kwargs |
Arguments for rich.progress |
None |
Returns:
Type | Description |
---|---|
Tracker |
Tracker object |