Model Management

Custom Model Management System

Introduction

The Custom Model Management System is a modular framework designed for creating, training, and deploying machine learning models within a customizable API structure. This system enables developers to work with various input types, incorporate advanced features like explainability and Out-of-Distribution (OOD) detection, and integrate Quantum Graph Neural Networks (QGNN) for fraud detection.

Architecture and Components

The system is structured to handle multiple machine learning models, ensuring modularity and flexibility. Key components include model management, API routes, and utility functions.

Core Modules and Functionalities

CustomModelManager

This is the primary class responsible for managing custom models, providing methods for model creation, training, inference, classification, and forecasting.

Model Creation: Allows users to create new models based on specified configurations. Configurations include input type (e.g., float_array, nested_json), model parameters, and input structure.

Training: Handles model training with support for custom batch sizes, epochs, and binary or regression loss functions. It preprocesses data, manages device allocation, and logs training progress.

Prediction: Supports standard predictions as well as Monte Carlo (MC) Dropout-based predictions for uncertainty estimation.

Forecasting: Enables multi-step forecasting by iteratively predicting future values based on recent inputs and optional external data.

Classification: Incorporates features like temperature scaling, MC Dropout, explainability using SHAP, and OOD detection.

QGNN Models

Quantum Graph Neural Networks (QGNN) are integrated for specialized fraud detection tasks. These models use transaction data, transformed into graphs, for real-time and batch fraud predictions.

Training: Includes methods for training QGNN models on transaction data, allowing advanced fraud detection.

Prediction: Supports fraud prediction using preprocessed graph structures.

Explainability: Enables visualization and insights into QGNN predictions, adding transparency to fraud detection processes.

Core Functionalities

1. Model Creation

Custom models can be created based on configurations, including input types, preprocessing steps, and model parameters. The configuration file saves all model details for consistency across training and inference tasks.

Input Nodes Calculation: Automatically calculates input nodes based on configuration, especially for JSON and float array inputs.

Configuration Management: Saves configuration files in designated directories, enabling consistent model initialization and updates.

2. Training

Models are trained using specified parameters, with support for batch processing, custom loss functions, and logging.

Epoch and Batch Management: Customizable number of epochs and batch size, managed in training loops for model optimization.

Loss Calculation: Provides different loss functions based on classification or regression tasks.

Checkpointing: Periodically saves model states to ensure training progress can be recovered.

3. Prediction and Inference

Models make predictions based on new input data, with support for enhanced techniques like MC Dropout and temperature scaling.

Monte Carlo Dropout: Optionally applies dropout during inference to estimate prediction uncertainty.

Temperature Scaling: Scales output logits to adjust confidence in probabilistic outputs.

Output Processing: Converts model outputs back to the original scale, ensuring user-friendly predictions.

4. Classification

Classification includes uncertainty estimation, explainability, and OOD detection for robust model deployment.

Explainability: Uses SHAP to interpret model predictions, providing detailed insights for each feature’s contribution.

Out-of-Distribution (OOD) Detection: Identifies OOD samples by comparing predictions to predefined thresholds.

Class Probability and Confidence: Uses temperature scaling and Monte Carlo Dropout to return probabilistic outputs with confidence scores.

5. Forecasting

Provides multi-step forecasts, allowing the system to predict future data points by feeding previous outputs back into the model.

Input Data Evolution: Uses historical data for forecasting; optionally integrates future events or data for refined predictions.

Rolling Prediction: Updates predictions based on prior results and additional input data, ensuring accuracy over multiple steps.

6. Fraud Detection with QGNN

The system includes Quantum Graph Neural Network models for fraud detection tasks. Transaction data is processed into graph form, enabling detailed fraud analysis.

Graph Preprocessing: Prepares transaction data as node features in a graph structure.

Prediction: QGNN-based prediction of fraud scores, supporting both real-time and batch analysis.

Visualization and Explainability: Provides tools for visualizing QGNN-based predictions, assisting in transparent fraud detection.

Using the Custom Model Management System

Initialization and Configuration

Custom Model Creation: Use the /create_model endpoint to define a model with input type, preprocessing, and model parameters.

Model Parameters: Specify configurations for each model, including input structure and data types, allowing consistent training and predictions.

Training and Evaluation

Training: Models are trained using input and target data, with logging enabled for each epoch’s progress.

Loss Calculation and Checkpointing: Evaluates loss at each epoch, saving checkpoints to ensure continuity in case of interruptions.

Prediction and Inference

Single-Step Prediction: Accepts input data through the /predict endpoint, returning structured outputs for each data point.

Uncertainty and Confidence: Supports confidence calibration via Monte Carlo Dropout and temperature scaling, providing users with robust predictions.

Model Management

Saving and Loading: Each model’s state is saved post-training, allowing reloading for continued use.

Configuration Persistence: Stores configuration files and scalers for consistent model initialization, retraining, and inference.