import torch import torch.nn as nn import torch.nn.functional as F class DistributionRegressor(nn.Module): def __init__(self, n_classes, hidden_dim=256): super(DistributionRegressor, self).__init__() self.fc1 = nn.Linear(n_classes, hidden_dim) self.fc2 = nn.Linear(hidden_dim, n_classes) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) x = F.softmax(x, dim=-1) return x