From 829d2150ecc046e69a8af1bfdcf85827540c26db Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Thu, 4 Feb 2021 18:00:58 +0100 Subject: [PATCH] xavier uniform --- src/main.py | 1 + src/model/classifiers.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/src/main.py b/src/main.py index 68aa2ff..5c5a640 100644 --- a/src/main.py +++ b/src/main.py @@ -90,6 +90,7 @@ def main(opt): cls = AuthorshipAttributionClassifier( phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device ) + cls.xavier_uniform() print(cls) if opt.name == 'auto': diff --git a/src/model/classifiers.py b/src/model/classifiers.py index a186d4b..58a8426 100644 --- a/src/model/classifiers.py +++ b/src/model/classifiers.py @@ -24,6 +24,11 @@ class AuthorshipAttributionClassifier(nn.Module): self.pad_length = pad_length self.device = device + def xavier_uniform(self): + for p in self.parameters(): + if p.dim() > 1 and p.requires_grad: + nn.init.xavier_uniform_(p) + def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'): assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]' early_stop = EarlyStop(patience)