MRI-Based Diagnostic Framework for Ankylosing Spondylitis

Complete Technical Documentation β€” Dataset, Models, Training, Results & Comparison

Deep LearningMedical ImagingExplainable AIFlask Web App

🎯Aim & Objectives

Aim

To develop an explainable hybrid deep learning framework that accurately detects and stages Ankylosing Spondylitis from sacroiliac joint MRI scans, providing automated segmentation, multi-stage classification, and interpretable visual explanations to support clinical decision-making.

Objectives

  • To develop an Attention U-Net segmentation model for automatic extraction of sacroiliac joint regions from MRI scans, eliminating the need for manual ROI selection.
  • To build a multi-output CNN classifier for simultaneous binary AS detection and 4-stage disease severity classification (Normal β†’ Early β†’ Moderate β†’ Advanced).
  • To implement a Hybrid CNN-Transformer model that combines EfficientNetB0 spatial features with Transformer attention for global context modeling.
  • To integrate Grad-CAM explainability with a novel region-focused variant that highlights sacroiliac joint areas influencing predictions.
  • To evaluate models using accuracy, precision, recall, F1-score, IoU, and confusion matrices on a publicly available dataset.
  • To deploy a complete Flask web application enabling clinicians to upload MRI scans, receive predictions with confidence scores, and view visual explanations.

Plan of Action

Phase 1 (3 Weeks)

Literature Review

Deep dive into CNNs, Vision Transformers, Hybrid architectures, and XAI techniques for AS diagnosis

Phase 2 (2 Weeks)

Dataset Collection

Source MRI images from Kaggle Lumbar Coordinate Pretraining dataset, preprocess and organize

Phase 3 (2 Weeks)

Data Preprocessing

Resize, normalize, generate masks, apply feature-based labeling, split into train/test sets

Phase 4 (5 Weeks)

Model Development

Build Attention U-Net, Simple CNN, and Hybrid CNN-Transformer architectures

Phase 5 (2 Weeks)

Testing & Evaluation

Measure accuracy, precision, recall, F1-score, IoU, generate confusion matrices

Phase 6 (2 Weeks)

XAI Integration

Implement Grad-CAM and focused Grad-CAM for sacroiliac joint region visualization

Phase 7 (2 Weeks)

Web App Development

Build Flask application with authentication, upload, prediction, and history features

Phase 8 (2 Weeks)

Documentation

Compile findings, prepare technical report, finalize Dockerfile for deployment

1Project Overview

This project implements an end-to-end AI-powered diagnostic framework for detecting and staging Ankylosing Spondylitis (AS) from sacroiliac joint MRI scans. It addresses three critical gaps identified in existing literature (see Section 9):

  • Manual ROI Extraction Bottleneck β€” solved via Attention U-Net automatic segmentation
  • Binary-Only Classification β€” solved via multi-output stage-wise classification (Normal / Early / Moderate / Advanced)
  • Black Box Problem β€” solved via region-focused Grad-CAM heatmap explainability
5
Trained Models
900
Dataset Images
96.67%
Best Binary Accuracy
82.22%
Best Stage Accuracy

System Architecture

πŸ“₯ Kaggle Dataset
Lumbar MRI images
β†’
πŸ”§ Preprocessing
256Γ—256, normalize, masks
β†’
πŸ—οΈ Model Training
U-Net + CNN + Hybrid
β†’
βœ… Trained Models
5 .keras files (~581 MB)
🌐 Flask Web App
Upload β†’ Predict β†’ Display
β†’
πŸ“€ User Upload
MRI image (PNG/JPG)
β†’
πŸ” Segmentation
Attention U-Net
β†’
🧠 Classification
CNN / Hybrid
β†’
πŸ”₯ Grad-CAM
Focused heatmap
β†’
πŸ“Š Results
AS status + stage + confidence

2Dataset Creation & Collection

Step 1: Original Source β€” Kaggle Lumbar Coordinate Pretraining

Source: Kaggle: Lumbar Coordinate Pretraining Dataset

The original raw dataset contains lumbar spine MRI scans in NumPy (.npy) and JPEG formats from 4 medical imaging sources:

Source FolderFiles (.npy)Files (.jpg)Description
processed_lsd516516Lumbar Spine Degeneration dataset
processed_spider211210SPIDER spine segmentation dataset
processed_osf3535Open Science Framework spine images
processed_tseg479479T-SEG thoracolumbar segmentation dataset
Total1,2411,240

Additional CSV files provide spine coordinate annotations:

  • coords_pretrain.csv (277 KB) β€” Maps filename β†’ source β†’ x, y coordinates β†’ spine level (L1/L2 to L5/S1)
  • coords_rsna_improved.csv (5.2 MB) β€” RSNA improved coordinates with conditions (Neural Foraminal Narrowing, etc.)

