Model Components#
Modular torch.nn.Module components for building custom architectures.
Text Embedding#
Text embedding is split into two composable stages:
TokenEmbedder — maps each token to a dense vector (with optional self-attention). Output:
(batch, seq_len, embedding_dim).SentenceEmbedder — aggregates token vectors into a sentence embedding. Output:
(batch, embedding_dim)or(batch, num_classes, embedding_dim)with label attention.
TokenEmbedder#
Embeds tokenized text with optional self-attention.
- class torchTextClassifiers.model.components.text_embedder.TokenEmbedder(token_embedder_config)[source]#
Bases:
ModuleA module that takes tokenized text and outputs dense vector representations (one for each token).
- __init__(token_embedder_config)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input_ids, attention_mask)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
TokenEmbedderConfig#
Configuration for TokenEmbedder.
- class torchTextClassifiers.model.components.text_embedder.TokenEmbedderConfig(vocab_size, embedding_dim, padding_idx, attention_config=None)[source]#
Bases:
object-
attention_config:
AttentionConfig|None= None#
- __init__(vocab_size, embedding_dim, padding_idx, attention_config=None)#
-
attention_config:
Example:
from torchTextClassifiers.model.components import (
TokenEmbedder, TokenEmbedderConfig, AttentionConfig,
)
# Simple token embedder (no self-attention)
config = TokenEmbedderConfig(
vocab_size=5000,
embedding_dim=128,
padding_idx=0,
)
token_embedder = TokenEmbedder(config)
out = token_embedder(input_ids, attention_mask)
# out["token_embeddings"]: (batch, seq_len, 128)
# With self-attention
attention_config = AttentionConfig(
n_layers=2,
n_head=4,
n_kv_head=4,
positional_encoding=False,
)
config = TokenEmbedderConfig(
vocab_size=5000,
embedding_dim=128,
padding_idx=0,
attention_config=attention_config,
)
token_embedder = TokenEmbedder(config)
SentenceEmbedder#
Aggregates per-token embeddings into a single sentence embedding.
- class torchTextClassifiers.model.components.text_embedder.SentenceEmbedder(sentence_embedder_config)[source]#
Bases:
Module- __init__(sentence_embedder_config)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(token_embeddings, attention_mask, return_label_attention_matrix=False)[source]#
Compute sentence embedding from embedded tokens - “remove” second dimension.
- Args (output from dataset collate_fn):
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
- Returns:
- A dictionary containing:
’sentence_embedding’: Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
’label_attention_matrix’: Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
- Return type:
Dict[str, Optional[torch.Tensor]]
SentenceEmbedderConfig#
Configuration for SentenceEmbedder.
LabelAttentionConfig#
Configuration for the label-attention aggregation mode.
- class torchTextClassifiers.model.components.text_embedder.LabelAttentionConfig(n_head, num_classes, embedding_dim)[source]#
Bases:
object- __init__(n_head, num_classes, embedding_dim)#
Example:
from torchTextClassifiers.model.components import (
SentenceEmbedder, SentenceEmbedderConfig,
LabelAttentionConfig,
)
# Mean-pooling (default)
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))
out = sentence_embedder(token_embeddings, attention_mask)
# out["sentence_embedding"]: (batch, 128)
# Label attention — one embedding per class
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(
aggregation_method=None,
label_attention_config=LabelAttentionConfig(
n_head=4,
num_classes=6,
embedding_dim=128,
),
))
out = sentence_embedder(token_embeddings, attention_mask)
# out["sentence_embedding"]: (batch, num_classes, 128)
Categorical Features#
CategoricalVariableNet#
Handles categorical features alongside text.
- class torchTextClassifiers.model.components.categorical_var_net.CategoricalVariableNet(categorical_vocabulary_sizes, categorical_embedding_dims=None, text_embedding_dim=None)[source]#
Bases:
Module- __init__(categorical_vocabulary_sizes, categorical_embedding_dims=None, text_embedding_dim=None)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(categorical_vars_tensor)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
CategoricalForwardType#
Enum for categorical feature combination strategies.
- class torchTextClassifiers.model.components.categorical_var_net.CategoricalForwardType(*values)[source]#
Bases:
Enum- SUM_TO_TEXT#
Sum categorical embeddings, concatenate with text.
- AVERAGE_AND_CONCAT#
Average categorical embeddings, concatenate with text.
- CONCATENATE_ALL#
Concatenate all embeddings (text + each categorical).
- SUM_TO_TEXT = 'EMBEDDING_SUM_TO_TEXT'#
- AVERAGE_AND_CONCAT = 'EMBEDDING_AVERAGE_AND_CONCAT'#
- CONCATENATE_ALL = 'EMBEDDING_CONCATENATE_ALL'#
Example:
from torchTextClassifiers.model.components import (
CategoricalVariableNet,
CategoricalForwardType
)
# 3 categorical variables with different vocab sizes
cat_net = CategoricalVariableNet(
vocabulary_sizes=[10, 5, 20],
embedding_dims=[8, 4, 16],
forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT
)
# Forward pass
cat_embeddings = cat_net(categorical_data)
Classification Head#
ClassificationHead#
Linear classification layer(s).
- class torchTextClassifiers.model.components.classification_head.ClassificationHead(input_dim=None, num_classes=None, net=None)[source]#
Bases:
Module- __init__(input_dim=None, num_classes=None, net=None)[source]#
Classification head for text classification tasks. It is a nn.Module that can either be a simple Linear layer or a custom neural network module.
- Parameters:
input_dim (int, optional) – Dimension of the input features. Required if net is not provided.
num_classes (int, optional) – Number of output classes. Required if net is not provided.
net (nn.Module, optional) – Custom neural network module to be used as the classification head. If provided, input_dim and num_classes are inferred from this module. Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
Example:
from torchTextClassifiers.model.components import ClassificationHead
# Simple linear classifier
head = ClassificationHead(
input_dim=128,
num_classes=5
)
# Custom classifier with nested nn.Module
import torch.nn as nn
custom_head_module = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 5)
)
head = ClassificationHead(net=custom_head_module)
Attention Mechanism#
AttentionConfig#
Configuration for transformer-style self-attention.
Block#
Single transformer block with self-attention + MLP.
- class torchTextClassifiers.model.components.attention.Block(config, layer_idx)[source]#
Bases:
Module- __init__(config, layer_idx)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, cos_sin)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
SelfAttentionLayer#
Multi-head self-attention layer.
- class torchTextClassifiers.model.components.attention.SelfAttentionLayer(config, layer_idx)[source]#
Bases:
Module- __init__(config, layer_idx)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, cos_sin=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
MLP#
Feed-forward network.
- class torchTextClassifiers.model.components.attention.MLP(config)[source]#
Bases:
Module- __init__(config)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Example:
from torchTextClassifiers.model.components import AttentionConfig, Block
# Configure attention
config = AttentionConfig(
n_embd=128,
n_head=4,
n_layer=3,
dropout=0.1
)
# Create transformer block
block = Block(config)
# Forward pass (requires rotary embeddings cos, sin)
output = block(embeddings, cos, sin)
Composing Components#
Components can be composed to create custom architectures:
import torch
import torch.nn as nn
from torchTextClassifiers.model.components import (
TokenEmbedder, TokenEmbedderConfig,
SentenceEmbedder, SentenceEmbedderConfig,
CategoricalVariableNet, ClassificationHead,
)
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedder = TokenEmbedder(TokenEmbedderConfig(
vocab_size=5000, embedding_dim=128, padding_idx=0,
))
self.sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig())
self.cat_net = CategoricalVariableNet(...)
self.head = ClassificationHead(...)
def forward(self, input_ids, attention_mask, categorical_data):
token_out = self.token_embedder(input_ids, attention_mask)
sent_out = self.sentence_embedder(
token_out["token_embeddings"], token_out["attention_mask"]
)
cat_features = self.cat_net(categorical_data)
combined = torch.cat([sent_out["sentence_embedding"], cat_features], dim=1)
return self.head(combined)
See Also#
Core Models - How components are used in models
Architecture Overview - Architecture explanation
Configuration Classes - ModelConfig for component configuration