diff --git a/Census/main.py b/Census/main.py index 435f71d..73a040d 100644 --- a/Census/main.py +++ b/Census/main.py @@ -86,6 +86,7 @@ q = CC(cls) Atr, Xtr, ytr = load_csv(survey_y, use_yhat=True) preprocessor = Preprocessor() +Xtr = preprocessor.fit_transform(Xtr) # Xtr_proc = preprocessor.fit_transform(Xtr) # big_train = LabelledCollection(Xtr_proc, ytr) # q.fit(big_train) @@ -99,14 +100,14 @@ n_area = len(trains) results = np.zeros(shape=(n_area, n_area)) for i, (Ai, Xi, yi) in tqdm(enumerate(trains), total=n_area): - Xi = preprocessor.fit_transform(Xi) + # Xi = preprocessor.fit_transform(Xi) tr = LabelledCollection(Xi, yi) q.fit(tr) len_tr = len(tr) # len_tr = len(big_train) for j, (Aj, Xj, yj) in enumerate(trains): if i==j: continue - Xj = preprocessor.transform(Xj) + # Xj = preprocessor.transform(Xj) te = LabelledCollection(Xj, yj) pred_prev = q.quantify(te.X) true_prev = te.prevalence()