Step 2: Dataset Curation β€” From 1,241 Raw Images to 900 Curated Images

From the original 1,241 raw MRI files, 900 images were selected and processed through the following pipeline:

  1. Source Selection: Images were selected from all 4 source datasets to ensure diversity in MRI acquisition parameters and spinal anatomy variations.
  2. Format Conversion: Raw NumPy arrays and JPEGs were converted to standardized 256Γ—256 grayscale PNG format.
  3. Quality Filtering: Images were filtered for quality β€” removing corrupt, blank, or duplicate scans to arrive at 900 final images.
  4. Mask Generation: Corresponding segmentation masks (256Γ—256 PNG) were generated for each image to delineate sacroiliac joint regions. Each mask is ~815 bytes.
  5. Augmentation: 3Γ— augmentation per image applied during the dataset generation phase (rotation, flip, brightness adjustment, noise addition).
  6. Annotation: A CSV file (dataset.csv) was created with image paths, mask paths, AS status labels, stage labels, bounding box coordinates.
Important Note: The original Kaggle dataset does NOT contain AS-specific labels. The AS labels and stage labels were generated using a feature-based analysis approach (see Section 3: Labeling Strategy).

Step 3: Final Dataset Structure

Dataset/
β”œβ”€β”€ images/              # 900 PNG files (img_0000.png to img_0899.png)
β”‚   β”œβ”€β”€ img_0000.png     # 256Γ—256 grayscale MRI
β”‚   β”œβ”€β”€ img_0001.png
β”‚   └── ... (900 files)
β”œβ”€β”€ masks/               # 900 PNG files (mask_0000.png to mask_0899.png)
β”‚   β”œβ”€β”€ mask_0000.png    # 256Γ—256 binary segmentation mask (~815 bytes each)
β”‚   β”œβ”€β”€ mask_0001.png
β”‚   └── ... (900 files)
β”œβ”€β”€ annotations/
β”‚   └── dataset.csv      # 900 rows with labels and bounding boxes
└── dataset_info.txt     # Dataset metadata summary

CSV Annotation Schema

ColumnTypeDescriptionExample
image_idStringUnique image identifierimg_0000.png
image_pathStringRelative path to imageimages/img_0000.png
mask_pathStringRelative path to segmentation maskmasks/mask_0000.png
AS_statusInteger0 = Negative, 1 = Positive1
stageInteger0 = Normal, 1 = Early, 2 = Moderate, 3 = Advanced1
stage_nameStringHuman-readable stage labelEarly
bbox_x1, bbox_y1IntegerTop-left bounding box corner76, 102
bbox_x2, bbox_y2IntegerBottom-right bounding box corner179, 204

Class Distribution (Initial CSV Labels)

Binary Classification

ClassCount%
AS Positive46952.1%
AS Negative43147.9%

Stage Distribution

StageCount%
Normal (0)55761.9%
Early (1)11112.3%
Moderate (2)12914.3%
Advanced (3)10311.4%
Note: These initial CSV labels were later replaced by feature-based labels (see next section) to create more meaningful labels that correlate with actual image characteristics. The final labels used for training are the balanced feature-based labels from Cell 9 of the notebook.

3Labeling Strategy (Feature-Based)

Since the original Kaggle dataset did not include AS-specific labels, a two-iteration feature-based labeling approach was developed to assign clinically motivated labels based on actual image characteristics.

Iteration 1: Initial Feature-Based Labels (Notebook Cell 8)

Three features were extracted from each image:

FeatureMethodThresholds
Brightnessnp.mean(img)45th percentile
Texturenp.std(img)55th percentile
Structure (Edge Density)cv2.Canny(img, 50, 150)30th/50th/70th percentiles for staging

Rule: AS+ if brightness < 45th percentile AND texture > 55th percentile

Problem: This produced highly unbalanced labels β€” only 16 AS Positive vs 884 AS Negative. Staging was similarly skewed (8 Early, 8 Moderate, 0 Advanced). This was unusable for training.

Iteration 2: Balanced Feature-Based Labels (Notebook Cell 9) βœ… Final Version

Four features were extracted and combined into a composite score:

FeatureMethodWeight in ScoreClinical Rationale
Mean Intensitynp.mean(img)40% of brightnessOverall tissue density
Lower-Half IntensityMean of bottom 50% of image60% of brightnessSacroiliac region is in lower spine
Texture (Std Dev)np.std(img)40% of totalInflammatory changes vary texture
Edge Densitycv2.Canny(img, 50, 150)30% of totalStructural damage increases edges
# Combined score formula
brightness_scores = mean_intensity Γ— 0.4 + lower_half_intensity Γ— 0.6
combined_score = brightness_scores Γ— 0.3 + texture_scores Γ— 0.4 + edge_density Γ— 0.3

# Binary: AS+ if score > 52nd percentile
threshold_binary = np.percentile(combined_score, 52)

