from astropy.table import Table
from collections import OrderedDict
from warnings import warn
__all__ = ["UndefinedCut", "PureCountingCut", "CutFlow"]
[docs]class UndefinedCut(Exception):
pass
[docs]class PureCountingCut(Exception):
pass
[docs]class CutFlow:
"""
a class that keeps track of e.g. events/images that passed cuts or other
events that could reject them
"""
def __init__(self, name="CutFlow"):
"""
Parameters
----------
name : string (default: "CutFlow")
name for the specific instance
"""
self.cuts = OrderedDict()
self.name = name
warn(
"CutFlow is deprecated. Use ctapipe.core.Selector for similar "
"functionality",
FutureWarning,
)
[docs] def count(self, cut, weight=1):
"""
counts an event/image at a given stage of the analysis
Parameters
----------
cut : string
name of the cut/stage where you want to count
weight : int or float, optional (default: 1)
weight of the current element
Notes
-----
If ``cut`` is not yet being tracked, it will simply be added
Will be an alias to __getitem__
"""
if cut not in self.cuts:
self.cuts[cut] = [None, weight]
else:
self.cuts[cut][1] += weight
[docs] def set_cut(self, cut, function):
"""
sets a function that selects on whatever you want to count
sets the counter corresponding to the selection criterion to 0
that means: it overwrites whatever you counted before under this
name
Parameters
----------
cut : string
name of the cut/stage where you want to count
function : function
a function that is your selection criterion
Notes
-----
add_cut and set_cut are aliases
"""
self.cuts[cut] = [function, 0]
[docs] def set_cuts(self, cut_dict, clear=False):
"""
sets functions that select on whatever you want to count
sets the counter corresponding to the selection criterion to 0
that means: it overwrites whatever you counted before under this
name
Parameters
----------
cut_dict : {string: functor} dictionary
dictionary of {name: function} of cuts to add as your selection criteria
clear : bool, optional (default: False)
if set to `True`, clear the cut-dictionary before adding the new cuts
Notes
-----
add_cuts and set_cuts are aliases
"""
if clear:
self.cuts = OrderedDict()
for cut, function in cut_dict.items():
self.cuts[cut] = [function, 0]
def _check_cut(self, cut):
"""
checks if ``cut`` is a valid name for a function to select on
Parameters
----------
cut : string
name of the selection criterion
Raises
------
UndefinedCut if ``cut`` is not known
PureCountingCut if ``cut`` has no associated function
(i.e. manual counting mode)
"""
if cut not in self.cuts:
raise UndefinedCut(
"unknown cut '{}' -- only know: {}".format(
cut, [a for a in self.cuts.keys()]
)
)
elif self.cuts[cut][0] is None:
raise PureCountingCut(f"'{cut}' has no function associated")
[docs] def cut(self, cut, *args, weight=1, **kwargs):
"""
selects the function associated with ``cut`` and hands it all
additional arguments provided. if the function returns `False`,
the event counter is incremented.
Parameters
----------
cut : string
name of the selection criterion
args, kwargs: additional arguments
anything you want to hand to the associated function
weight : int or float, optional (default: 1)
weight of the current element
Returns
-------
True if the function evaluats to True
False otherwise
Raises
------
UndefinedCut if `cut` is not known
PureCountingCut if `cut` has no associated function
(i.e. manual counting mode)
"""
self._check_cut(cut)
if self.cuts[cut][0](*args, **kwargs):
return True
else:
self.cuts[cut][1] += weight
return False
[docs] def keep(self, cut, *args, weight=1, **kwargs):
"""
selects the function associated with ``cut`` and hands it all
additional arguments provided. if the function returns True,
the event counter is incremented.
Parameters
----------
cut : string
name of the selection criterion
args, kwargs: additional arguments
anything you want to hand to the associated function
weight : int or float, optional (default: 1)
weight of the current element
Returns
-------
True if the function evaluats to True
False otherwise
Raises
------
UndefinedCut if ``cut`` is not known
PureCountingCut if ``cut`` has no associated function
(i.e. manual counting mode)
"""
self._check_cut(cut)
if self.cuts[cut][0](*args, **kwargs):
self.cuts[cut][1] += weight
return True
else:
return False
[docs] def __call__(self, *args, **kwargs):
"""
creates an astropy table of the cut names, counted events and
selection efficiencies
prints the instance name and the astropy table
Parameters
----------
kwargs : keyword arguments
arguments to be passed to the ``get_table`` function; see there
Returns
-------
t : `astropy.table.Table`
the table containing the cut names, counted events and
efficiencies -- sorted in the order the cuts were added if not
specified otherwise
"""
print(self.name)
t = self.get_table(*args, **kwargs)
print(t)
return t
[docs] def get_table(
self, base_cut=None, sort_column=None, sort_reverse=False, value_format="5.3f"
):
"""
creates an astropy table of the cut names, counted events and
selection efficiencies
Parameters
----------
base_cut : string, optional (default: None)
name of the selection criterion that should be taken as 100 %
in efficiency calculation
if not given, the criterion with the highest count is used
sort_column : integer, optional (default: None)
the index of the column that should be used for sorting the entries
by default the table is sorted in the order the cuts were added
(index 0: cut name, index 1: number of passed events, index 2: efficiency)
sort_reverse : bool, optional (default: False)
if true, revert the order of the entries
value_format : string, optional (default: '5.3f')
formatting string for the efficiency column
Returns
-------
t : `astropy.table.Table`
the table containing the cut names, counted events and
efficiencies -- sorted in the order the cuts were added if not
specified otherwise
"""
if base_cut is None:
base_value = max([a[1] for a in self.cuts.values()])
elif base_cut not in self.cuts:
raise UndefinedCut(
"unknown cut '{}' -- only know: {}".format(
base_cut, [a for a in self.cuts.keys()]
)
)
else:
base_value = self.cuts[base_cut][1]
t = Table(
[
[cut for cut in self.cuts.keys()],
[self.cuts[cut][1] for cut in self.cuts.keys()],
[self.cuts[cut][1] / base_value for cut in self.cuts.keys()],
],
names=["Cut Name", "selected Events", "Efficiency"],
)
t["Efficiency"].format = value_format
if sort_column is not None:
t.sort(t.colnames[sort_column])
# if sorted by column 0 (i.e. the cut name) default sorting (alphabetically) is
# fine. if sorted by column 1 or 2 or `sort_reverse` is True,
# revert the order of the table
if (sort_column is not None and sort_column > 0) != sort_reverse:
t.reverse()
return t
add_cut = set_cut
add_cuts = set_cuts
__getitem__ = count