forked from moreo/QuaPy
adapting everything to the new format
This commit is contained in:
parent
f63575ff55
commit
238a30520c
|
@ -40,14 +40,17 @@ print(f'training matrix shape: {train.instances.shape}')
|
||||||
|
|
||||||
true_prevalence = ResultSubmission.load(T1A_devprevalence_path)
|
true_prevalence = ResultSubmission.load(T1A_devprevalence_path)
|
||||||
|
|
||||||
param_grid = {'C': np.logspace(-3,3,7), 'class_weight': ['balanced', None]}
|
param_grid = {
|
||||||
|
'C': np.logspace(-3,3,7),
|
||||||
|
'class_weight': ['balanced', None]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def gen_samples():
|
def gen_samples():
|
||||||
return gen_load_samples_T1(T1A_devvectors_path, nF, ground_truth_path=T1A_devprevalence_path, return_id=False)
|
return gen_load_samples_T1(T1A_devvectors_path, nF, ground_truth_path=T1A_devprevalence_path, return_id=False)
|
||||||
|
|
||||||
|
|
||||||
for quantifier in [CC, ACC, PCC, PACC, EMQ, HDy]:
|
for quantifier in [CC]: #, ACC, PCC, PACC, EMQ, HDy]:
|
||||||
#classifier = CalibratedClassifierCV(LogisticRegression(), n_jobs=-1)
|
#classifier = CalibratedClassifierCV(LogisticRegression(), n_jobs=-1)
|
||||||
classifier = LogisticRegression()
|
classifier = LogisticRegression()
|
||||||
model = quantifier(classifier)
|
model = quantifier(classifier)
|
||||||
|
@ -66,7 +69,7 @@ for quantifier in [CC, ACC, PCC, PACC, EMQ, HDy]:
|
||||||
print(f'{quantifier_name} mae={model.best_score_:.3f} (params: {model.best_params_})')
|
print(f'{quantifier_name} mae={model.best_score_:.3f} (params: {model.best_params_})')
|
||||||
|
|
||||||
pickle.dump(model.best_model(),
|
pickle.dump(model.best_model(),
|
||||||
open(os.path.join(models_path, quantifier_name+'.modsel.pkl'), 'wb'),
|
open(os.path.join(models_path, quantifier_name+'.pkl'), 'wb'),
|
||||||
protocol=pickle.HIGHEST_PROTOCOL)
|
protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
DEV_SAMPLES = 1000
|
DEV_SAMPLES = 1000
|
||||||
TEST_SAMPLES = 5000
|
TEST_SAMPLES = 5000
|
||||||
|
|
||||||
|
TXA_SAMPLE_SIZE = 250
|
||||||
|
TXB_SAMPLE_SIZE = 250
|
||||||
|
|
||||||
T1A_SAMPLE_SIZE = 250
|
T1A_SAMPLE_SIZE = 250
|
||||||
T1B_SAMPLE_SIZE = 1000
|
T1B_SAMPLE_SIZE = 1000
|
||||||
|
T2A_SAMPLE_SIZE = 250
|
||||||
|
T2B_SAMPLE_SIZE = 1000
|
||||||
|
|
||||||
ERROR_TOL = 1E-3
|
ERROR_TOL = 1E-3
|
||||||
|
|
|
@ -80,21 +80,30 @@ def gen_load_samples_T2B(path_dir:str, ground_truth_path:str = None):
|
||||||
|
|
||||||
class ResultSubmission:
|
class ResultSubmission:
|
||||||
|
|
||||||
def __init__(self, categories: List[str]):
|
def __init__(self):
|
||||||
if not isinstance(categories, list) or len(categories) < 2:
|
self.df = None
|
||||||
raise TypeError('wrong format for categories; a list with at least two category names (str) was expected')
|
|
||||||
self.categories = categories
|
def __init_df(self, categories:int):
|
||||||
self.df = pd.DataFrame(columns=list(categories))
|
if not isinstance(categories, int) or categories < 2:
|
||||||
self.df.index.rename('id', inplace=True)
|
raise TypeError('wrong format for categories: an int (>=2) was expected')
|
||||||
|
df = pd.DataFrame(columns=list(range(categories)))
|
||||||
|
df.index.set_names('id', inplace=True)
|
||||||
|
self.df = df
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_categories(self):
|
||||||
|
return len(self.df.columns.values)
|
||||||
|
|
||||||
def add(self, sample_id:int, prevalence_values:np.ndarray):
|
def add(self, sample_id:int, prevalence_values:np.ndarray):
|
||||||
if not isinstance(sample_id, int):
|
if not isinstance(sample_id, int):
|
||||||
raise TypeError(f'error: expected int for sample_sample, found {type(sample_id)}')
|
raise TypeError(f'error: expected int for sample_sample, found {type(sample_id)}')
|
||||||
if not isinstance(prevalence_values, np.ndarray):
|
if not isinstance(prevalence_values, np.ndarray):
|
||||||
raise TypeError(f'error: expected np.ndarray for prevalence_values, found {type(prevalence_values)}')
|
raise TypeError(f'error: expected np.ndarray for prevalence_values, found {type(prevalence_values)}')
|
||||||
|
if self.df is None:
|
||||||
|
self.__init_df(categories=len(prevalence_values))
|
||||||
if sample_id in self.df.index.values:
|
if sample_id in self.df.index.values:
|
||||||
raise ValueError(f'error: prevalence values for "{sample_id}" already added')
|
raise ValueError(f'error: prevalence values for "{sample_id}" already added')
|
||||||
if prevalence_values.ndim!=1 and prevalence_values.size != len(self.categories):
|
if prevalence_values.ndim!=1 and prevalence_values.size != self.n_categories:
|
||||||
raise ValueError(f'error: wrong shape found for prevalence vector {prevalence_values}')
|
raise ValueError(f'error: wrong shape found for prevalence vector {prevalence_values}')
|
||||||
if (prevalence_values<0).any() or (prevalence_values>1).any():
|
if (prevalence_values<0).any() or (prevalence_values>1).any():
|
||||||
raise ValueError(f'error: prevalence values out of range [0,1] for "{sample_id}"')
|
raise ValueError(f'error: prevalence values out of range [0,1] for "{sample_id}"')
|
||||||
|
@ -102,9 +111,7 @@ class ResultSubmission:
|
||||||
raise ValueError(f'error: prevalence values do not sum up to one for "{sample_id}"'
|
raise ValueError(f'error: prevalence values do not sum up to one for "{sample_id}"'
|
||||||
f'(error tolerance {constants.ERROR_TOL})')
|
f'(error tolerance {constants.ERROR_TOL})')
|
||||||
|
|
||||||
# new_entry = dict([('id', sample_id)] + [(col_i, prev_i) for col_i, prev_i in enumerate(prevalence_values)])
|
self.df.loc[sample_id] = prevalence_values
|
||||||
new_entry = pd.DataFrame(prevalence_values.reshape(1,2), index=[sample_id], columns=self.df.columns)
|
|
||||||
self.df = self.df.append(new_entry, ignore_index=False)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
@ -112,7 +119,7 @@ class ResultSubmission:
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path: str) -> 'ResultSubmission':
|
def load(cls, path: str) -> 'ResultSubmission':
|
||||||
df = ResultSubmission.check_file_format(path)
|
df = ResultSubmission.check_file_format(path)
|
||||||
r = ResultSubmission(categories=df.columns.values.tolist())
|
r = ResultSubmission()
|
||||||
r.df = df
|
r.df = df
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@ -120,16 +127,15 @@ class ResultSubmission:
|
||||||
ResultSubmission.check_dataframe_format(self.df)
|
ResultSubmission.check_dataframe_format(self.df)
|
||||||
self.df.to_csv(path)
|
self.df.to_csv(path)
|
||||||
|
|
||||||
def prevalence(self, sample_name:str):
|
def prevalence(self, sample_id:int):
|
||||||
sel = self.df.loc[self.df['filename'] == sample_name]
|
sel = self.df.loc[sample_id]
|
||||||
if sel.empty:
|
if sel.empty:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return sel.loc[:,self.df.columns[1]:].values.flatten()
|
return sel.values.flatten()
|
||||||
|
|
||||||
def iterrows(self):
|
def iterrows(self):
|
||||||
for index, row in self.df.iterrows():
|
for index, row in self.df.iterrows():
|
||||||
# filename = row.filename
|
|
||||||
prevalence = row.values.flatten()
|
prevalence = row.values.flatten()
|
||||||
yield index, prevalence
|
yield index, prevalence
|
||||||
|
|
||||||
|
@ -146,10 +152,11 @@ class ResultSubmission:
|
||||||
|
|
||||||
if df.index.name != 'id' or len(df.columns) < 2:
|
if df.index.name != 'id' or len(df.columns) < 2:
|
||||||
raise ValueError(f'wrong header{hint_path}, '
|
raise ValueError(f'wrong header{hint_path}, '
|
||||||
f'the format of the header should be "id,<cat_1>,...,<cat_n>"')
|
f'the format of the header should be "id,0,...,n-1", '
|
||||||
|
f'where n is the number of categories')
|
||||||
if [int(ci) for ci in df.columns.values] != list(range(len(df.columns))):
|
if [int(ci) for ci in df.columns.values] != list(range(len(df.columns))):
|
||||||
raise ValueError(f'wrong header{hint_path}, category ids should be 0,1,2,...,n')
|
raise ValueError(f'wrong header{hint_path}, category ids should be 0,1,2,...,n-1, '
|
||||||
|
f'where n is the number of categories')
|
||||||
if df.empty:
|
if df.empty:
|
||||||
raise ValueError(f'error{hint_path}: results file is empty')
|
raise ValueError(f'error{hint_path}: results file is empty')
|
||||||
elif len(df) != constants.DEV_SAMPLES and len(df) != constants.TEST_SAMPLES:
|
elif len(df) != constants.DEV_SAMPLES and len(df) != constants.TEST_SAMPLES:
|
||||||
|
@ -167,9 +174,9 @@ class ResultSubmission:
|
||||||
if unexpected:
|
if unexpected:
|
||||||
raise ValueError(f'there are {len(missing)} unexpected ids{hint_path}: {sorted(unexpected)}')
|
raise ValueError(f'there are {len(missing)} unexpected ids{hint_path}: {sorted(unexpected)}')
|
||||||
|
|
||||||
for category_name in df.columns:
|
for category_id in df.columns:
|
||||||
if (df[category_name] < 0).any() or (df[category_name] > 1).any():
|
if (df[category_id] < 0).any() or (df[category_id] > 1).any():
|
||||||
raise ValueError(f'error{hint_path} column "{category_name}" contains values out of range [0,1]')
|
raise ValueError(f'error{hint_path} column "{category_id}" contains values out of range [0,1]')
|
||||||
|
|
||||||
prevs = df.values
|
prevs = df.values
|
||||||
round_errors = np.abs(prevs.sum(axis=-1) - 1.) > constants.ERROR_TOL
|
round_errors = np.abs(prevs.sum(axis=-1) - 1.) > constants.ERROR_TOL
|
||||||
|
@ -180,13 +187,6 @@ class ResultSubmission:
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def sort_categories(self):
|
|
||||||
self.df = self.df.reindex([self.df.columns[0]] + sorted(self.df.columns[1:]), axis=1)
|
|
||||||
self.categories = sorted(self.categories)
|
|
||||||
|
|
||||||
def filenames(self):
|
|
||||||
return self.df.filename.values
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSubmission, sample_size=None, average=True):
|
def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSubmission, sample_size=None, average=True):
|
||||||
if sample_size is None:
|
if sample_size is None:
|
||||||
|
@ -199,18 +199,19 @@ def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSub
|
||||||
if len(true_prevs) != len(predicted_prevs):
|
if len(true_prevs) != len(predicted_prevs):
|
||||||
raise ValueError(f'size mismatch, ground truth file has {len(true_prevs)} entries '
|
raise ValueError(f'size mismatch, ground truth file has {len(true_prevs)} entries '
|
||||||
f'while the file of predictions contain {len(predicted_prevs)} entries')
|
f'while the file of predictions contain {len(predicted_prevs)} entries')
|
||||||
true_prevs.sort_categories()
|
if true_prevs.n_categories != predicted_prevs.n_categories:
|
||||||
predicted_prevs.sort_categories()
|
|
||||||
if true_prevs.categories != predicted_prevs.categories:
|
|
||||||
raise ValueError(f'these result files are not comparable since the categories are different: '
|
raise ValueError(f'these result files are not comparable since the categories are different: '
|
||||||
f'true={true_prevs.categories} vs. predictions={predicted_prevs.categories}')
|
f'true={true_prevs.n_categories} categories vs. '
|
||||||
|
f'predictions={predicted_prevs.n_categories} categories')
|
||||||
ae, rae = [], []
|
ae, rae = [], []
|
||||||
for sample_name, true_prevalence in true_prevs.iterrows():
|
for sample_id, true_prevalence in true_prevs.iterrows():
|
||||||
pred_prevalence = predicted_prevs.prevalence(sample_name)
|
pred_prevalence = predicted_prevs.prevalence(sample_id)
|
||||||
ae.append(qp.error.ae(true_prevalence, pred_prevalence))
|
ae.append(qp.error.ae(true_prevalence, pred_prevalence))
|
||||||
rae.append(qp.error.rae(true_prevalence, pred_prevalence, eps=1./(2*sample_size)))
|
rae.append(qp.error.rae(true_prevalence, pred_prevalence, eps=1./(2*sample_size)))
|
||||||
|
|
||||||
ae = np.asarray(ae)
|
ae = np.asarray(ae)
|
||||||
rae = np.asarray(rae)
|
rae = np.asarray(rae)
|
||||||
|
|
||||||
if average:
|
if average:
|
||||||
return ae.mean(), rae.mean()
|
return ae.mean(), rae.mean()
|
||||||
else:
|
else:
|
||||||
|
@ -224,3 +225,4 @@ def evaluate_submission(true_prevs: ResultSubmission, predicted_prevs: ResultSub
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,10 @@ LeQua2022 Official evaluation script
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.task in {'T1A'}:
|
if args.task in {'T1A', 'T2A'}:
|
||||||
qp.environ['SAMPLE_SIZE'] = constants.T1A_SAMPLE_SIZE
|
qp.environ['SAMPLE_SIZE'] = constants.TXA_SAMPLE_SIZE
|
||||||
|
if args.task in {'T1B', 'T2B'}:
|
||||||
|
qp.environ['SAMPLE_SIZE'] = constants.TXB_SAMPLE_SIZE
|
||||||
true_prev = ResultSubmission.load(args.true_prevalences)
|
true_prev = ResultSubmission.load(args.true_prevalences)
|
||||||
pred_prev = ResultSubmission.load(args.pred_prevalences)
|
pred_prev = ResultSubmission.load(args.pred_prevalences)
|
||||||
mae, mrae = evaluate_submission(true_prev, pred_prev)
|
mae, mrae = evaluate_submission(true_prev, pred_prev)
|
||||||
|
|
|
@ -22,13 +22,13 @@ def main(args):
|
||||||
f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of '
|
f'dev samples ({constants.DEV_SAMPLES}) nor with the expected number of '
|
||||||
f'test samples ({constants.TEST_SAMPLES}).')
|
f'test samples ({constants.TEST_SAMPLES}).')
|
||||||
|
|
||||||
_, categories = load_category_map(args.catmap)
|
# _, categories = load_category_map(args.catmap)
|
||||||
|
|
||||||
# load pickled model
|
# load pickled model
|
||||||
model = pickle.load(open(args.model, 'rb'))
|
model = pickle.load(open(args.model, 'rb'))
|
||||||
|
|
||||||
# predictions
|
# predictions
|
||||||
predictions = ResultSubmission(categories=list(range(len(categories))))
|
predictions = ResultSubmission()
|
||||||
for sampleid, sample in tqdm(gen_load_samples_T1(args.samples, args.nf),
|
for sampleid, sample in tqdm(gen_load_samples_T1(args.samples, args.nf),
|
||||||
desc='predicting', total=nsamples):
|
desc='predicting', total=nsamples):
|
||||||
predictions.add(sampleid, model.quantify(sample))
|
predictions.add(sampleid, model.quantify(sample))
|
||||||
|
@ -48,8 +48,6 @@ if __name__=='__main__':
|
||||||
help='Path to the directory containing the samples')
|
help='Path to the directory containing the samples')
|
||||||
parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str,
|
parser.add_argument('output', metavar='PREDICTIONS-PATH', type=str,
|
||||||
help='Path where to store the predictions file')
|
help='Path where to store the predictions file')
|
||||||
parser.add_argument('catmap', metavar='CATEGORY-MAP-PATH', type=str,
|
|
||||||
help='Path to the category map file')
|
|
||||||
parser.add_argument('nf', metavar='NUM-FEATURES', type=int,
|
parser.add_argument('nf', metavar='NUM-FEATURES', type=int,
|
||||||
help='Number of features seen during training')
|
help='Number of features seen during training')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -81,7 +81,7 @@ def natural_prevalence_prediction(
|
||||||
|
|
||||||
def gen_prevalence_prediction(model: BaseQuantifier, gen_fn: Callable, eval_budget=None):
|
def gen_prevalence_prediction(model: BaseQuantifier, gen_fn: Callable, eval_budget=None):
|
||||||
if not inspect.isgenerator(gen_fn()):
|
if not inspect.isgenerator(gen_fn()):
|
||||||
raise ValueError('param "gen_fun" is not a generator')
|
raise ValueError('param "gen_fun" is not a callable returning a generator')
|
||||||
|
|
||||||
if not isinstance(eval_budget, int):
|
if not isinstance(eval_budget, int):
|
||||||
eval_budget = -1
|
eval_budget = -1
|
||||||
|
|
|
@ -15,7 +15,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: BaseQuantifier,
|
model: BaseQuantifier,
|
||||||
param_grid: dict,
|
param_grid: dict,
|
||||||
sample_size: int,
|
sample_size: Union[int, None],
|
||||||
protocol='app',
|
protocol='app',
|
||||||
n_prevpoints: int = None,
|
n_prevpoints: int = None,
|
||||||
n_repetitions: int = 1,
|
n_repetitions: int = 1,
|
||||||
|
@ -32,30 +32,33 @@ class GridSearchQ(BaseQuantifier):
|
||||||
protocol for quantification.
|
protocol for quantification.
|
||||||
:param model: the quantifier to optimize
|
:param model: the quantifier to optimize
|
||||||
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore for
|
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore for
|
||||||
:param sample_size: the size of the samples to extract from the validation set
|
:param sample_size: the size of the samples to extract from the validation set (ignored if protocl='gen')
|
||||||
that particular parameter
|
:param protocol: either 'app' for the artificial prevalence protocol, 'npp' for the natural prevalence
|
||||||
:param protocol: either 'app' for the artificial prevalence protocol, or 'npp' for the natural prevalence
|
protocol, or 'gen' for using a custom sampling generator function
|
||||||
protocol
|
:param n_prevpoints: if specified, indicates the number of equally distant points to extract from the interval
|
||||||
:param n_prevpoints: if specified, indicates the number of equally distant point to extract from the interval
|
|
||||||
[0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
|
[0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
|
||||||
each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested.
|
each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested.
|
||||||
Ignored if protocol='npp'.
|
Ignored if protocol!='app'.
|
||||||
:param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
|
:param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
|
||||||
if eval_budget is set and is lower than the number of combinations that would be generated using the value
|
for the protocol='app' if eval_budget is set and is lower than the number of combinations that would be
|
||||||
assigned to n_prevpoints (for the current number of classes and n_repetitions)
|
generated using the value assigned to n_prevpoints (for the current number of classes and n_repetitions).
|
||||||
|
Ignored for protocol='npp' and protocol='gen' (use eval_budget for setting a maximum number of samples in
|
||||||
|
those cases).
|
||||||
:param eval_budget: if specified, sets a ceil on the number of evaluations to perform for each hyper-parameter
|
:param eval_budget: if specified, sets a ceil on the number of evaluations to perform for each hyper-parameter
|
||||||
combination. For example, if there are 3 classes, n_repetitions=1 and eval_budget=20, then n_prevpoints will be
|
combination. For example, if protocol='app', there are 3 classes, n_repetitions=1 and eval_budget=20, then
|
||||||
set to 5, since this will generate 15 different prevalences:
|
n_prevpoints will be set to 5, since this will generate 15 different prevalences, i.e., [0, 0, 1],
|
||||||
[0, 0, 1], [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0]
|
[0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0], and since setting it to 6 would generate more than
|
||||||
Ignored if protocol='npp'.
|
20. When protocol='gen', indicates the maximum number of samples to generate, but less samples will be
|
||||||
|
generated if the generator yields less samples.
|
||||||
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
||||||
are those in qp.error.QUANTIFICATION_ERROR
|
are those in qp.error.QUANTIFICATION_ERROR
|
||||||
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
||||||
the best chosen hyperparameter combination
|
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
||||||
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
||||||
a float in [0,1] indicating the proportion of labelled data to extract from the training set
|
a float in [0,1] indicating the proportion of labelled data to extract from the training set, or a callable
|
||||||
|
returning a generator function each time it is invoked (only for protocol='gen').
|
||||||
:param n_jobs: number of parallel jobs
|
:param n_jobs: number of parallel jobs
|
||||||
:param random_seed: set the seed of the random generator to replicate experiments
|
:param random_seed: set the seed of the random generator to replicate experiments. Ignored if protocol='gen'.
|
||||||
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
||||||
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
||||||
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
||||||
|
@ -79,17 +82,13 @@ class GridSearchQ(BaseQuantifier):
|
||||||
'unknown protocol: valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence ' \
|
'unknown protocol: valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence ' \
|
||||||
'protocols. Use protocol="gen" when passing a generator function thorough val_split that yields a ' \
|
'protocols. Use protocol="gen" when passing a generator function thorough val_split that yields a ' \
|
||||||
'sample (instances) and their prevalence (ndarray) at each iteration.'
|
'sample (instances) and their prevalence (ndarray) at each iteration.'
|
||||||
if self.protocol == 'npp':
|
assert self.eval_budget is None or isinstance(self.eval_budget, int)
|
||||||
if self.n_repetitions is None or self.n_repetitions == 1:
|
if self.protocol in ['npp', 'gen']:
|
||||||
if self.eval_budget is not None:
|
if self.protocol=='npp' and (self.eval_budget is None or self.eval_budget <= 0):
|
||||||
print(f'[warning] when protocol=="npp" the parameter n_repetitions should be indicated '
|
raise ValueError(f'when protocol="npp" the parameter eval_budget should be '
|
||||||
f'(and not eval_budget). Setting n_repetitions={self.eval_budget}...')
|
f'indicated (and should be >0).')
|
||||||
self.n_repetitions = self.eval_budget
|
if self.n_prevpoints != 1:
|
||||||
else:
|
print('[warning] n_prevpoints has been set and will be ignored for the selected protocol')
|
||||||
raise ValueError(f'when protocol=="npp" the parameter n_repetitions should be indicated '
|
|
||||||
f'(and should be >1).')
|
|
||||||
if self.n_prevpoints is not None:
|
|
||||||
print('[warning] n_prevpoints has been set along with the npp protocol, and will be ignored')
|
|
||||||
|
|
||||||
def sout(self, msg):
|
def sout(self, msg):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -145,7 +144,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
else:
|
else:
|
||||||
raise ValueError('unknown protocol')
|
raise ValueError('unknown protocol')
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
|
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float, Callable] = None):
|
||||||
"""
|
"""
|
||||||
:param training: the training set on which to optimize the hyperparameters
|
:param training: the training set on which to optimize the hyperparameters
|
||||||
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
||||||
|
|
Loading…
Reference in New Issue