# Staging (for AS+ images only):
#   Advanced (3): score > 85th percentile
#   Moderate (2): score > 70th percentile
#   Early (1):    score > 60th percentile
#   Normal (0):   below 60th percentile (AS+ but no clear stage)

Resulting balanced distribution:

LabelCountResult
AS Positive432Well balanced βœ“
AS Negative468Well balanced βœ“
Stage 0 (Normal)540Includes all AS- (468) + some AS+ (72)
Stage 1 (Early)90AS+ with moderate scores
Stage 2 (Moderate)135AS+ with high scores
Stage 3 (Advanced)135AS+ with highest scores

Train/Test Split

SplitCountRatio
Training720 images80%
Testing / Validation180 images20%

Method: train_test_split(test_size=0.2, random_state=42, stratify=y_binary) β€” stratified on binary labels to maintain class proportions in both sets.

4Model Architectures

Model 1: Attention U-Net Segmentation

Purpose: Automatic segmentation of sacroiliac joint regions β€” eliminates manual ROI extraction.

PropertyValue
Parameters7,869,572
Input Shape(256, 256, 1) β€” grayscale
Output(256, 256, 1) β€” sigmoid binary mask
Fileattention_unet_model.keras (~94.6 MB)
Encoder Filters64 β†’ 128 β†’ 256 β†’ 512 (bottleneck)
Decoder3 levels with UpSampling2D + Conv2D(2) + Attention Gate + Concatenate
ActivationReLU (hidden), Sigmoid (output)
Notebook CellCell 3 (build) + Cell 4 (train)
Encoder Path (each level = 2Γ— Conv2D + MaxPool2D):
  Level 1: Conv(64) β†’ Conv(64) β†’ Pool         # 256β†’128
  Level 2: Conv(128) β†’ Conv(128) β†’ Pool        # 128β†’64
  Level 3: Conv(256) β†’ Conv(256) β†’ Pool        # 64β†’32
  Bottleneck: Conv(512) β†’ Conv(512)             # 32Γ—32

Decoder Path (each level = UpSample + AttGate + Concat + 2Γ— Conv2D):
  Up Level 3: UpSample(2) β†’ Conv(256,2) β†’ AttentionGate(conv3, up, 256) β†’ Concat β†’ Conv(256)Γ—2
  Up Level 2: UpSample(2) β†’ Conv(128,2) β†’ AttentionGate(conv2, up, 128) β†’ Concat β†’ Conv(128)Γ—2
  Up Level 1: UpSample(2) β†’ Conv(64,2)  β†’ AttentionGate(conv1, up, 64)  β†’ Concat β†’ Conv(64)Γ—2

Output: Conv2D(1, kernel=1, activation='sigmoid')

Attention Gate Mechanism:

ΞΈ(x) = Conv2D(inter_ch, 1)(skip_connection)   # Transform skip features
Ο†(g) = Conv2D(inter_ch, 1)(gating_signal)     # Transform decoder features
ψ    = sigmoid(Conv2D(1, 1)(relu(ΞΈ + Ο†)))      # Compute attention coefficients
output = skip_connection Γ— ψ                    # Apply learned attention weighting

Why Attention U-Net? Standard U-Net passes all skip connection features equally. The attention gates learn to suppress irrelevant background regions (muscle, fat) and highlight the small sacroiliac joint structures, which is critical since the joint occupies only ~15-20% of the full MRI field of view.

Model 2: Simple CNN Classifier Classifier β€” Primary

Purpose: Dual-output classification β€” simultaneous AS detection + disease stage classification. β˜… Best Performing Model

PropertyValue
Parameters16,870,790
Input Shape(256, 256, 1)
Outputs2 heads: binary_output (2 classes, softmax) + stage_output (4 classes, softmax)
Filecnn_classifier_model.keras (~202.5 MB)
Notebook CellCell 7 (build_simple_cnn)
Input(256, 256, 1)
  β†’ Conv2D(32, 3, relu, same) β†’ MaxPooling2D(2)      # 256β†’128
  β†’ Conv2D(64, 3, relu, same) β†’ MaxPooling2D(2)      # 128β†’64
  β†’ Conv2D(128, 3, relu, same) β†’ MaxPooling2D(2)     # 64β†’32
  β†’ Flatten()                                          # 32Γ—32Γ—128 = 131,072
  β†’ Dense(128, relu) β†’ Dropout(0.5)
  β”œβ”€β”€ Dense(2, softmax) β†’ binary_output  [AS Negative / AS Positive]
  └── Dense(4, softmax) β†’ stage_output   [Normal / Early / Moderate / Advanced]

Design choice: Despite its simplicity, this 3-block CNN significantly outperformed the more complex Hybrid model. The large Flatten+Dense layer (131,072β†’128) gives it strong discriminative power. The dual-output design allows simultaneous binary detection and staging from a single forward pass, with loss_weights={'binary': 1.0, 'stage': 0.5} prioritizing correct AS detection.

