Extensibility
It is possible to implement and register a new input, feature extractor, sample similarity, noise
source type, or interpolation method before using using them in calculate_metrics()
:
Register a new input
Subclass a new dataset (e.g.,
NewDataset
) fromDataset
class (refer totorch_fidelity.datasets.Cifar10_RGB
for an example),Register it under some new name (new-ds):
register_dataset('new-ds', lambda root, download: NewDataset(root, download))
,Pass “new-ds” as a value of either
input1
orinput2
keyword arguments tocalculate_metrics()
.
Register a new feature extractor
Subclass a new feature extractor (e.g.,
NewFeatureExtractor
) fromtorch_fidelity.FeatureExtractorBase
class, implement all methods and properties,Register it under some new name (new-fe):
register_feature_extractor('new-fe', NewFeatureExtractor)
,Pass “new-fe” as a value of
feature_extractor
keyword argument tocalculate_metrics()
.
Register a new sample similarity measure
Subclass a new sample similarity (e.g.,
NewSampleSimilarity
) fromtorch_fidelity.SampleSimilarityBase
class, implement all methods and properties,Register it under some new name (new-ss):
register_sample_similarity('new-ss', NewSampleSimilarity)
,Pass “new-ss” as a value of
ppl_sample_similarity
keyword argument tocalculate_metrics()
.
Register a new noise source type
Prepare a new function for drawing a sample from a multivariate distribution of a given shape, e.g.,
def random_new(rng, shape): pass
,Register it under some new name (new-ns):
register_noise_source('new-ns', random_new)
,Pass “new-ns” as a value of either
input1_model_z_type
orinput2_model_z_type
keyword arguments tocalculate_metrics()
.
Register a new interpolation method
Prepare a new sample interpolation function, e.g.,
def new_interp(a, b, t): pass
,Register it under some new name (new-interp):
register_interpolation('new-interp', new_interp)
,Pass “new-interp” as a value of
ppl_z_interp_mode
keyword arguments tocalculate_metrics()
.