Utility
common
¶
AverageMeter
dataclass
¶
Computes and stores the average and current value.
val
is the running value, avg
is the average value over an epoch.
update(self, val, n=1)
¶
Updates the average meter value with new data. It also converts torch.Tensor
to primitive datatype.
filter_list(arr, pattern=None)
¶
Filter a list of strings with given pattern
>> arr = ['crossentropy', 'binarycrossentropy', 'softmax', 'mae',]
>> filter_list(arr, ".*entropy*")
>> # ["crossentropy", "binarycrossentropy"]
get_file_extension(path)
¶
Returns extension of the file
get_files(folder)
¶
Fetch every file from given folder recursively.
listify(item)
¶
Convert any scalar value into list.
module_to_cls_index(module, lower_key=True)
¶
Fetch classes from module and create a Dictionary with key as class name and value as Class
to_item(data)
¶
Converts torch.Tensor into cpu numpy format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
Any |
torch.Tensor contained in any Iterable or Dictionary. |
required |
Last update:
September 9, 2021