Model Components#

Modular torch.nn.Module components for building custom architectures.

Text Embedding#

Text embedding is split into two composable stages:

  1. TokenEmbedder — maps each token to a dense vector (with optional self-attention). Output: (batch, seq_len, embedding_dim).

  2. SentenceEmbedder — aggregates token vectors into a sentence embedding. Output: (batch, embedding_dim) or (batch, num_classes, embedding_dim) with label attention.

TokenEmbedder#

Embeds tokenized text with optional self-attention.

class torchTextClassifiers.model.components.text_embedder.TokenEmbedder(token_embedder_config)[source]#

Bases: Module

A module that takes tokenized text and outputs dense vector representations (one for each token).

cos: Tensor#
sin: Tensor#
__init__(token_embedder_config)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

init_weights()[source]#
forward(input_ids, attention_mask)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Dict[str, Tensor | None]

TokenEmbedderConfig#

Configuration for TokenEmbedder.

class torchTextClassifiers.model.components.text_embedder.TokenEmbedderConfig(vocab_size, embedding_dim, padding_idx, attention_config=None)[source]#

Bases: object

vocab_size: int#
embedding_dim: int#
padding_idx: int#
attention_config: AttentionConfig | None = None#
__init__(vocab_size, embedding_dim, padding_idx, attention_config=None)#

Example:

from torchTextClassifiers.model.components import (
    TokenEmbedder, TokenEmbedderConfig, AttentionConfig,
)

# Simple token embedder (no self-attention)
config = TokenEmbedderConfig(
    vocab_size=5000,
    embedding_dim=128,
    padding_idx=0,
)
token_embedder = TokenEmbedder(config)
out = token_embedder(input_ids, attention_mask)
# out["token_embeddings"]: (batch, seq_len, 128)

# With self-attention
attention_config = AttentionConfig(
    n_layers=2,
    n_head=4,
    n_kv_head=4,
    positional_encoding=False,
)
config = TokenEmbedderConfig(
    vocab_size=5000,
    embedding_dim=128,
    padding_idx=0,
    attention_config=attention_config,
)
token_embedder = TokenEmbedder(config)

SentenceEmbedder#

Aggregates per-token embeddings into a single sentence embedding.

class torchTextClassifiers.model.components.text_embedder.SentenceEmbedder(sentence_embedder_config)[source]#

Bases: Module

__init__(sentence_embedder_config)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(token_embeddings, attention_mask, return_label_attention_matrix=False)[source]#

Compute sentence embedding from embedded tokens - “remove” second dimension.

Args (output from dataset collate_fn):

token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens return_label_attention_matrix (bool): Whether to compute and return the label attention matrix

Returns:

A dictionary containing:
  • ’sentence_embedding’: Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled

  • ’label_attention_matrix’: Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None

Return type:

Dict[str, Optional[torch.Tensor]]

SentenceEmbedderConfig#

Configuration for SentenceEmbedder.

class torchTextClassifiers.model.components.text_embedder.SentenceEmbedderConfig(aggregation_method='mean', label_attention_config=None)[source]#

Bases: object

aggregation_method: str | None = 'mean'#
label_attention_config: LabelAttentionConfig | None = None#
__init__(aggregation_method='mean', label_attention_config=None)#

LabelAttentionConfig#

Configuration for the label-attention aggregation mode.

class torchTextClassifiers.model.components.text_embedder.LabelAttentionConfig(n_head, num_classes, embedding_dim)[source]#

Bases: object

n_head: int#
num_classes: int#
embedding_dim: int#
__init__(n_head, num_classes, embedding_dim)#

Example:

from torchTextClassifiers.model.components import (
    SentenceEmbedder, SentenceEmbedderConfig,
    LabelAttentionConfig,
)

# Mean-pooling (default)
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(aggregation_method="mean"))
out = sentence_embedder(token_embeddings, attention_mask)
# out["sentence_embedding"]: (batch, 128)

# Label attention — one embedding per class
sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig(
    aggregation_method=None,
    label_attention_config=LabelAttentionConfig(
        n_head=4,
        num_classes=6,
        embedding_dim=128,
    ),
))
out = sentence_embedder(token_embeddings, attention_mask)
# out["sentence_embedding"]: (batch, num_classes, 128)

Categorical Features#

CategoricalVariableNet#

Handles categorical features alongside text.

class torchTextClassifiers.model.components.categorical_var_net.CategoricalVariableNet(categorical_vocabulary_sizes, categorical_embedding_dims=None, text_embedding_dim=None)[source]#

Bases: Module

__init__(categorical_vocabulary_sizes, categorical_embedding_dims=None, text_embedding_dim=None)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(categorical_vars_tensor)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

CategoricalForwardType#

Enum for categorical feature combination strategies.

