models.get_model

The get_model function provides a way to quickly download a prebuilt model for easy testing of leap-ie.

preprocessing_fn, model, class_list = get_model(
  model_name,
  source="source_name",
  segmentation=False,
)

Arguments

  • model_name (str): Name of a Torchvision, Keras or HuggingFace model. When source is "keras" should be of the form module_name.model_name.

    • Required: Yes

    • Default: None

  • source ("torchvision" | "keras" | "huggingface"): Name of a source repository to fetch the model from. Note: For "huggingface" source, only PyTorch models are supported.

    • Required: No

    • Default: "torchvision"

  • segmentation (bool): Whether the fetched model should be a segmentation or classification model.

    • Required: No

    • Default: False

Returns

  • preprocessing_fn

    • The preprocessing function used on inputs for inference.

  • model

    • The model.

  • class_list

    • List of class names corresponding to the model's output classes.

Examples

  • Torchvision:get_model('resnet18', source='torchvision')

  • Hugging Face: get_model('nateraw/vit-age-classifier', source='huggingface')

  • Keras: get_model('resnet50.ResNet50', source='keras')