Core Models#
Core PyTorch and PyTorch Lightning models.
PyTorch Model#
TextClassificationModel#
Core PyTorch nn.Module combining all components.
- class torchTextClassifiers.model.model.TextClassificationModel(classification_head, text_embedder=None, categorical_variable_net=None)[source]#
Bases:
ModuleFastText Pytorch Model.
Architecture:
The model combines three main components:
TextEmbedder: Converts tokens to embeddings
CategoricalVariableNet (optional): Handles categorical features
ClassificationHead: Produces class logits
- __init__(classification_head, text_embedder=None, categorical_variable_net=None)[source]#
Constructor for the FastTextModel class.
- Parameters:
classification_head (ClassificationHead) – The classification head module.
text_embedder (Optional[TextEmbedder]) – The text embedding module. If not provided, assumes that input text is already embedded (as tensors) and directly passed to the classification head.
categorical_variable_net (Optional[CategoricalVariableNet]) – The categorical variable network module. If not provided, assumes no categorical variables are used.
- forward(input_ids, attention_mask, categorical_vars, **kwargs)[source]#
Memory-efficient forward pass implementation.
- Args: output from dataset collate_fn
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
- Returns:
- Model output scores for each class - shape (batch_size, num_classes)
Raw, not softmaxed.
- Return type:
Example:
from torchTextClassifiers.model import TextClassificationModel
from torchTextClassifiers.model.components import (
TextEmbedder, TextEmbedderConfig,
CategoricalVariableNet, CategoricalForwardType,
ClassificationHead
)
# Create components
text_embedder = TextEmbedder(TextEmbedderConfig(
vocab_size=5000,
embedding_dim=128
))
cat_net = CategoricalVariableNet(
vocabulary_sizes=[10, 20],
embedding_dims=[8, 16],
forward_type=CategoricalForwardType.AVERAGE_AND_CONCAT
)
classification_head = ClassificationHead(
input_dim=128 + 24, # text_dim + cat_dim
num_classes=5
)
# Combine into model
model = TextClassificationModel(
text_embedder=text_embedder,
categorical_variable_net=cat_net,
classification_head=classification_head
)
# Forward pass
logits = model(input_ids, categorical_data)
PyTorch Lightning Module#
TextClassificationModule#
PyTorch Lightning LightningModule for automated training.
- class torchTextClassifiers.model.lightning.TextClassificationModule(model, loss, optimizer, optimizer_params, scheduler, scheduler_params, scheduler_interval='epoch', **kwargs)[source]#
Bases:
LightningModulePytorch Lightning Module for FastTextModel.
Features:
Automated training/validation/test steps
Metrics tracking (accuracy)
Optimizer and scheduler management
Logging integration
PyTorch Lightning callbacks support
- __init__(model, loss, optimizer, optimizer_params, scheduler, scheduler_params, scheduler_interval='epoch', **kwargs)[source]#
Initialize FastTextModule.
- Parameters:
model (
TextClassificationModel) – Model.loss – Loss
optimizer – Optimizer
optimizer_params – Optimizer parameters.
scheduler – Scheduler.
scheduler_params – Scheduler parameters.
scheduler_interval – Scheduler interval.
- forward(batch)[source]#
Perform forward-pass.
- Parameters:
batch (List[torch.LongTensor]) – Batch to perform forward-pass on.
- Return type:
Returns (torch.Tensor): Prediction.
- training_step(batch, batch_idx)[source]#
Training step.
- Parameters:
batch (List[torch.LongTensor]) – Training batch.
batch_idx (int) – Batch index.
- Return type:
Returns (torch.Tensor): Loss tensor.
- validation_step(batch, batch_idx)[source]#
Validation step.
- Parameters:
batch (List[torch.LongTensor]) – Validation batch.
batch_idx (int) – Batch index.
Returns (torch.Tensor): Loss tensor.
- test_step(batch, batch_idx)[source]#
Test step.
- Parameters:
batch (List[torch.LongTensor]) – Test batch.
batch_idx (int) – Batch index.
Returns (torch.Tensor): Loss tensor.
Example:
from torchTextClassifiers.model import (
TextClassificationModel,
TextClassificationModule
)
import torch.nn as nn
import torch.optim as optim
from pytorch_lightning import Trainer
# Create PyTorch model
model = TextClassificationModel(...)
# Wrap in Lightning module
lightning_module = TextClassificationModule(
model=model,
loss=nn.CrossEntropyLoss(),
optimizer=optim.Adam,
lr=1e-3,
scheduler=optim.lr_scheduler.StepLR,
scheduler_params={"step_size": 10, "gamma": 0.1}
)
# Train with Lightning Trainer
trainer = Trainer(
max_epochs=20,
accelerator="auto",
devices=1
)
trainer.fit(
lightning_module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader
)
# Test
trainer.test(lightning_module, dataloaders=test_dataloader)
Training Steps#
The TextClassificationModule implements standard training/validation/test steps:
Training Step:
def training_step(self, batch, batch_idx):
input_ids, cat_features, labels = batch
logits = self.model(input_ids, cat_features)
loss = self.loss(logits, labels)
acc = self.compute_accuracy(logits, labels)
self.log("train_loss", loss)
self.log("train_acc", acc)
return loss
Validation Step:
def validation_step(self, batch, batch_idx):
input_ids, cat_features, labels = batch
logits = self.model(input_ids, cat_features)
loss = self.loss(logits, labels)
acc = self.compute_accuracy(logits, labels)
self.log("val_loss", loss)
self.log("val_acc", acc)
Custom Training#
For custom training loops, use the PyTorch model directly:
from torchTextClassifiers.model import TextClassificationModel
import torch.nn as nn
import torch.optim as optim
model = TextClassificationModel(...)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Custom training loop
for epoch in range(num_epochs):
for batch in dataloader:
input_ids, cat_features, labels = batch
# Forward pass
logits = model(input_ids, cat_features)
loss = loss_fn(logits, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
See Also#
Model Components - Model components
torchTextClassifiers Wrapper - High-level wrapper using these models
Dataset - Data loading for models
Configuration Classes - Model and training configuration