🧠 Fed-X-ViT: A Federated Transfer Learning & XAI Pipeline for Brain Tumor Diagnosis using Azure MLOps
A federated simulation pipeline demonstrating how a pre-trained "expert" AI model can be securely fine-tuned on private, decentralized datasets (e.g., from multiple hospitals) to solve a new clinical task. All results, metrics, and models are tracked and versioned in the cloud using Azure Machine Learning.
- Project Overview
- Key Features
- MLOps Architecture
- The Model: Fed-X-ViT
- Performance & Results
- Explainable AI (XAI) Validation
- How to Run the Simulation
- Contributors
This project addresses three critical challenges in clinical AI: Accuracy, Privacy, and Trust.
- Accuracy: We use a state-of-the-art hybrid CNN-Vision Transformer model (Fed-X-ViT) to achieve greater than 99% accuracy in brain tumor classification.
- Privacy: We simulate a Federated Learning (FL) environment using the Flower framework, where three "clients" (simulating 3 different hospitals) collaboratively train a global model on their own private data. No private data ever leaves the local client.
- Trust: We integrate Explainable AI (XAI) using Grad-CAM to visualize precisely why the model makes its predictions. To ensure our results are transparent and reproducible, we use Azure Machine Learning to log all metrics, parameters, and final model artifacts.
This repository demonstrates the full, end-to-end pipeline:
- Phase 1: Training a powerful 4-class "expert" model on a public dataset.
- Phase 2: Using Federated Transfer Learning to adapt this expert for a new 2-class (Tumor/Healthy) task, using data partitioned across 3 simulated clients.
- Phase 3: Running a final XAI analysis using Grad-CAM to validate that the final federated model's decisions are clinically relevant.
- State-of-the-Art Hybrid Model: A novel hierarchical model combining a CNN (EfficientNetV2) for local feature extraction with a Vision Transformer (ViT) for global context analysis.
- Federated Transfer Learning: An advanced technique where a pre-trained "expert" model is surgically adapted for a new task. This preserves its core knowledge, allowing it to be fine-tuned rapidly and efficiently on small, private datasets.
- Azure MLOps Integration: The entire simulation is connected to an Azure Machine Learning Workspace. All metrics (e.g.,
global_val_accuracy,global_val_loss) are logged in real-time via MLflow, and the final model is automatically versioned and stored. - Parallel GPU Simulation: The 3-client simulation runs on a single local GPU (e.g., RTX 3060) using Flower's high-performance Ray backend for parallelization.
- Explainable & Trustworthy: Includes a
xai.pyscript to run Grad-CAM analysis, generating heatmaps that prove the model's decisions are based on correct clinical features.
We simulate a "Hub-and-Spoke" model. The "Hub" is the central server and cloud infrastructure, and the "Spokes" are the private clients (hospitals).
| Component | Role & Technology |
|---|---|
| Local Machine | A single PC (e.g., with an RTX 3060 GPU) runs the entire simulation. |
| ↳ Server Process | The server.py script acts as the Coordinator. It loads the expert model, performs transfer learning, connects to Azure, and orchestrates the federated rounds using the FedAvg strategy. |
| ↳ Client Processes | The server spawns 3 client processes (client.py) using Ray. Each client is given its own sandboxed dataset and performs local training on the GPU. |
| Azure Cloud | The central, immutable "lab notebook" for the project. |
| ↳ Azure ML Workspace | Acts as the central hub for tracking experiments, models, and results. |
| ↳ MLflow Tracking | The API used by the server to log metrics and parameters to the Azure ML run history. |
| ↳ Azure Blob Storage | The secure storage where the final, globally aggregated model (final_federated_model.pth) and XAI heatmaps are automatically uploaded. |
Our model's architecture is inspired by the diagnostic workflow of an expert radiologist:
- CNN (EfficientNetV2): Acts as the "Eyes." It performs a high-resolution scan of the MRI, extracting thousands of low-level features like textures, edges, and local patterns indicative of anomalies.
- ViT (Vision Transformer): Acts as the "Brain." It receives the rich feature map from the CNN, segments it into patches, and uses its self-attention mechanism to analyze the global context and relationships between all features. This allows it to make a final, highly-informed classification.
This is the core of our Federated Transfer Learning pipeline. Instead of training a new model from scratch, we adapt our pre-trained expert.
- Load Expert: We start with our 99.31% accurate 4-class "expert" model, which has learned a rich representation of brain MRI features.
- Create New Model: We instantiate a new, blank model with the same architecture but a 2-class (Tumor/Healthy) classification head.
- Transplant Weights: We surgically transfer the weights from all the feature extraction layers (the entire CNN and ViT backbone) from the expert model to the new model.
- Isolate Final Layer: We discard the old 4-class "decision layer" and leave the new 2-class decision layer randomly initialized.
This new hybrid model, which is 99% pre-trained, is sent to the clients. They only need to fine-tune the final layer on their small, private datasets, making the federated process incredibly data-efficient and fast.
First, the Fed-X-ViT model was trained centrally on a public 4-class dataset (Glioma, Meningioma, Pituitary, No Tumor) to establish a powerful baseline.
| Metric | Result |
|---|---|
| Final Test Set Accuracy | 99.31% |
Next, we ran the 3-client federated simulation for 5 rounds on the new 2-class (Tumor/Healthy) dataset.
| Parameter | Details |
|---|---|
| Clients | 3 (Simulated hospitals) |
| Federation Strategy | Federated Averaging (FedAvg) |
| Total Training Images | 11,693 |
| Total Validation Images | 2,505 |
| Total Testing Images | 2,514 |
| Global Rounds | 5 |
| Framework | Flower + Ray |
| GPU | NVIDIA RTX 3060 |
The following metrics were logged in real-time to Azure ML during the simulation run. The model showed rapid convergence, starting at an impressive 99.32% and peaking at 99.72%—demonstrating that federated fine-tuning can improve an already-expert model.
| Round | Global Validation Accuracy | Global Validation Loss |
|---|---|---|
| 1 | 99.321% | 0.0201 |
| 2 | 99.481% | 0.0148 |
| 3 | 99.481% | 0.0133 |
| 4 | 99.600% | 0.0116 |
| 5 | 99.720% | 0.0078 |
After 5 rounds, the server saved the final global model (final_federated_model.pth) and automatically tested it on a combined, unseen test set of 2,514 images.
| Metric | Final Result |
|---|---|
| Final Combined Test Accuracy | 99.60% |
| Final Combined Test Loss | 0.0125 |
This report shows the final federated model has near-perfect, well-balanced performance for both classes.
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| Healthy | 0.99 | 0.99 | 0.99 | 666 |
| Tumor | 1.00 | 1.00 | 1.00 | 1848 |
| macro avg | 0.99 | 1.00 | 0.99 | 2514 |
| weighted avg | 1.00 | 1.00 | 1.00 | 2514 |
To ensure our model is trustworthy, we must verify that it is looking at the correct regions of an MRI scan to make its predictions. We use Gradient-weighted Class Activation Mapping (Grad-CAM) for this purpose.
Grad-CAM is traditionally used on CNNs. Our hybrid model requires a more sophisticated approach:
- The ViT is the "Brain": The Vision Transformer makes the final classification decision (e.g., "Tumor").
- The CNN is the "Eyes": Grad-CAM is applied to the last convolutional layer of the EfficientNetV2 backbone.
- We ask Grad-CAM: "Show me the pixels in the final feature map that were most important for the ViT's final 'Tumor' decision."
The result is a heatmap where the ViT's high-level reasoning guides the Grad-CAM's pixel-level visualization. This shows us exactly which parts of the original image contributed most to the final classification.
The xai.py script was run after the simulation completed. It randomly selected 25 tumor images from each client's private test set and generated Grad-CAM heatmaps.
As shown in the image above, the heatmaps confirm that the model correctly focuses its attention on the tumorous regions of the brain. This proves that the model's high accuracy is a result of learning genuine, clinically relevant features, not exploiting biases or artifacts in the data.
- Python 3.10+
- An NVIDIA GPU with CUDA 12.1+ support (e.g., RTX 3060 or higher).
- An Azure Account with an Azure Machine Learning Workspace created. (We used Visual Studio Enterprise Subscription, credits to MLSA pgm)
-
Clone the repository:
git clone [https://github.com/harshitsinghcode/Fed-X-ViT.git](https://github.com/harshitsinghcode/Fed-X-ViT.git) cd Fed-X-ViT -
Create and activate a virtual environment:
python -m venv venv .\venv\Scripts\activate # On Windows :( # source venv/bin/activate # On macOS/Linux :)
-
Install all required libraries:
- First, install PyTorch matching your system's CUDA version. Check your version with
nvidia-smi.# Example for CUDA 12.1 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 - Then, install the remaining dependencies:
pip install -r requirements.txt
- First, install PyTorch matching your system's CUDA version. Check your version with
-
Download Azure Config File:
- Go to your Azure ML Workspace in the Azure Portal.
- On the Overview page, click Download config.json.
- Place this
config.jsonfile in the root folder of the project (D:\FedXViT\).
-
Authenticate with Azure CLI:
- Run this command in your terminal. It will open a browser window for you to log in.
az login
- Download Data: Download the pre-organized client datasets and pre-trained models from the provided link.
- Run the Splitting Script: Run the data splitting and counting script once to organize your source data into the final
SplitDatafolder, perfectly structured for the simulation.python count.py
You only need to run a single command. This will start the server, connect to Azure, spawn the 3 clients, run 5 rounds of federated training, save the final model, and run the final evaluation on the test set.
python server.pyTo save the complete console output to a log file for later review:
python server.py | tee logs.txt
- Sinchan Shetty
22BCE5238 - Riddhi Bandyopadhyay
22BCE1068 - Harshit Kumar Singh
22BLC1079
Guided By
Dr. Suganya R SCOPE - 52858





