Core Models#

Core PyTorch and PyTorch Lightning models.

PyTorch Model#

TextClassificationModel#

Core PyTorch nn.Module combining all components.

class torchTextClassifiers.model.model.TextClassificationModel(classification_head, text_embedder=None, categorical_variable_net=None)[source]#

Bases: Module

FastText Pytorch Model.

Architecture:

The model combines three main components:

  1. TextEmbedder: Converts tokens to embeddings

  2. CategoricalVariableNet (optional): Handles categorical features

  3. ClassificationHead: Produces class logits

__init__(classification_head, text_embedder=None, categorical_variable_net=None)[source]#

Constructor for the FastTextModel class.

Parameters:
  • classification_head (ClassificationHead) – The classification head module.

  • text_embedder (Optional[TextEmbedder]) – The text embedding module. If not provided, assumes that input text is already embedded (as tensors) and directly passed to the classification head.

  • categorical_variable_net (Optional[CategoricalVariableNet]) – The categorical variable network module. If not provided, assumes no categorical variables are used.

forward(input_ids, attention_mask, categorical_vars, **kwargs)[source]#

Memory-efficient forward pass implementation.

Args: output from dataset collate_fn

input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)

Returns:

Model output scores for each class - shape (batch_size, num_classes)

Raw, not softmaxed.

Return type:

torch.Tensor

Example:

from torchTextClassifiers.model import TextClassificationModel
from torchTextClassifiers.model.components import (
    TextEmbedder, TextEmbedderConfig,
    CategoricalVariableNet, CategoricalForwardType,
    ClassificationHead
)

# Create components
text_embedder = TextEmbedder(TextEmbedderConfig(
    vocab_size=5000,
    embedding_dim=128
))

cat_net = CategoricalVariableNet(
    vocabulary_sizes=[10, 20],
    embedding_dims=[8, 16],
    forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT
)

classification_head = ClassificationHead(
    input_dim=128 + 24,  # text_dim + cat_dim
    num_classes=5
)

# Combine into model
model = TextClassificationModel(
    text_embedder=text_embedder,
    categorical_variable_net=cat_net,
    classification_head=classification_head
)

# Forward pass
logits = model(input_ids, categorical_data)

PyTorch Lightning Module#

TextClassificationModule#

PyTorch Lightning LightningModule for automated training.

class torchTextClassifiers.model.lightning.TextClassificationModule(model, loss, optimizer, optimizer_params, scheduler, scheduler_params, scheduler_interval='epoch', **kwargs)[source]#

Bases: LightningModule

Pytorch Lightning Module for FastTextModel.

Features:

  • Automated training/validation/test steps

  • Metrics tracking (accuracy)

  • Optimizer and scheduler management

  • Logging integration

  • PyTorch Lightning callbacks support

__init__(model, loss, optimizer, optimizer_params, scheduler, scheduler_params, scheduler_interval='epoch', **kwargs)[source]#

Initialize FastTextModule.

Parameters:
  • model (TextClassificationModel) – Model.

  • loss – Loss

  • optimizer – Optimizer

  • optimizer_params – Optimizer parameters.

  • scheduler – Scheduler.

  • scheduler_params – Scheduler parameters.

  • scheduler_interval – Scheduler interval.

forward(batch)[source]#

Perform forward-pass.

Parameters:

batch (List[torch.LongTensor]) – Batch to perform forward-pass on.

Return type:

Tensor

Returns (torch.Tensor): Prediction.

training_step(batch, batch_idx)[source]#

Training step.

Parameters:
  • batch (List[torch.LongTensor]) – Training batch.

  • batch_idx (int) – Batch index.

Return type:

Tensor

Returns (torch.Tensor): Loss tensor.

validation_step(batch, batch_idx)[source]#

Validation step.

Parameters:
  • batch (List[torch.LongTensor]) – Validation batch.

  • batch_idx (int) – Batch index.

Returns (torch.Tensor): Loss tensor.

test_step(batch, batch_idx)[source]#

Test step.

Parameters:
  • batch (List[torch.LongTensor]) – Test batch.

  • batch_idx (int) – Batch index.

Returns (torch.Tensor): Loss tensor.

predict_step(batch, batch_idx=0, dataloader_idx=0)[source]#

Prediction step.

Parameters:
  • batch (List[torch.LongTensor]) – Prediction batch.

  • batch_idx (int) – Batch index.

  • dataloader_idx (int) – Dataloader index.

Returns (torch.Tensor): Predictions.

configure_optimizers()[source]#

Configure optimizer for Pytorch lighting.

Returns: Optimizer and scheduler for pytorch lighting.

Example:

from torchTextClassifiers.model import (
    TextClassificationModel,
    TextClassificationModule
)
import torch.nn as nn
import torch.optim as optim
from pytorch_lightning import Trainer

# Create PyTorch model
model = TextClassificationModel(...)

# Wrap in Lightning module
lightning_module = TextClassificationModule(
    model=model,
    loss=nn.CrossEntropyLoss(),
    optimizer=optim.Adam,
    lr=1e-3,
    scheduler=optim.lr_scheduler.StepLR,
    scheduler_params={"step_size": 10, "gamma": 0.1}
)

# Train with Lightning Trainer
trainer = Trainer(
    max_epochs=20,
    accelerator="auto",
    devices=1
)

trainer.fit(
    lightning_module,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

# Test
trainer.test(lightning_module, dataloaders=test_dataloader)

Training Steps#

The TextClassificationModule implements standard training/validation/test steps:

Training Step:

def training_step(self, batch, batch_idx):
    input_ids, cat_features, labels = batch
    logits = self.model(input_ids, cat_features)
    loss = self.loss(logits, labels)
    acc = self.compute_accuracy(logits, labels)

    self.log("train_loss", loss)
    self.log("train_acc", acc)

    return loss

Validation Step:

def validation_step(self, batch, batch_idx):
    input_ids, cat_features, labels = batch
    logits = self.model(input_ids, cat_features)
    loss = self.loss(logits, labels)
    acc = self.compute_accuracy(logits, labels)

    self.log("val_loss", loss)
    self.log("val_acc", acc)

Custom Training#

For custom training loops, use the PyTorch model directly:

from torchTextClassifiers.model import TextClassificationModel
import torch.nn as nn
import torch.optim as optim

model = TextClassificationModel(...)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Custom training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        input_ids, cat_features, labels = batch

        # Forward pass
        logits = model(input_ids, cat_features)
        loss = loss_fn(logits, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}, Loss: {loss.item()}")

See Also#