From 5234ce138764ee7e25c5bab69c8a6c84da5801ef Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Fri, 2 Jun 2023 19:36:54 +0200 Subject: [PATCH] fixed evaluation_report and dataframe visualization --- out.html | 2237 +++++++++++++++++++++++++++++++++++++++++++ quacc/data.py | 18 +- quacc/evaluation.py | 66 +- quacc/main.py | 156 +-- quacc/test_1.py | 138 +++ 5 files changed, 2451 insertions(+), 164 deletions(-) create mode 100644 out.html create mode 100644 quacc/test_1.py diff --git a/out.html b/out.html new file mode 100644 index 0000000..7de9363 --- /dev/null +++ b/out.html @@ -0,0 +1,2237 @@ +spambase +Loading spambase (Spambase Data Set) +#instances=4601, type=, #features=57, #classes=[0 1], prevs=[0.606, 0.394] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
basetrueestimerrors
01T0F1F0T1T0F1F0T1maeraemraekldnkldf1e_truef1e_estim
00.00001.00000.00000.00000.10000.90000.00000.00380.00000.99620.05000.45710.45710.22540.1122NaN1.0000
10.00001.00000.00000.00000.08000.92000.00000.00080.00000.99920.04000.29810.29810.16410.0819NaN0.9907
20.00001.00000.00000.00000.11000.89000.00000.00120.00000.99880.05500.32890.32890.25680.1277NaN1.0000
30.00001.00000.00000.00000.09000.91000.00000.00730.00000.99270.04500.62310.62310.19600.0977NaN0.9999
40.00001.00000.00000.00000.10000.90000.00000.01780.00000.98220.05001.15221.15220.23340.1162NaN1.0000
50.00001.00000.00000.00000.06000.94000.00000.00070.00000.99930.03000.27960.27960.10860.0542NaN1.0000
60.00001.00000.00000.00000.12000.88000.00000.00080.00000.99920.06000.31500.31500.28980.1439NaN1.0000
70.00001.00000.00000.00000.11000.89000.00000.00050.00000.99950.05500.29650.29650.25680.1277NaN1.0000
80.00001.00000.00000.00000.09000.91000.00000.00090.00000.99910.04500.30550.30550.19390.0966NaN0.9761
90.00001.00000.00000.00000.14000.86000.00000.02220.00000.97780.07001.38671.38670.36940.1826NaN1.0000
100.10000.90000.09000.01000.08000.82000.00850.20460.00000.78690.09733.70343.70340.42040.20720.33330.9233
110.10000.90000.10000.00000.11000.79000.02880.07690.00000.89430.09064.28634.28630.36760.18170.35480.5715
120.10000.90000.10000.00000.10000.80000.06490.05010.00000.88500.06752.85202.85200.26890.13370.33330.2786
130.10000.90000.10000.00000.06000.84000.01820.16560.00000.81610.08288.71458.71450.33170.16430.23080.8196
140.10000.90000.09000.01000.08000.82000.02430.10760.00000.86810.07282.04922.04920.27550.13690.33330.6885
150.10000.90000.10000.00000.08000.82000.07670.02190.00000.90150.05171.40881.40880.18110.09030.28570.1248
160.10000.90000.08000.02000.07000.83000.01760.04960.00010.93270.06620.74380.74380.19810.09870.36000.5856
170.10000.90000.08000.02000.13000.77000.02100.19400.00000.78500.09452.15862.15860.47880.23490.48390.8218
180.10000.90000.08000.02000.12000.78000.01780.13580.00000.84640.09111.60241.60240.40740.20090.46670.7928
190.10000.90000.07000.03000.10000.80000.03520.13770.00000.82710.06741.13161.13160.29060.14430.48150.6616
200.20000.80000.20000.00000.11000.69000.11760.22430.00000.65810.112211.567111.56710.47950.23530.21570.4881
210.20000.80000.20000.00000.06000.74000.15030.04400.00030.80540.05472.51082.51080.14550.07260.13040.1284
220.20000.80000.18000.02000.07000.73000.06500.11820.00000.81680.09251.40041.40040.26100.12970.20000.4762
230.20000.80000.19000.01000.11000.69000.15180.02090.00000.82730.07410.51860.51860.26960.13400.24000.0643
240.20000.80000.19000.01000.11000.69000.06760.38400.00000.54840.18706.68056.68050.66270.31970.24000.7395
250.20000.80000.20000.00000.08000.72000.12450.10060.00000.77490.07785.37635.37630.26680.13260.16670.2878
260.20000.80000.19000.01000.10000.70000.14070.09860.00000.76080.07471.79891.79890.28930.14360.22450.2594
270.20000.80000.19000.01000.09000.71000.09920.12360.00000.77720.09042.26992.26990.30550.15160.20830.3838
280.20000.80000.20000.00000.07000.73000.15620.07550.00000.76830.05694.07274.07270.20110.10020.14890.1945
290.20000.80000.19000.01000.10000.70000.12020.04600.00000.83380.08490.97430.97430.26520.13180.22450.1605
300.30000.70000.29000.01000.08000.62000.20280.16600.00000.63130.08362.91312.91310.29660.14720.13430.2904
310.30000.70000.29000.01000.05000.65000.23590.10080.00000.66340.05211.79081.79080.14910.07440.09380.1760
320.30000.70000.27000.03000.07000.63000.18810.21810.00000.59380.09401.66541.66540.27280.13550.15620.3670
330.30000.70000.28000.02000.06000.64000.20940.08790.00000.70280.06530.99560.99560.15520.07740.12500.1734
340.30000.70000.28000.02000.03000.67000.17890.08050.00000.74070.06560.93390.93390.09510.04750.08200.1836
350.30000.70000.29000.01000.10000.60000.16980.24190.00000.58840.11594.20904.20900.44380.21830.15940.4160
360.30000.70000.30000.00000.08000.62000.25060.06920.00000.68010.06473.76083.76080.22380.11140.11760.1213
370.30000.70000.24000.06000.07000.63000.06370.31390.00000.62250.12691.39271.39270.41900.20650.21310.7114
380.30000.70000.27000.03000.06000.64000.14990.23030.00000.61980.10011.77821.77820.27830.13830.14290.4343
390.30000.70000.28000.02000.04000.66000.15160.18180.00000.66660.08421.95541.95540.21270.10600.09680.3749
400.40000.60000.39000.01000.07000.53000.22610.31430.00000.45950.15225.44235.44230.44450.21870.09300.4100
410.40000.60000.35000.05000.05000.55000.26130.18930.00000.54940.06970.92330.92330.16520.08240.12500.2660
420.40000.60000.38000.02000.08000.52000.21330.24900.00000.53770.12342.64222.64220.38390.18960.11630.3686
430.40000.60000.39000.01000.05000.55000.22470.34920.00000.42610.16966.04056.04050.43880.21590.07140.4373
440.40000.60000.38000.02000.09000.51000.26340.22140.00000.51520.10332.32902.32900.35830.17730.12640.2959
450.40000.60000.38000.02000.06000.54000.21250.28610.00000.50140.13313.01843.01840.36530.18060.09520.4024
460.40000.60000.39000.01000.05000.55000.32250.16080.00000.51660.07542.79902.79900.20420.10180.07140.1996
470.40000.60000.36000.04000.05000.55000.19310.36640.00000.44050.16322.20412.20410.38190.18870.11110.4868
480.40000.60000.35000.05000.04000.56000.24120.18340.00000.57540.07440.91190.91190.14580.07280.11390.2754
490.40000.60000.38000.02000.08000.52000.29010.18350.00000.52640.08491.93181.93180.28630.14220.11630.2403
500.50000.50000.47000.03000.02000.48000.26610.25480.00000.47910.11241.91351.91350.23730.11810.05050.3237
510.50000.50000.46000.04000.07000.43000.30660.28720.00000.40620.12361.70291.70290.32950.16330.10680.3190
520.50000.50000.49000.01000.07000.43000.44880.04990.00000.50130.05560.96010.96010.16060.08010.07550.0527
530.50000.50000.47000.03000.01000.49000.31170.28130.00000.40700.12572.08712.08710.22630.11270.04080.3109
540.50000.50000.47000.03000.05000.45000.38570.13250.00000.48180.06721.02151.02150.14610.07290.07840.1466
550.50000.50000.45000.05000.09000.41000.36110.21640.00000.42240.08941.04981.04980.28970.14380.13460.2306
560.50000.50000.43000.07000.07000.43000.26650.35180.00000.38170.14091.29421.29420.34230.16950.14000.3975
570.50000.50000.46000.04000.03000.47000.29690.20030.00000.50280.09661.20991.20990.16900.08430.07070.2523
580.50000.50000.49000.01000.02000.48000.40310.07200.00000.52490.05341.30031.30030.06830.03410.02970.0820
590.50000.50000.49000.01000.06000.44000.33520.30220.00000.36260.14615.22255.22250.39210.19360.06670.3107
600.60000.40000.55000.05000.02000.38000.37230.26170.00000.36600.10591.25151.25150.18190.09070.05980.2601
610.60000.40000.55000.05000.06000.34000.39440.34780.00000.25770.14891.71431.71430.34100.16890.09090.3060
620.60000.40000.58000.02000.06000.34000.48830.18690.00000.32480.08351.95031.95030.23110.11500.06450.1607
630.60000.40000.59000.01000.03000.37000.46910.30980.00000.22110.14995.36115.36110.34740.17190.03280.2482
640.60000.40000.57000.03000.06000.34000.40360.32640.00000.27000.14822.47092.47090.36270.17940.07320.2879
650.60000.40000.57000.03000.01000.39000.32510.34880.00000.32620.15942.59042.59040.32430.16070.03390.3491
660.60000.40000.53000.07000.05000.35000.29770.47880.00000.22350.20441.78751.78750.45310.22280.10170.4457
670.60000.40000.56000.04000.06000.34000.44910.23670.00000.31420.09831.39121.39120.24130.12010.08200.2086
680.60000.40000.56000.04000.04000.36000.40480.30600.00000.28920.13301.81691.81690.27200.13520.06670.2743
690.60000.40000.59000.01000.03000.37000.44620.23880.00000.31500.11444.12444.12440.25030.12450.03280.2111
700.70000.30000.69000.01000.02000.28000.63740.05340.00000.30930.03630.96790.96790.04680.02340.02130.0402
710.70000.30000.66000.04000.01000.29000.56460.21080.00000.22460.08541.20711.20710.12290.06140.03650.1573
720.70000.30000.64000.06000.03000.27000.56260.13370.00000.30370.05370.55840.55840.06950.03470.06570.1062
730.70000.30000.64000.06000.01000.29000.48900.32800.00000.18310.13401.34651.34650.21520.10720.05190.2512
740.70000.30000.63000.07000.02000.28000.47170.28270.00000.24560.10631.00151.00150.15810.07890.06670.2305
750.70000.30000.68000.02000.03000.27000.59140.21230.00000.19630.09612.23662.23660.19470.09700.03550.1522
760.70000.30000.66000.04000.03000.27000.61640.13000.00000.25350.04500.74570.74570.08070.04030.05040.0954
770.70000.30000.66000.04000.01000.29000.54850.19400.00000.25750.07701.09171.09170.10610.05300.03650.1503
780.70000.30000.66000.04000.05000.25000.55830.26940.00000.17230.11471.61611.61610.25360.12610.06380.1944
790.70000.30000.68000.02000.03000.27000.48310.29380.00000.22310.13693.06673.06670.28960.14380.03550.2332
800.80000.20000.73000.07000.03000.17000.58470.30970.00000.10560.11991.15471.15470.20270.10100.06410.2094
810.80000.20000.73000.07000.03000.17000.51180.33320.00000.15500.13161.18731.18730.22970.11440.06410.2456
820.80000.20000.71000.09000.00000.20000.53720.25000.00000.21270.08640.49870.49870.09160.04570.05960.1888
830.80000.20000.75000.05000.07000.13000.75680.15190.00000.09130.05440.77030.77030.18430.09190.07410.0912
840.80000.20000.76000.04000.02000.18000.70050.21420.00000.08530.08711.31521.31520.16370.08160.03800.1326
850.80000.20000.75000.05000.01000.19000.65690.14930.00000.19380.05160.65390.65390.05540.02770.03850.1021
860.80000.20000.76000.04000.01000.19000.61340.22220.00000.16440.09111.25961.25960.13380.06680.03180.1534
870.80000.20000.75000.05000.01000.19000.66290.23200.00000.10510.09101.13131.13130.14010.06990.03850.1489
880.80000.20000.74000.06000.00000.20000.58570.30810.00000.10620.12401.12041.12040.19610.09770.03900.2082
890.80000.20000.74000.06000.02000.18000.45810.46950.00000.07240.20482.01522.01520.42650.21010.05130.3389
900.90000.10000.86000.04000.02000.08000.86810.10340.00000.02840.03580.70630.70630.07180.03590.03370.0562
910.90000.10000.82000.08000.02000.08000.69540.23120.00020.07320.07560.70020.70020.09450.04720.05750.1426
920.90000.10000.85000.05000.01000.09000.71360.22090.00010.06540.08541.04581.04580.11550.05770.03410.1341
930.90000.10000.83000.07000.01000.09000.65370.29120.00000.05520.11061.04831.04830.15490.07730.04600.1822
940.90000.10000.83000.07000.00000.10000.82850.08100.00000.09050.00550.06020.06020.00120.00060.04050.0466
950.90000.10000.85000.05000.00000.10000.72450.22440.00000.05110.08720.94590.94590.12300.06140.02860.1341
960.90000.10000.86000.04000.01000.09000.67970.32020.00000.00010.14012.01202.01200.40730.20090.02820.1907
970.90000.10000.85000.05000.00000.10000.79740.11660.00000.08590.03330.35170.35170.02570.01280.02860.0681
980.90000.10000.86000.04000.01000.09000.78100.20400.00010.01490.08201.29671.29670.17830.08890.02820.1155
990.90000.10000.84000.06000.00000.10000.71520.16070.00000.12420.06240.48170.48170.05250.02630.03450.1010
1001.00000.00000.93000.07000.00000.00000.95650.02290.00000.02060.02361.19241.19240.03990.01990.03630.0118
1011.00000.00000.94000.06000.00000.00000.89540.07360.00000.03100.02231.61281.61280.02340.01170.03090.0395
1021.00000.00000.98000.02000.00000.00000.94860.05140.00000.00000.01570.32200.32200.01150.00580.01010.0264
1031.00000.00000.94000.06000.00000.00000.93170.06280.00000.00540.00410.28420.28420.00180.00090.03090.0326
1041.00000.00000.97000.03000.00000.00000.77350.22650.00000.00000.09821.45401.45400.15330.07650.01520.1277
1051.00000.00000.93000.07000.00000.00000.84940.12750.00000.02310.04031.36761.36760.03300.01650.03630.0698
1061.00000.00000.95000.05000.00000.00000.97350.01480.00010.01160.01760.75360.75360.02690.01350.02560.0076
1071.00000.00000.94000.06000.00000.00000.91790.04920.00000.03290.01651.69441.69440.02410.01200.03090.0261
1081.00000.00000.96000.04000.00000.00000.90190.09810.00010.00000.02910.34050.34050.02260.01130.02040.0516
1091.00000.00000.94000.06000.00000.00000.94430.02360.00030.03170.01821.74331.74330.03870.01940.03090.0125
+************************************************** diff --git a/quacc/data.py b/quacc/data.py index f8fad6d..715802b 100644 --- a/quacc/data.py +++ b/quacc/data.py @@ -1,7 +1,9 @@ +from typing import List, Optional + import numpy as np +import quapy as qp import scipy.sparse as sp from quapy.data import LabelledCollection -from typing import List, Optional class ExtendedCollection(LabelledCollection): @@ -12,3 +14,17 @@ class ExtendedCollection(LabelledCollection): classes: Optional[List] = None, ): super().__init__(instances, labels, classes=classes) + +def get_dataset(name): + datasets = { + "spambase": lambda: qp.datasets.fetch_UCIDataset( + "spambase", verbose=False + ).train_test, + "hp": lambda: qp.datasets.fetch_reviews("hp", tfidf=True).train_test, + "imdb": lambda: qp.datasets.fetch_reviews("imdb", tfidf=True).train_test, + } + + try: + return datasets[name]() + except KeyError: + raise KeyError(f"{name} is not available as a dataset") diff --git a/quacc/evaluation.py b/quacc/evaluation.py index d12d098..029d476 100644 --- a/quacc/evaluation.py +++ b/quacc/evaluation.py @@ -1,16 +1,20 @@ +import itertools from quapy.protocol import ( OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol, ) -import quapy as qp from typing import Iterable, Callable, Union from .estimator import AccuracyEstimator import pandas as pd +import numpy as np import quacc.error as error -def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol): +def estimate( + estimator: AccuracyEstimator, + protocol: AbstractStochasticSeededProtocol, +): # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") @@ -18,6 +22,9 @@ def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededPro for sample in protocol(): e_sample = estimator.extend(sample) estim_prev = estimator.estimate(e_sample.X, ext=True) + # base_prevs.append(_prettyfloat(accuracy, sample.prevalence())) + # true_prevs.append(_prettyfloat(accuracy, e_sample.prevalence())) + # estim_prevs.append(_prettyfloat(accuracy, estim_prev)) base_prevs.append(sample.prevalence()) true_prevs.append(e_sample.prevalence()) estim_prevs.append(estim_prev) @@ -25,6 +32,38 @@ def estimate(estimator: AccuracyEstimator, protocol: AbstractStochasticSeededPro return base_prevs, true_prevs, estim_prevs +_bprev_col_0 = ["base"] +_bprev_col_1 = ["0", "1"] +_prev_col_0 = ["true", "estim"] +_prev_col_1 = ["T0", "F1", "F0", "T1"] +_err_col_0 = ["errors"] + + +def _report_columns(err_names): + bprev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + prev_cols = list(itertools.product(_prev_col_0, _prev_col_1)) + + err_1 = err_names + err_cols = list(itertools.product(_err_col_0, err_1)) + + cols = bprev_cols + prev_cols + err_cols + + return pd.MultiIndex.from_tuples(cols) + + +def _dict_prev(base_prev, true_prev, estim_prev): + prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list( + itertools.product(_prev_col_0, _prev_col_1) + ) + + return { + k: v + for (k, v) in zip( + prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0) + ) + } + + def evaluation_report( estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol, @@ -40,26 +79,25 @@ def evaluation_report( ] assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function" error_names = [e.__name__ for e in error_funcs] + error_cols = error_names.copy() + if "f1e" in error_cols: + error_cols.remove("f1e") + error_cols.extend(["f1e_true", "f1e_estim"]) + + # df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names + df_cols = _report_columns(error_cols) - df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names - if "f1e" in df_cols: - df_cols.remove("f1e") - df_cols.extend(["f1e_true", "f1e_estim"]) lst = [] for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): - series = { - "base_prev": base_prev, - "true_prev": true_prev, - "estim_prev": estim_prev, - } + series = _dict_prev(base_prev, true_prev, estim_prev) for error_name, error_metric in zip(error_names, error_funcs): if error_name == "f1e": - series["f1e_true"] = error_metric(true_prev) - series["f1e_estim"] = error_metric(estim_prev) + series[("errors", "f1e_true")] = error_metric(true_prev) + series[("errors", "f1e_estim")] = error_metric(estim_prev) continue score = error_metric(true_prev, estim_prev) - series[error_name] = score + series[("errors", error_name)] = score lst.append(series) diff --git a/quacc/main.py b/quacc/main.py index 51f4b04..14fa7d0 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -1,158 +1,17 @@ -import numpy as np +import pandas as pd import quapy as qp -import scipy.sparse as sp -from quapy.data import LabelledCollection from quapy.method.aggregative import SLD -from quapy.protocol import APP, AbstractStochasticSeededProtocol +from quapy.protocol import APP from sklearn.linear_model import LogisticRegression -from sklearn.model_selection import cross_val_predict import quacc.evaluation as eval from quacc.estimator import AccuracyEstimator -qp.environ['SAMPLE_SIZE'] = 100 +from .data import get_dataset +qp.environ["SAMPLE_SIZE"] = 100 -# Extended classes -# -# 0 ~ True 0 -# 1 ~ False 1 -# 2 ~ False 0 -# 3 ~ True 1 -# _____________________ -# | | | -# | True 0 | False 1 | -# |__________|__________| -# | | | -# | False 0 | True 1 | -# |__________|__________| -# -def get_ex_class(classes, true_class, pred_class): - return true_class * classes + pred_class - - -def extend_collection(coll, pred_prob): - n_classes = coll.n_classes - - # n_X = [ X | predicted probs. ] - if isinstance(coll.X, sp.csr_matrix): - pred_prob_csr = sp.csr_matrix(pred_prob) - n_x = sp.hstack([coll.X, pred_prob_csr]) - elif isinstance(coll.X, np.ndarray): - n_x = np.concatenate((coll.X, pred_prob), axis=1) - else: - raise ValueError("Unsupported matrix format") - - # n_y = (exptected y, predicted y) - n_y = [] - for i, true_class in enumerate(coll.y): - pred_class = pred_prob[i].argmax(axis=0) - n_y.append(get_ex_class(n_classes, true_class, pred_class)) - - return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)]) - - -def qf1e_binary(prev): - recall = prev[0] / (prev[0] + prev[1]) - precision = prev[0] / (prev[0] + prev[2]) - - return 1 - 2 * (precision * recall) / (precision + recall) - - -def compute_errors(true_prev, estim_prev, n_instances): - errors = {} - _eps = 1 / (2 * n_instances) - errors = { - "mae": qp.error.mae(true_prev, estim_prev), - "rae": qp.error.rae(true_prev, estim_prev, eps=_eps), - "mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps), - "kld": qp.error.kld(true_prev, estim_prev, eps=_eps), - "nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps), - "true_f1e": qf1e_binary(true_prev), - "estim_f1e": qf1e_binary(estim_prev), - } - - return errors - - -def extend_and_quantify( - model, - q_model, - train, - test: LabelledCollection | AbstractStochasticSeededProtocol, -): - model.fit(*train.Xy) - - pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba") - _train = extend_collection(train, pred_prob_train) - - q_model.fit(_train) - - def quantify_extended(test): - pred_prob_test = model.predict_proba(test.X) - _test = extend_collection(test, pred_prob_test) - _estim_prev = q_model.quantify(_test.instances) - # check that _estim_prev has all the classes and eventually fill the missing - # ones with 0 - for _cls in _test.classes_: - if _cls not in q_model.classes_: - _estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0) - print(_estim_prev) - return _test.prevalence(), _estim_prev - - if isinstance(test, LabelledCollection): - _true_prev, _estim_prev = quantify_extended(test) - _errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0]) - return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors]) - - elif isinstance(test, AbstractStochasticSeededProtocol): - orig_prevs, true_prevs, estim_prevs, errors = [], [], [], [] - for index in test.samples_parameters(): - sample = test.sample(index) - _true_prev, _estim_prev = quantify_extended(sample) - - orig_prevs.append(sample.prevalence()) - true_prevs.append(_true_prev) - estim_prevs.append(_estim_prev) - errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0])) - - return orig_prevs, true_prevs, estim_prevs, errors - - -def get_dataset(name): - datasets = { - "spambase": lambda: qp.datasets.fetch_UCIDataset( - "spambase", verbose=False - ).train_test, - "hp": lambda: qp.datasets.fetch_reviews("hp", tfidf=True).train_test, - "imdb": lambda: qp.datasets.fetch_reviews("imdb", tfidf=True).train_test, - } - - try: - return datasets[name]() - except KeyError: - raise KeyError(f"{name} is not available as a dataset") - - -def test_1(dataset_name): - train, test = get_dataset(dataset_name) - - orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify( - LogisticRegression(), - SLD(LogisticRegression()), - train, - APP(test, n_prevalences=11, repeats=1), - ) - - for orig_prev, true_prev, estim_prev, _errors in zip( - orig_prevs, true_prevs, estim_prevs, errors - ): - print(f"original prevalence:\t{orig_prev}") - print(f"true prevalence:\t{true_prev}") - print(f"estimated prevalence:\t{estim_prev}") - for name, err in _errors.items(): - print(f"{name}={err:.3f}") - print() +pd.set_option("display.float_format", "{:.4f}".format) def test_2(dataset_name): @@ -161,9 +20,8 @@ def test_2(dataset_name): model.fit(*train.Xy) estimator = AccuracyEstimator(model, SLD(LogisticRegression())) estimator.fit(train) - df = eval.evaluation_report( - estimator, APP(test, n_prevalences=11, repeats=1) - ) + df = eval.evaluation_report(estimator, APP(test, n_prevalences=11, repeats=100)) + # print(df.to_string()) print(df.to_string()) diff --git a/quacc/test_1.py b/quacc/test_1.py new file mode 100644 index 0000000..00ea263 --- /dev/null +++ b/quacc/test_1.py @@ -0,0 +1,138 @@ +import numpy as np +import scipy as sp +import quapy as qp +from quapy.data import LabelledCollection +from quapy.method.aggregative import SLD +from quapy.protocol import APP, AbstractStochasticSeededProtocol +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import cross_val_predict + +from .data import get_dataset + +# Extended classes +# +# 0 ~ True 0 +# 1 ~ False 1 +# 2 ~ False 0 +# 3 ~ True 1 +# _____________________ +# | | | +# | True 0 | False 1 | +# |__________|__________| +# | | | +# | False 0 | True 1 | +# |__________|__________| +# +def get_ex_class(classes, true_class, pred_class): + return true_class * classes + pred_class + + +def extend_collection(coll, pred_prob): + n_classes = coll.n_classes + + # n_X = [ X | predicted probs. ] + if isinstance(coll.X, sp.csr_matrix): + pred_prob_csr = sp.csr_matrix(pred_prob) + n_x = sp.hstack([coll.X, pred_prob_csr]) + elif isinstance(coll.X, np.ndarray): + n_x = np.concatenate((coll.X, pred_prob), axis=1) + else: + raise ValueError("Unsupported matrix format") + + # n_y = (exptected y, predicted y) + n_y = [] + for i, true_class in enumerate(coll.y): + pred_class = pred_prob[i].argmax(axis=0) + n_y.append(get_ex_class(n_classes, true_class, pred_class)) + + return LabelledCollection(n_x, np.asarray(n_y), [*range(0, n_classes * n_classes)]) + + +def qf1e_binary(prev): + recall = prev[0] / (prev[0] + prev[1]) + precision = prev[0] / (prev[0] + prev[2]) + + return 1 - 2 * (precision * recall) / (precision + recall) + + +def compute_errors(true_prev, estim_prev, n_instances): + errors = {} + _eps = 1 / (2 * n_instances) + errors = { + "mae": qp.error.mae(true_prev, estim_prev), + "rae": qp.error.rae(true_prev, estim_prev, eps=_eps), + "mrae": qp.error.mrae(true_prev, estim_prev, eps=_eps), + "kld": qp.error.kld(true_prev, estim_prev, eps=_eps), + "nkld": qp.error.nkld(true_prev, estim_prev, eps=_eps), + "true_f1e": qf1e_binary(true_prev), + "estim_f1e": qf1e_binary(estim_prev), + } + + return errors + + +def extend_and_quantify( + model, + q_model, + train, + test: LabelledCollection | AbstractStochasticSeededProtocol, +): + model.fit(*train.Xy) + + pred_prob_train = cross_val_predict(model, *train.Xy, method="predict_proba") + _train = extend_collection(train, pred_prob_train) + + q_model.fit(_train) + + def quantify_extended(test): + pred_prob_test = model.predict_proba(test.X) + _test = extend_collection(test, pred_prob_test) + _estim_prev = q_model.quantify(_test.instances) + # check that _estim_prev has all the classes and eventually fill the missing + # ones with 0 + for _cls in _test.classes_: + if _cls not in q_model.classes_: + _estim_prev = np.insert(_estim_prev, _cls, [0.0], axis=0) + print(_estim_prev) + return _test.prevalence(), _estim_prev + + if isinstance(test, LabelledCollection): + _true_prev, _estim_prev = quantify_extended(test) + _errors = compute_errors(_true_prev, _estim_prev, test.X.shape[0]) + return ([test.prevalence()], [_true_prev], [_estim_prev], [_errors]) + + elif isinstance(test, AbstractStochasticSeededProtocol): + orig_prevs, true_prevs, estim_prevs, errors = [], [], [], [] + for index in test.samples_parameters(): + sample = test.sample(index) + _true_prev, _estim_prev = quantify_extended(sample) + + orig_prevs.append(sample.prevalence()) + true_prevs.append(_true_prev) + estim_prevs.append(_estim_prev) + errors.append(compute_errors(_true_prev, _estim_prev, sample.X.shape[0])) + + return orig_prevs, true_prevs, estim_prevs, errors + + + + +def test_1(dataset_name): + train, test = get_dataset(dataset_name) + + orig_prevs, true_prevs, estim_prevs, errors = extend_and_quantify( + LogisticRegression(), + SLD(LogisticRegression()), + train, + APP(test, n_prevalences=11, repeats=1), + ) + + for orig_prev, true_prev, estim_prev, _errors in zip( + orig_prevs, true_prevs, estim_prevs, errors + ): + print(f"original prevalence:\t{orig_prev}") + print(f"true prevalence:\t{true_prev}") + print(f"estimated prevalence:\t{estim_prev}") + for name, err in _errors.items(): + print(f"{name}={err:.3f}") + print()