Model 3: Hybrid CNN-Transformer v1 Hybrid

Purpose: Combine EfficientNetB0 CNN features with actual Vision Transformer attention blocks for global context.

PropertyValue
Parameters~5.2M (approximate)
BackboneEfficientNetB0 (ImageNet, frozen)
Transformer2 Γ— MultiHeadAttention blocks (4 heads, key_dim=256)
Fileclassifier_best_model.keras (~162.9 MB)
Notebook CellCell 5 (build_hybrid_cnn_transformer)
Input(256, 256, 1) → Conv2D(3,1,same)           # Grayscale→RGB adapter
  β†’ EfficientNetB0(frozen, imagenet)             # Feature extraction
  β†’ GlobalAveragePooling2D()                     # (batch, 1280)
  β†’ Reshape(1, 1280)                             # Sequence for Transformer
  β†’ TransformerBlock(4 heads, mlp_dim=256)       # Self-attention + MLP + residual
  β†’ TransformerBlock(4 heads, mlp_dim=256)       # Second Transformer layer
  β†’ Flatten()
  β†’ Dense(256, relu) β†’ Dropout(0.3)
  β”œβ”€β”€ Dense(2, softmax) β†’ binary_output
  └── Dense(4, softmax) β†’ stage_output

Transformer Block internals:

LayerNorm β†’ MultiHeadAttention(4 heads) β†’ Residual Add
LayerNorm β†’ Dense(mlp_dim, relu) β†’ Dense(original_dim) β†’ Residual Add

Performance: Binary: 54.44%, Stage: 49.44% (early stopped at epoch 21/41)

Model 4: Hybrid CNN-Transformer v2 Hybrid

Purpose: Simplified version replacing Transformer blocks with dense layers.

PropertyValue
Parameters4,838,319
BackboneEfficientNetB0 (ImageNet, frozen)
Post-CNNDense layers only (no Transformer blocks)
Filehybrid_cnn_transformer_model.keras (~26.5 MB)
Notebook CellCell 15 (build_hybrid_cnn_transformer_v2)
Input(256, 256, 1) → Conv2D(3,1,same)           # Grayscale→RGB adapter
  β†’ EfficientNetB0(frozen, imagenet)             # Feature extraction
  β†’ GlobalAveragePooling2D()                     # (batch, 1280)
  β†’ Dense(512, relu) β†’ Dropout(0.4)
  β†’ Dense(256, relu) β†’ Dropout(0.3)
  β”œβ”€β”€ Dense(2, softmax) β†’ binary_output
  └── Dense(4, softmax) β†’ stage_output

Performance: Binary: 52.22%, Stage: 57.78% (early stopped at epoch 16, best epoch 1)

Why Both Hybrid Models Underperformed: The EfficientNetB0 backbone was frozen (base_model.trainable = False), meaning it could not adapt its ImageNet features (trained on natural images like cats, dogs, cars) to the very different domain of medical MRI scans. Additionally, the single-token sequence (1Γ—1280 after GAP) provided minimal benefit from the Transformer attention mechanism. Fine-tuning the last few EfficientNet layers would likely improve results significantly.

All Models Summary

ModelFileSizeTypeCellStatus
Attention U-Netattention_unet_model.keras94.6 MBSegmentation3+4Deployed
U-Net Best Checkpointunet_best_model.keras94.6 MBSegmentation4Deployed
Simple CNN Classifier β˜…cnn_classifier_model.keras202.5 MBClassifier7+10Primary
Hybrid v1 (Transformer)classifier_best_model.keras162.9 MBClassifier5+6Needs tuning
Hybrid v2 (Dense)hybrid_cnn_transformer_model.keras26.5 MBClassifier15Needs tuning

Total model storage: ~581.1 MB

5Training Configuration & Details

Implementation Steps

  1. Data Loading (Cell 0-2): CSV loaded with pd.read_csv(), images read via cv2.imread(IMREAD_GRAYSCALE), normalized to [0,1], reshaped to (N, 256, 256, 1). Train/test split with stratification.
  2. Feature-Based Relabeling (Cell 8-9): Two rounds of feature extraction (intensity, texture, edges, lower-half intensity) to create balanced, image-characteristic-based AS labels.
  3. Attention U-Net Training (Cell 3-4): Segmentation model trained on image→mask pairs using binary crossentropy.
  4. Hybrid CNN-Transformer v1 Training (Cell 5-6): EfficientNetB0 + Transformer blocks + simplified classifier trained on relabeled data.
  5. Simple CNN Training (Cell 7+10): 3-block CNN with dual heads trained on balanced feature-based labels β€” achieved best results.
  6. Evaluation (Cell 11): Full classification reports, confusion matrices, IoU computation.
  7. Testing (Cell 12): Visual prediction on random test samples with overlays.
  8. Grad-CAM (Cell 13-14): Standard and focused Grad-CAM heatmap generation and visualization.
  9. Hybrid v2 Training (Cell 15): Simplified hybrid model with dense-only layers.

