The get_model
function provides a way to quickly download a prebuilt model for easy testing of leap-ie
.
Arguments
model_name
(str
): Name of a , or model. When source
is "keras"
should be of the form module_name.model_name
.
Required: Yes
Default: None
source
("torchvision" | "keras" | "huggingface"
): Name of a source repository to fetch the model from. Note: For "huggingface"
source, only PyTorch models are supported.
Required: No
Default: "torchvision"
segmentation
(bool
): Whether the fetched model should be a segmentation or classification model.
Required: No
Default: False
Returns
preprocessing_fn
The preprocessing function used on inputs for inference.
model
The model.
class_list
List of class names corresponding to the model's output classes.
Examples
Torchvision:get_model('resnet18', source='torchvision')
Hugging Face: get_model('nateraw/vit-age-classifier', source='huggingface')
Keras: get_model('resnet50.ResNet50', source='keras')