torchTextClassifiers Wrapper#

The main wrapper class for text classification tasks.

Main Class#

class torchTextClassifiers.torchTextClassifiers.torchTextClassifiers(tokenizer, model_config, ragged_multilabel=False)[source]#

Bases: object

Generic text classifier framework supporting multiple architectures.

Given a tokenizer and model configuration, this class initializes: - Text embedding layer (if needed) - Categorical variable embedding network (if categorical variables are provided) - Classification head The resulting model can be trained using PyTorch Lightning and used for predictions.

Methods

train

Train the classifier using PyTorch Lightning.

predict

save

Save the complete torchTextClassifiers instance to disk.

load

Load a torchTextClassifiers instance from disk.

__init__(tokenizer, model_config, ragged_multilabel=False)[source]#

Initialize the torchTextClassifiers instance.

Parameters:
  • tokenizer (BaseTokenizer) – A tokenizer instance for text preprocessing

  • model_config (ModelConfig) – Configuration parameters for the text classification model

Example

>>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
>>>  # Assume tokenizer is a trained BaseTokenizer instance
>>> model_config = ModelConfig(
...     embedding_dim=10,
...     categorical_vocabulary_sizes=[30, 25],
...     categorical_embedding_dims=[10, 5],
...     num_classes=10,
... )
>>> ttc = torchTextClassifiers(
...     tokenizer=tokenizer,
...     model_config=model_config,
... )
train(X_train, y_train, training_config, X_val=None, y_val=None, verbose=False)[source]#

Train the classifier using PyTorch Lightning.

This method handles the complete training process including: - Data validation and preprocessing - Dataset and DataLoader creation - PyTorch Lightning trainer setup with callbacks - Model training with early stopping - Best model loading after training

Note on Checkpoints:

After training, the best model checkpoint is automatically loaded. This checkpoint contains the full training state (model weights, optimizer, and scheduler state). Loading uses weights_only=False as the checkpoint is self-generated and trusted.

Parameters:
  • X_train (ndarray) – Training input data

  • y_train (ndarray) – Training labels

  • X_val (Optional[ndarray]) – Validation input data

  • y_val (Optional[ndarray]) – Validation labels

  • training_config (TrainingConfig) – Configuration parameters for training

  • verbose (bool) – Whether to print training progress information

Return type:

None

Example

>>> training_config = TrainingConfig(
...     lr=1e-3,
...     batch_size=4,
...     num_epochs=1,
... )
>>> ttc.train(
...     X_train=X,
...     y_train=Y,
...     X_val=X,
...     y_val=Y,
...     training_config=training_config,
... )
predict(X_test, top_k=1, explain=False)[source]#
Parameters:
  • X_test (np.ndarray) – input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables

  • top_k (int) – for each sentence, return the top_k most likely predictions (default: 1)

  • explain (bool) – launch gradient integration to have an explanation of the prediction (default: False)

Returns: A dictionary containing the following fields:
  • predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.

  • confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.

  • if explain is True:
    • attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.

save(path)[source]#

Save the complete torchTextClassifiers instance to disk.

This saves: - Model configuration - Tokenizer state - PyTorch Lightning checkpoint (if trained) - All other instance attributes

Parameters:

path (Union[str, Path]) – Directory path where the model will be saved

Return type:

None

Example

>>> ttc = torchTextClassifiers(tokenizer, model_config)
>>> ttc.train(X_train, y_train, training_config)
>>> ttc.save("my_model")
classmethod load(path, device='auto')[source]#

Load a torchTextClassifiers instance from disk.

Parameters:
  • path (Union[str, Path]) – Directory path where the model was saved

  • device (str) – Device to load the model on (‘auto’, ‘cpu’, ‘cuda’, etc.)

Return type:

torchTextClassifiers

Returns:

Loaded torchTextClassifiers instance

Example

>>> loaded_ttc = torchTextClassifiers.load("my_model")
>>> predictions = loaded_ttc.predict(X_test)

Usage Example#

from torchTextClassifiers import torchTextClassifiers, ModelConfig, TrainingConfig
from torchTextClassifiers.tokenizers import WordPieceTokenizer

# Create tokenizer
tokenizer = WordPieceTokenizer()
tokenizer.train(texts, vocab_size=1000)

# Configure model
model_config = ModelConfig(embedding_dim=64, num_classes=2)
training_config = TrainingConfig(num_epochs=10, batch_size=16, lr=1e-3)

# Create and train classifier
classifier = torchTextClassifiers(tokenizer=tokenizer, model_config=model_config)
classifier.train(X_text=texts, y=labels, training_config=training_config)

# Make predictions
predictions = classifier.predict(new_texts)
probabilities = classifier.predict_proba(new_texts)

See Also#