class torchTextClassifiers.model.components.categorical_var_net.CategoricalForwardType(*values)[source]#

Bases: Enum

SUM_TO_TEXT#

Sum categorical embeddings, concatenate with text.

AVERAGE_AND_CONCAT#

Average categorical embeddings, concatenate with text.

CONCATENATE_ALL#

Concatenate all embeddings (text + each categorical).

SUM_TO_TEXT = 'EMBEDDING_SUM_TO_TEXT'#
AVERAGE_AND_CONCAT = 'EMBEDDING_AVERAGE_AND_CONCAT'#
CONCATENATE_ALL = 'EMBEDDING_CONCATENATE_ALL'#

Example:

from torchTextClassifiers.model.components import (
    CategoricalVariableNet,
    CategoricalForwardType
)

# 3 categorical variables with different vocab sizes
cat_net = CategoricalVariableNet(
    vocabulary_sizes=[10, 5, 20],
    embedding_dims=[8, 4, 16],
    forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT
)

# Forward pass
cat_embeddings = cat_net(categorical_data)

Classification Head#

ClassificationHead#

Linear classification layer(s).

class torchTextClassifiers.model.components.classification_head.ClassificationHead(input_dim=None, num_classes=None, net=None)[source]#

Bases: Module

__init__(input_dim=None, num_classes=None, net=None)[source]#

Classification head for text classification tasks. It is a nn.Module that can either be a simple Linear layer or a custom neural network module.

Parameters:
  • input_dim (int, optional) – Dimension of the input features. Required if net is not provided.

  • num_classes (int, optional) – Number of output classes. Required if net is not provided.

  • net (nn.Module, optional) – Custom neural network module to be used as the classification head. If provided, input_dim and num_classes are inferred from this module. Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

Example:

from torchTextClassifiers.model.components import ClassificationHead

# Simple linear classifier
head = ClassificationHead(
    input_dim=128,
    num_classes=5
)

# Custom classifier with nested nn.Module
import torch.nn as nn

custom_head_module = nn.Sequential(
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(64, 5)
)

head = ClassificationHead(net=custom_head_module)

Attention Mechanism#

AttentionConfig#

Configuration for transformer-style self-attention.

class torchTextClassifiers.model.components.attention.AttentionConfig(n_layers, n_head, n_kv_head, sequence_len=None, positional_encoding=True)[source]#

Bases: object

Attributes

n_embd: int#

Embedding dimension.

n_head: int#

Number of attention heads.

n_layer: int#

Number of transformer blocks.

dropout: float#

Dropout rate (default: 0.0).

bias: bool#

Use bias in linear layers (default: False).

n_layers: int#
n_head: int#
n_kv_head: int#
sequence_len: int | None = None#
positional_encoding: bool = True#
__init__(n_layers, n_head, n_kv_head, sequence_len=None, positional_encoding=True)#

Block#

Single transformer block with self-attention + MLP.

class torchTextClassifiers.model.components.attention.Block(config, layer_idx)[source]#

Bases: Module

__init__(config, layer_idx)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, cos_sin)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

SelfAttentionLayer#

Multi-head self-attention layer.

class torchTextClassifiers.model.components.attention.SelfAttentionLayer(config, layer_idx)[source]#

Bases: Module

__init__(config, layer_idx)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, cos_sin=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

MLP#

Feed-forward network.

class torchTextClassifiers.model.components.attention.MLP(config)[source]#

Bases: Module

__init__(config)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Example:

from torchTextClassifiers.model.components import AttentionConfig, Block

# Configure attention
config = AttentionConfig(
    n_embd=128,
    n_head=4,
    n_layer=3,
    dropout=0.1
)

# Create transformer block
block = Block(config)

# Forward pass (requires rotary embeddings cos, sin)
output = block(embeddings, cos, sin)

Composing Components#

Components can be composed to create custom architectures:

import torch
import torch.nn as nn
from torchTextClassifiers.model.components import (
    TokenEmbedder, TokenEmbedderConfig,
    SentenceEmbedder, SentenceEmbedderConfig,
    CategoricalVariableNet, ClassificationHead,
)

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedder = TokenEmbedder(TokenEmbedderConfig(
            vocab_size=5000, embedding_dim=128, padding_idx=0,
        ))
        self.sentence_embedder = SentenceEmbedder(SentenceEmbedderConfig())
        self.cat_net = CategoricalVariableNet(...)
        self.head = ClassificationHead(...)

    def forward(self, input_ids, attention_mask, categorical_data):
        token_out = self.token_embedder(input_ids, attention_mask)
        sent_out = self.sentence_embedder(
            token_out["token_embeddings"], token_out["attention_mask"]
        )
        cat_features = self.cat_net(categorical_data)
        combined = torch.cat([sent_out["sentence_embedding"], cat_features], dim=1)
        return self.head(combined)

See Also#