Training Hyperparameters Comparison

ParameterAttention U-NetSimple CNN β˜…Hybrid v1Hybrid v2
OptimizerAdam (default lr)Adam (lr=0.001)Adam (default)Adam (lr=0.001)
LossBinary CESparse Cat. CE (Γ—2)Sparse Cat. CE (Γ—2)Sparse Cat. CE (Γ—2)
Loss WeightsN/Abinary:1.0, stage:0.5binary:1.0, stage:0.5binary:1.0, stage:0.5
Max Epochs501005050
Batch Size16321616
EarlyStopping Monitorval_lossval_binary_output_accuracyval_binary_output_accuracyval_binary_output_accuracy
EarlyStopping Patience10201515
ReduceLROnPlateauYes (factor=0.5, patience=5)Yes (factor=0.5, patience=7)Yes (factor=0.5, patience=5)Yes (factor=0.5, patience=5)
ModelCheckpointYes (unet_best_model.keras)NoNoNo
Actual Epochs~50 (full run)26 (best @ epoch 6)41 (best @ epoch 21)16 (best @ epoch 1)

Training Environment

Libraries

LibraryVersion
Python3.x (Kaggle)
TensorFlow2.19.0
Keras3.10.0
NumPy2.0.2
Pandas2.2.2
OpenCV4.12.0
Scikit-learn1.6.1
Matplotlib3.10.0
Seaborn0.13.2

Hardware

ComponentSpec
PlatformKaggle Notebooks
GPU1Γ— GPU (T4 / P100)
RAM~16 GB
CUDAEnabled
StorageKaggle workspace

6Results & Performance Metrics

Performance Summary

ModelTypeParametersBinary AccStage AccIoUStatus
Attention U-NetSegmentation7,869,57289.35% (val)β€”0.5652βœ… Deployed
Simple CNN β˜…Classification16,870,79096.67%82.22%β€”βœ… Primary
Hybrid v1Classification~5.2M54.44%49.44%β€”βš οΈ Needs tuning
Hybrid v2Classification4,838,31952.22%57.78%β€”βš οΈ Needs tuning

CNN Classifier β€” Detailed Classification Report

Binary Classification (AS+/AS-) β€” 96.67% Accuracy

ClassPrecisionRecallF1-ScoreSupport
AS Negative0.990.950.9794
AS Positive0.940.990.9786
Overall Accuracy0.97 (96.67%)180
Macro Avg0.970.970.97180
Weighted Avg0.970.970.97180

Stage Classification β€” 82.22% Accuracy

StagePrecisionRecallF1-ScoreSupport
Normal (0)0.871.000.93104
Early (1)0.000.000.0018
Moderate (2)0.610.730.6726
Advanced (3)0.930.780.8532
Accuracy82.22%180
Macro Avg0.600.630.61180
Weighted Avg0.750.820.78180
Early Stage Detection Issue: Stage 1 (Early) has 0% recall due to only 18 test samples and class imbalance. The model classifies borderline Early cases as Normal or Moderate. This is a known limitation and area for improvement (see Status section).

7Grad-CAM Explainability

Two Grad-CAM implementations provide visual explanations for the model's classification decisions:

Standard Grad-CAM (Cell 13)

1. Create sub-model: [input] β†’ [last_conv_layer_output, model_predictions]
2. Forward pass with GradientTape
3. Compute gradients: d(predicted_class_score) / d(last_conv_layer_output)
4. Global average pooling of gradients β†’ per-channel importance weights
5. Weighted combination: heatmap = conv_output @ pooled_grads
6. ReLU activation (keep only positive influence) + normalize to [0,1]
7. Resize to input dimensions, apply JET colormap, overlay with alpha=0.4

Region-Focused Grad-CAM β€” Novel Enhancement (Cell 14)

# After standard Grad-CAM computation, apply sacroiliac joint ROI mask:
roi_y_start = int(height Γ— 0.5)   # Focus on lower 50% of image
roi_y_end   = int(height Γ— 0.9)   # Down to 90% (avoiding edge)
roi_x_start = int(width Γ— 0.3)    # Central 40% horizontally
roi_x_end   = int(width Γ— 0.7)

mask = np.zeros_like(heatmap)
mask[roi_y_start:roi_y_end, roi_x_start:roi_x_end] = 1.0
heatmap_focused = heatmap Γ— mask
heatmap_focused = heatmap_focused / (max(heatmap_focused) + 1e-8)

Target Conv Layer: conv2d_49 (last convolutional layer in the Simple CNN)

