forked from moreo/QuaPy
Merge branch 'master' of https://github.com/HLT-ISTI/QuaPy
This commit is contained in:
commit
f33abb5319
|
@ -249,7 +249,7 @@ class TextClassifierNet(torch.nn.Module, metaclass=ABCMeta):
|
||||||
|
|
||||||
class LSTMnet(TextClassifierNet):
|
class LSTMnet(TextClassifierNet):
|
||||||
|
|
||||||
def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_nlayers=1,
|
def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_class_nlayers=1,
|
||||||
drop_p=0.5):
|
drop_p=0.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocabulary_size_ = vocabulary_size
|
self.vocabulary_size_ = vocabulary_size
|
||||||
|
@ -258,12 +258,12 @@ class LSTMnet(TextClassifierNet):
|
||||||
'embedding_size': embedding_size,
|
'embedding_size': embedding_size,
|
||||||
'hidden_size': hidden_size,
|
'hidden_size': hidden_size,
|
||||||
'repr_size': repr_size,
|
'repr_size': repr_size,
|
||||||
'lstm_nlayers': lstm_nlayers,
|
'lstm_class_nlayers': lstm_class_nlayers,
|
||||||
'drop_p': drop_p
|
'drop_p': drop_p
|
||||||
}
|
}
|
||||||
|
|
||||||
self.word_embedding = torch.nn.Embedding(vocabulary_size, embedding_size)
|
self.word_embedding = torch.nn.Embedding(vocabulary_size, embedding_size)
|
||||||
self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_nlayers, dropout=drop_p, batch_first=True)
|
self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_class_nlayers, dropout=drop_p, batch_first=True)
|
||||||
self.dropout = torch.nn.Dropout(drop_p)
|
self.dropout = torch.nn.Dropout(drop_p)
|
||||||
|
|
||||||
self.dim = repr_size
|
self.dim = repr_size
|
||||||
|
@ -272,8 +272,8 @@ class LSTMnet(TextClassifierNet):
|
||||||
|
|
||||||
def init_hidden(self, set_size):
|
def init_hidden(self, set_size):
|
||||||
opt = self.hyperparams
|
opt = self.hyperparams
|
||||||
var_hidden = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size'])
|
var_hidden = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size'])
|
||||||
var_cell = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size'])
|
var_cell = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size'])
|
||||||
if next(self.lstm.parameters()).is_cuda:
|
if next(self.lstm.parameters()).is_cuda:
|
||||||
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
|
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
|
||||||
return var_hidden, var_cell
|
return var_hidden, var_cell
|
||||||
|
|
Loading…
Reference in New Issue