Extensibility

It is possible to implement and register a new input, feature extractor, sample similarity, noise source type, or interpolation method before using using them in calculate_metrics():

Register a new input

  1. Subclass a new dataset (e.g., NewDataset) from Dataset class (refer to torch_fidelity.datasets.Cifar10_RGB for an example),

  2. Register it under some new name (new-ds): register_dataset('new-ds', lambda root, download: NewDataset(root, download)),

  3. Pass “new-ds” as a value of either input1 or input2 keyword arguments to calculate_metrics().

Register a new feature extractor

  1. Subclass a new feature extractor (e.g., NewFeatureExtractor) from torch_fidelity.FeatureExtractorBase class, implement all methods and properties,

  2. Register it under some new name (new-fe): register_feature_extractor('new-fe', NewFeatureExtractor),

  3. Pass “new-fe” as a value of feature_extractor keyword argument to calculate_metrics().

Register a new sample similarity measure

  1. Subclass a new sample similarity (e.g., NewSampleSimilarity) from torch_fidelity.SampleSimilarityBase class, implement all methods and properties,

  2. Register it under some new name (new-ss): register_sample_similarity('new-ss', NewSampleSimilarity),

  3. Pass “new-ss” as a value of ppl_sample_similarity keyword argument to calculate_metrics().

Register a new noise source type

  1. Prepare a new function for drawing a sample from a multivariate distribution of a given shape, e.g., def random_new(rng, shape): pass,

  2. Register it under some new name (new-ns): register_noise_source('new-ns', random_new),

  3. Pass “new-ns” as a value of either input1_model_z_type or input2_model_z_type keyword arguments to calculate_metrics().

Register a new interpolation method

  1. Prepare a new sample interpolation function, e.g., def new_interp(a, b, t): pass,

  2. Register it under some new name (new-interp): register_interpolation('new-interp', new_interp),

  3. Pass “new-interp” as a value of ppl_z_interp_mode keyword arguments to calculate_metrics().