Clinical Value: Standard Grad-CAM may highlight any discriminative region including background artifacts. The focused variant ensures explanations align with the sacroiliac joint area β€” where radiologists actually look for AS signs like bone marrow edema, erosion, and ankylosis.

8End-to-End Prediction Pipeline

πŸ“€ Upload MRI
β†’
πŸ”„ Preprocess
β†’
πŸ” Segment
β†’
🧠 Classify
β†’
πŸ”₯ Grad-CAM
β†’
πŸ’Ύ Save to DB
β†’
πŸ“Š Display

Detailed Flow (from predict.py and app.py):

  1. Upload (app.py) β€” User uploads PNG/JPG/JPEG via Flask form β†’ saved to uploads/
  2. Preprocessing (predict.py)
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (256, 256))
    img = img / 255.0
    img = np.expand_dims(img, axis=(0, -1))  # β†’ shape (1, 256, 256, 1)
  3. Segmentation β€” Selected model(s) predict mask β†’ threshold at 0.5 β†’ overlay on original (red highlight, alpha=0.3)
  4. Classification β€” Selected model(s) predict dual outputs:
    • binary_output: softmax(2) β†’ argmax β†’ AS Positive/Negative + confidence %
    • stage_output: softmax(4) β†’ argmax β†’ Normal/Early/Moderate/Advanced + confidence %
  5. Grad-CAM β€” Focused heatmap β†’ JET colormap β†’ overlay on MRI (alpha=0.4)
  6. Storage β€” All outputs saved to static/results/{uuid}/, metadata to SQLite DB
  7. Display β€” Results page shows original, mask, overlay, Grad-CAM, predictions with confidence

Available Models at Runtime

KeyDisplay NameTypeFile
attention_unetAttention U-NetSegmentationattention_unet_model.keras
unet_bestU-Net BestSegmentationunet_best_model.keras
cnn_classifierCNN ClassifierClassifiercnn_classifier_model.keras
classifier_bestBest ClassifierClassifierclassifier_best_model.keras
hybrid_cnn_transformerHybrid CNN-TransformerClassifierhybrid_cnn_transformer_model.keras

Users can select multiple models simultaneously for side-by-side comparison.

9Comparison with Existing Systems

Key Differentiators: Unlike existing systems that are (1) binary-only, (2) require manual ROI, or (3) lack explainability, our framework provides all three: automatic segmentation + stage-wise classification + visual Grad-CAM evidence.
StudyDatasetModelTaskPerformanceLimitations
Lee et al. (2023)296 patients, 4,746 slicesFaster R-CNN + VGG-19Detect sacroiliitisAUROC ~0.83, Sens ~0.725, Spec ~0.936Binary only, no staging
Xie et al. (2025)1,294 patients, 4 centersResNet50 + KNN-11axSpA classificationAUC ~0.912, Acc ~86.9%Manual ROI, binary only
Zhou et al. (2024)485 patients3D U-Net + ResNet50 + ensembleDiagnose sacroiliitisAUC ~0.910, Acc ~85.6%No stage classification
Bordner et al. (2023)362 images (DESIR)Deep LearningBME & sacroiliitisβ€”Binary only
Deep Learning Chris (2023)326 axSpA + 63 NSBPAttention U-NetSegmentation/DetectionAUC ~0.96, Sens ~0.90, Spec ~0.93Segmentation only, no classification
Kumar et al. (2025)β€”Transfer LearningAS detectionβ€”Binary only, no staging
Kocaoglu (2025)β€”FPGA DLAS detectionβ€”Hardware-specific
Manikandan et al. (2023)β€”ASNETAS diagnosisβ€”No visual explainability
Our Framework900 imagesAtt. U-Net + CNN + Hybrid + Grad-CAMSeg + Binary + StagingBinary: 96.67%, Stage: 82.22%, IoU: 0.5652Fully automated, explainable, web-deployed

Feature Comparison Matrix

FeatureMost Existing SystemsOur Framework
ROI ExtractionManual / Semi-automaticβœ… Fully automatic (Attention U-Net)
Classification TypeBinary only (AS+/AS-)βœ… Binary + 4-stage severity
ExplainabilityProbability scores onlyβœ… Region-focused Grad-CAM heatmaps
Multi-model ComparisonSingle modelβœ… 5 selectable models
Web DeploymentResearch code onlyβœ… Flask app with auth + history
Clinician-Verified Labelsβœ… Radiologist annotations⚠️ Feature-based (synthetic)
Dataset Size296–1,294 patients⚠️ 900 images (single source)
Clinical Validationβœ… Some prospective studies❌ Not yet clinically validated

10System Requirements

Hardware Requirements

