models.get_model
The get_model function provides a way to quickly download a prebuilt model for easy testing of leap-ie.
preprocessing_fn, model, class_list = get_model(
model_name,
source="source_name",
segmentation=False,
)Arguments
model_name(str): Name of a Torchvision, Keras or HuggingFace model. Whensourceis"keras"should be of the formmodule_name.model_name.Required: Yes
Default: None
source("torchvision" | "keras" | "huggingface"): Name of a source repository to fetch the model from. Note: For"huggingface"source, only PyTorch models are supported.Required: No
Default:
"torchvision"
segmentation(bool): Whether the fetched model should be a segmentation or classification model.Required: No
Default:
False
Returns
preprocessing_fnThe preprocessing function used on inputs for inference.
modelThe model.
class_listList of class names corresponding to the model's output classes.
Examples
Torchvision:
get_model('resnet18', source='torchvision')Hugging Face:
get_model('nateraw/vit-age-classifier', source='huggingface')Keras:
get_model('resnet50.ResNet50', source='keras')