From 3d22270a4d0d7bd8f7103f2a8cb4b57de4ab17d6 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Mon, 18 Mar 2024 10:21:15 +0100 Subject: [PATCH] first look at the problem --- Census/main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Census/main.py b/Census/main.py index 435f71d..73a040d 100644 --- a/Census/main.py +++ b/Census/main.py @@ -86,6 +86,7 @@ q = CC(cls) Atr, Xtr, ytr = load_csv(survey_y, use_yhat=True) preprocessor = Preprocessor() +Xtr = preprocessor.fit_transform(Xtr) # Xtr_proc = preprocessor.fit_transform(Xtr) # big_train = LabelledCollection(Xtr_proc, ytr) # q.fit(big_train) @@ -99,14 +100,14 @@ n_area = len(trains) results = np.zeros(shape=(n_area, n_area)) for i, (Ai, Xi, yi) in tqdm(enumerate(trains), total=n_area): - Xi = preprocessor.fit_transform(Xi) + # Xi = preprocessor.fit_transform(Xi) tr = LabelledCollection(Xi, yi) q.fit(tr) len_tr = len(tr) # len_tr = len(big_train) for j, (Aj, Xj, yj) in enumerate(trains): if i==j: continue - Xj = preprocessor.transform(Xj) + # Xj = preprocessor.transform(Xj) te = LabelledCollection(Xj, yj) pred_prev = q.quantify(te.X) true_prev = te.prevalence()