ComponentMinimumRecommended
CPUIntel i5 / AMD Ryzen 5Intel i7 / AMD Ryzen 7 (multi-core)
GPUNot required (CPU inference)NVIDIA with CUDA (GTX 1660+ / RTX 2060+)
RAM4 GB16 GB (32 GB for training)
Storage1 GB (models only)500 GB SSD (models + dataset + outputs)

Software Requirements

SoftwareVersion
OSWindows 10/11, Linux (Ubuntu 20.04+), macOS
Python3.8+ (3.12 recommended)
TensorFlow2.19.0
Keras3.10.0
OpenCV4.12.0
FlaskLatest
NumPy / Pandas / Scikit-learnLatest compatible
Docker (optional)For containerized deployment

Model File Sizes

Model FileSize
attention_unet_model.keras94.6 MB
unet_best_model.keras94.6 MB
cnn_classifier_model.keras202.5 MB
classifier_best_model.keras162.9 MB
hybrid_cnn_transformer_model.keras26.5 MB
Total~581.1 MB

Models stored in models/ directory. Git-ignored due to size; must be transferred separately for deployment.

11Web Application Architecture

ComponentTechnologyDetails
BackendFlask (Python)app.py β€” routes, auth, file handling
ML InferenceTensorFlow/Keraspredict.py β€” preprocessing, prediction, Grad-CAM
Model ManagementCustommodel_loader.py β€” lazy loading, caching, model registry
DatabaseSQLitedatabase.py β€” users + predictions tables
AuthSession-basedSHA-256 password hashing, Flask sessions
FrontendHTML/CSS templates6 pages: index, login, signup, upload, results, history
File StorageLocal filesystemuploads/ for input, static/results/{uuid}/ for outputs
Port8000Configurable
ContainerizationDockerDockerfile + docker-compose.yml available

Database Schema

