xavier uniform
This commit is contained in:
parent
2b61722890
commit
829d2150ec
|
@ -90,6 +90,7 @@ def main(opt):
|
||||||
cls = AuthorshipAttributionClassifier(
|
cls = AuthorshipAttributionClassifier(
|
||||||
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
phi, num_authors=A.size, pad_index=pad_index, pad_length=opt.pad, device=device
|
||||||
)
|
)
|
||||||
|
cls.xavier_uniform()
|
||||||
print(cls)
|
print(cls)
|
||||||
|
|
||||||
if opt.name == 'auto':
|
if opt.name == 'auto':
|
||||||
|
|
|
@ -24,6 +24,11 @@ class AuthorshipAttributionClassifier(nn.Module):
|
||||||
self.pad_length = pad_length
|
self.pad_length = pad_length
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
def xavier_uniform(self):
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1 and p.requires_grad:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
def fit(self, X, y, batch_size, epochs, patience=10, lr=0.001, val_prop=0.1, alpha=1., log='../log/tmp.csv', checkpointpath='../checkpoint/model.dat'):
|
||||||
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
assert 0 <= alpha <= 1, 'wrong range, alpha must be in [0,1]'
|
||||||
early_stop = EarlyStop(patience)
|
early_stop = EarlyStop(patience)
|
||||||
|
|
Loading…
Reference in New Issue