matrix extension refactored

This commit is contained in:
Lorenzo Volpi 2023-11-22 19:24:06 +01:00
parent 46d3964e6e
commit beea23db14
1 changed files with 3 additions and 3 deletions

View File

@ -45,9 +45,9 @@ class ExtendedData:
pred_proba: np.ndarray,
ext: np.ndarray = None,
) -> np.ndarray | sp.csr_matrix:
to_append = pred_proba
if ext is not None:
to_append = np.concatenate([ext, pred_proba], axis=1)
to_append = ext
if ext is None:
to_append = pred_proba
if isinstance(instances, sp.csr_matrix):
_to_append = sp.csr_matrix(to_append)