models.get_model
The get_model
function provides a way to quickly download a prebuilt model for easy testing of leap-ie
.
Arguments
model_name
(str
): Name of a Torchvision, Keras or HuggingFace model. Whensource
is"keras"
should be of the formmodule_name.model_name
.Required: Yes
Default: None
source
("torchvision" | "keras" | "huggingface"
): Name of a source repository to fetch the model from. Note: For"huggingface"
source, only PyTorch models are supported.Required: No
Default:
"torchvision"
segmentation
(bool
): Whether the fetched model should be a segmentation or classification model.Required: No
Default:
False
Returns
preprocessing_fn
The preprocessing function used on inputs for inference.
model
The model.
class_list
List of class names corresponding to the model's output classes.
Examples
Torchvision:
get_model('resnet18', source='torchvision')
Hugging Face:
get_model('nateraw/vit-age-classifier', source='huggingface')
Keras:
get_model('resnet50.ResNet50', source='keras')