elektronn3.models.base module¶
- class elektronn3.models.base.InferenceModel(src, disable_cuda=False, multi_gpu=True, normalize_func=None)[source]¶
Bases:
object
Class to perform inference using a trained elektronn3 model or nn.Module object.
- Parameters:
src (
Union
[str
,Module
]) – Path to training folder of e3 model or already loaded/initialized nn.Module defining the model.disable_cuda (
bool
) – use cpu onlymulti_gpu (
bool
) – enable multi-gpu support of pytorch
Examples
>>> cnn = nn.Sequential( ... nn.Conv2d(5, 32, 3, padding=1), nn.ReLU(), ... nn.Conv2d(32, 2, 1)).to('cpu') >>> inp = np.random.randn(2, 5, 10, 10) >>> model = InferenceModel(cnn) >>> out = model.predict_proba(inp) >>> assert np.all(np.array(out.shape) == np.array([2, 2, 10, 10]))