class MLFlowPyTorchWrapper(mlflow.pyfunc.PythonModel):
def __init__()
...
def predict(self, model_input: list[SingleForm], params=None) -> list[PredictionResponse]:
query = self.preprocess_inputs(
inputs=model_input,
)
# Preprocess inputs
text = query[self.text_feature].values
categorical_variables = query[self.categorical_features].values
...
all_scores = []
for batch_idx, batch in enumerate(dataloader):
with torch.no_grad():
scores = self.module(batch).detach()
all_scores.append(scores)
all_scores = torch.cat(all_scores)
probs = torch.nn.functional.softmax(all_scores, dim=1)
...
responses = []
for i in range(len(predictions[0])):
response = process_response(predictions, i, nb_echos_max, prob_min, self.libs)
responses.append(response)
return responses