CREATE TABLE users (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    email TEXT UNIQUE NOT NULL,
    password TEXT NOT NULL,       -- SHA-256 hash
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

CREATE TABLE predictions (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    uuid TEXT UNIQUE,             -- Unique run identifier
    user_id INTEGER NOT NULL,
    image_path TEXT NOT NULL,
    as_status TEXT NOT NULL,      -- "AS Positive" / "AS Negative"
    stage TEXT NOT NULL,           -- "Normal" / "Early" / "Moderate" / "Advanced"
    confidence REAL,               -- Binary confidence %
    stage_confidence REAL,         -- Stage confidence %
    segmentation_mask TEXT,        -- Path to predicted mask image
    gradcam_overlay TEXT,          -- Path to Grad-CAM overlay image
    segmentation_overlay TEXT,     -- Path to segmentation overlay image
    model_results TEXT,            -- JSON of all model outputs
    timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    FOREIGN KEY (user_id) REFERENCES users(id)
);

Source Files

FileLinesPurpose
app.py~173Flask routes: signup, login, upload, predict, history, results
predict.py~202Core ML: preprocessing, segmentation, classification, Grad-CAM
model_loader.py~57Model registry, lazy loading, caching
database.py~145SQLite operations: init, CRUD for users and predictions
ankylosings.ipynb18 cellsFull training notebook: data loading β†’ training β†’ evaluation β†’ Grad-CAM

12Project Status & Areas for Improvement

Current Status

ComponentStatusDetails
Dataset CollectionComplete900 images curated from Kaggle lumbar dataset
Feature-Based LabelingCompleteBalanced labels based on image characteristics
Attention U-Net (Segmentation)Complete89.35% val accuracy, IoU 0.5652
Simple CNN ClassifierComplete96.67% binary, 82.22% stage β€” deployed as primary
Hybrid CNN-Transformer v1Needs Improvement54.44% binary β€” frozen backbone limits performance
Hybrid CNN-Transformer v2Needs Improvement52.22% binary β€” frozen backbone limits performance
Standard Grad-CAMCompleteWorking for CNN classifier
Focused Grad-CAMCompleteSacroiliac joint ROI-focused variant
Flask Web AppCompleteAuth, upload, predict, history β€” all working
Docker DeploymentCompleteDockerfile + docker-compose.yml ready
Early Stage DetectionNeeds Work0% recall for Stage 1 (Early) β€” class imbalance issue
Clinical ValidationNot StartedNo radiologist-verified labels or prospective trials
Transformer Attention MapsNot StartedVisualizing Transformer self-attention patterns

Known Issues & Improvements Needed

IssueImpactProposed Solution
Frozen EfficientNetB0 BackboneHybrid models stuck at ~52% (near random)Unfreeze last 20-30 layers of EfficientNetB0, use differential learning rates (base: 1e-5, head: 1e-3)
Early Stage (Stage 1) 0% RecallModel cannot detect early ASOversample early-stage images (SMOTE/augmentation), use focal loss instead of CE, add class weights
Synthetic LabelsLabels don't reflect true clinical ground truthPartner with radiology department for expert annotations on subset; validate feature-label correlation
Single MRI ModalityMisses clinical contextAdd HLA-B27 status, CRP levels, patient age/gender as additional inputs (multimodal fusion)
2D Slice Analysis OnlyMisses inter-slice continuityImplement 3D CNN or use multiple adjacent slices as input channels
Small Dataset (900 images)Limited generalizationExpand to multi-center datasets (DESIR, ASAS cohorts), apply heavy augmentation
U-Net IoU = 0.5652Moderate segmentation qualityUse Dice loss + BCE combined, increase training data, add boundary-aware loss

13Expected Outcomes & Future Scope

Expected Project Outcomes

  • A complete Explainable AI system for multi-stage AS classification using sacroiliac joint MRI, deployable as a web application.
  • A hybrid CNN-Transformer model trained to capture both local spatial patterns (bone erosion, sclerosis) and global structural relationships.
  • Visual explanations (Grad-CAM heatmaps) highlighting the anatomical regions influencing predictions, enabling radiologists to verify AI decisions.
  • A practical, transparent diagnostic tool that bridges the gap between AI accuracy and clinical usability, supporting faster and more confident AS diagnosis.

Future Scope

Short-Term Improvements

  • Unfreeze Hybrid Model: Fine-tune EfficientNetB0 backbone layers for medical domain adaptation
  • Address Class Imbalance: Implement focal loss, SMOTE, or weighted sampling for early-stage detection
  • Improve Segmentation: Use combined Dice + BCE loss to boost IoU beyond 0.6
  • Add Confusion Matrix Visualization: Interactive confusion matrices in the web app
  • Batch Prediction: Enable processing multiple MRI scans in one upload

Medium-Term Goals

  • Multimodal Integration: Combine MRI data with clinical variables (HLA-B27, CRP, age, gender)
  • 3D Analysis: Process volumetric MRI data to capture inter-slice relationships
  • Multi-Center Validation: Train on datasets from multiple hospitals to reduce institutional bias
  • Transformer Attention Maps: Extract and visualize self-attention patterns alongside Grad-CAM
  • Edge Deployment: Optimize models (quantization, pruning) for inference on mobile/tablet devices

Long-Term Vision

  • Clinical Trials: Prospective studies measuring real-world impact on radiologist decision-making
  • Telemedicine Integration: Deploy on cloud platforms for remote screening in underserved areas
  • Disease Progression Tracking: Longitudinal analysis comparing scans over time for treatment response
  • Multi-Disease Extension: Extend to other spondyloarthritis conditions and spinal pathologies
  • DICOM Integration: Direct integration with hospital PACS systems for seamless workflow

14Glossary of Technical Terms

Definitions of key machine learning and medical imaging terms used in this documentation.

Core ML Concepts

  • Deep Learning (DL): A subset of machine learning using neural networks with many layers (deep) to learn complex patterns from data.
  • Convolutional Neural Network (CNN): A type of deep learning model specifically designed for image analysis. It uses "filters" to automatically detect features like edges, textures, and shapes.
  • Transformer: A newer deep learning architecture that uses "attention mechanisms" to weigh the importance of different parts of the input data. Originally for text, now used for images (Vision Transformers).
  • Epoch: One complete pass of the entire training dataset through the model during training.
  • Batch Size: The number of training examples used in one iteration to update the model's internal parameters.
  • Overfitting: When a model learns the training data too well, including noise, and performs poorly on new, unseen data.
  • Fine-Tuning: Taking a pre-trained model (e.g., trained on millions of general images) and training it further on a specific dataset (e.g., MRI scans) to adapt it to a new task.

Metrics & Evaluation

  • Accuracy: The percentage of correct predictions made by the model. (Correct / Total).
  • Precision: "Quality" of positivity. Of all images predicted as Positive, how many were actually Positive?
  • Recall (Sensitivity): "Quantity" of positivity. Of all actual Positive images, how many did the model correctly find?
  • F1-Score: The harmonic mean of Precision and Recall. A balanced metric useful when classes are uneven.
  • IoU (Intersection over Union): A metric for segmentation. Measures the overlap between the predicted mask and the ground truth mask. 0 = no overlap, 1 = perfect match.
  • Confusion Matrix: A table showing correct and incorrect predictions for each class, helping to see where the model is making mistakes.

Project-Specific Terms

  • Segmentation: The process of partitioning an image into different regions. Here, separating the sacroiliac joint from the background.
  • Classification: Categorizing an entire image into a class (e.g., AS Positive or Negative).
  • ROI (Region of Interest): A specific part of an image identified for further analysis (e.g., the joint area).
  • Grad-CAM: "Gradient-weighted Class Activation Mapping". A technique to visualize which parts of an image were most important for the model's decision (displayed as a heatmap).
  • Augmentation: Artificially increasing the training dataset size by creating modified versions of images (rotating, flipping, adding noise) to help the model generalize better.