elektronn3.models.base module

class elektronn3.models.base.InferenceModel(src, disable_cuda=False, multi_gpu=True, normalize_func=None)[source]

Bases: object

Class to perform inference using a trained elektronn3 model or nn.Module object.

Parameters
  • src (Union[str, Module]) – Path to training folder of e3 model or already loaded/initialized nn.Module defining the model.

  • disable_cuda (bool) – use cpu only

  • multi_gpu (bool) – enable multi-gpu support of pytorch

Examples

>>> cnn = nn.Sequential(
... nn.Conv2d(5, 32, 3, padding=1), nn.ReLU(),
... nn.Conv2d(32, 2, 1)).to('cpu')
>>> inp = np.random.randn(2, 5, 10, 10)
>>> model = InferenceModel(cnn)
>>> out = model.predict_proba(inp)
>>> assert np.all(np.array(out.shape) == np.array([2, 2, 10, 10]))
predict_proba(inp, bs=10, verbose=False)[source]
Parameters
  • inp (ndarray) – Input data, e.g. of shape [N, C, H, W]

  • bs (int) – batch size

  • verbose (bool) – report inference speed

Returns:

elektronn3.models.base.load_model(src)[source]

Load trained elektronn3 model.

Parameters
  • src (str) – Source path to model directory. Directory must contain training

  • and model-checkpoint.pth. (script) –

Return type

Module

Returns

Trained model