Skip to content

iridite/Med-ML

Repository files navigation

CS-MDSS: 因果-相似性医疗决策支持系统

English | 中文

🏥 基于AI的皮肤病学分析演示系统

结合深度学习、度量学习和因果推理

Python 3.13+ PyTorch Streamlit License: MIT


📋 概述

CS-MDSS 是一个最小可行性产品(MVP)演示系统,展示了人工智能如何通过三种创新方法辅助皮肤病学的临床决策:

  1. 相似性检索(度量学习):使用学习到的嵌入向量查找相似的历史病例
  2. 因果推理:理解临床特征的变化如何影响预测结果
  3. 决策链可视化:通过 Grad-CAM 热力图解释AI推理过程

该系统旨在增强而非替代临床专业知识,提供透明、可解释的AI辅助。

🎯 核心特性

🔬 多头深度学习架构

  • ResNet50 骨干网络:在 DermaMNIST 上微调的预训练特征提取器
  • 分类头:7类皮肤病变分类
  • 嵌入头:用于相似性搜索的128维向量
  • 度量学习:结合交叉熵 + 三元组边界 + 余弦相似度损失

🔍 相似性搜索引擎

  • Faiss驱动的检索:快速近似最近邻搜索
  • Top-K相似病例:历史病例比较以提供临床背景
  • 视觉相似性:"人脸识别"范式应用于医学图像

🧠 可解释性与因果分析

  • Grad-CAM热力图:可视化驱动预测的图像区域
  • 反事实分析:针对临床特征的"假设"场景
  • 风险评估:自动化严重程度分类

🖥️ 交互式临床界面

  • 基于Streamlit的UI:专业、直观的界面
  • 实时分析:即时预测和解释
  • 临床建议:AI生成的决策支持

📁 项目结构

ws/
├── src/
│   ├── model.py          # 神经网络架构和损失函数
│   ├── utils.py          # 数据加载、Faiss索引、反事实逻辑
│   ├── train.py          # PyTorch Lightning训练脚本
│   ├── quick_demo.py     # 快速演示脚本(生成模拟数据)
│   └── app.py            # Streamlit Web界面
├── test/                 # 单元测试文件
├── models/               # 保存的模型检查点、Faiss索引及因果模型缓存
├── data/                 # DermaMNIST数据集(自动下载)
├── logs/                 # TensorBoard训练日志
├── pyproject.toml        # 依赖项(uv包管理器)
└── README.md

🚀 快速开始

前置要求

  • Python 3.13+
  • uv 包管理器

安装

# 克隆仓库
cd ws

# 使用uv安装依赖
uv sync

# 或者安装特定包
uv add torch torchvision pytorch-lightning monai medmnist faiss-cpu shap streamlit pandas numpy matplotlib plotly opencv-python seaborn

快速演示(无需训练)

如果你希望跳过训练过程直接体验系统界面,可以使用 quick_demo.py 生成模拟的模型权重和索引数据:

# 生成演示资源(模拟权重和合成数据)
uv run python src/quick_demo.py

# 启动 Streamlit 应用
uv run streamlit run src/app.py

训练模型

# 使用默认参数训练
uv run python src/train.py

# 使用自定义参数训练
uv run python src/train.py \
    --batch-size 64 \
    --max-epochs 50 \
    --learning-rate 1e-4 \
    --image-size 224

# 仅快速测试运行
uv run python src/train.py --fast-dev-run

说明:执行 train.py 不仅会训练深度学习模型,还会自动训练并保存基于 DoWhy 的因果推断模型 (models/causal_analyzer.pkl)。该过程会预计算各临床特征的平均干预效应(ATE),以便在网页演示中实现实时的反事实推理。

训练参数:

参数 默认值 描述
--batch-size 32 训练批次大小
--max-epochs 30 最大训练轮数
--learning-rate 1e-4 初始学习率
--image-size 224 输入图像大小(28或224)
--embedding-dim 128 嵌入向量维度
--ce-weight 1.0 交叉熵损失权重
--triplet-weight 0.5 三元组损失权重
--cosine-weight 0.3 余弦相似度损失权重
--index-samples 2000 要索引的样本数量

运行测试

确保代码模块功能正常:

uv run python -m unittest discover test

运行演示

# 启动Streamlit应用
uv run streamlit run src/app.py

# 或使用自定义端口
uv run streamlit run src/app.py --server.port 8501

应用程序将在浏览器中打开,地址为 http://localhost:8501

📊 数据集

本项目使用来自 MedMNIST v2 集合的 DermaMNIST 数据集。

类别 标签 描述
0 akiec 光化性角化病
1 bcc 基底细胞癌
2 bkl 良性角化病
3 df 皮肤纤维瘤
4 mel 黑色素瘤
5 nv 黑素细胞痣
6 vasc 血管病变

数据集统计:

  • 训练集:7,007张图像
  • 验证集:1,003张图像
  • 测试集:2,005张图像
  • 图像尺寸:224×224(RGB)

🔬 技术创新

1. 医学影像的度量学习

与传统分类不同,CS-MDSS学习一个语义嵌入空间,其中相似的病变聚集在一起。这使得:

  • 少样本学习:用有限的样本识别罕见病症
  • 基于案例的推理:通过相似的历史病例解释预测
  • 迁移学习:嵌入向量可以泛化到未见过的病变类型

损失函数:

L_total = λ_ce × L_CrossEntropy + λ_triplet × L_TripletMargin + λ_cosine × L_CosineSimilarity

2. 因果反事实分析

系统集成 DoWhy 因果推理库,通过结构因果模型(SCM)模拟因果干预以回答"假设"问题:

  • "如果病变是对称的会怎样?"
  • "20岁的年龄差异会如何影响诊断?"
  • "如果边界是规则的会怎样?"

这为临床医生提供了对特征重要性的直观理解。

3. Grad-CAM可解释性

梯度加权类激活映射突出显示诊断相关区域:

  • 视觉注意力:模型在看哪里?
  • 临床相关性:高亮区域是否符合已知的诊断标准?
  • 信任校准:帮助临床医生评估AI可靠性

📈 模型架构

CSMDSSEncoder
├── 骨干网络:ResNet50(ImageNet预训练)
│   └── 输出:2048维特征向量
├── 分类头
│   ├── Dropout(0.3)
│   ├── Linear(2048 → 512) + ReLU
│   ├── Dropout(0.3)
│   └── Linear(512 → 7)  # 7个类别
└── 嵌入头
    ├── Linear(2048 → 512) + ReLU
    ├── Dropout(0.3)
    ├── Linear(512 → 128)
    └── L2归一化  # 单位球面

🖥️ 用户界面

Streamlit界面包含:

侧边栏

  • 图像上传组件
  • Top-K检索滑块
  • 置信度阈值调整
  • Grad-CAM开关
  • 因果分析开关
  • 系统状态指示器

主面板

第一行:相似病例

  • 查询图像显示
  • Top-K相似历史病例
  • 相似度分数和确诊诊断

第二行:预测分析

  • 主要诊断及置信度
  • 风险级别评估
  • 概率分布图表
  • Grad-CAM注意力热力图

第三行:决策支持

  • 因果反事实探索器
  • 临床建议
  • 基于风险级别的行动项目

⚠️ 局限性与免责声明

  1. 仅用于研究目的:该系统设计用于演示和研究。未获批准用于临床使用。

  2. 数据集局限性:DermaMNIST包含处理过的标准化图像,可能无法反映真实世界的临床摄影情况。

  3. 因果分析:本项目集成了 DoWhy 库进行因果推理。由于 DermaMNIST 数据集缺乏真实的临床元数据,系统目前使用基于皮肤病学领域知识(如 ABCD 规则)生成的合成数据来训练因果模型,以演示完整的反事实分析流程。

由于 MedMNIST 版数据集不含元数据,当前的因果分析模块是通过领域知识(Domain Knowledge)模拟出来的(def _generate_domain_knowledge_data(self, n_samples=2000))。它利用 DoWhy 学习的是皮肤病学中公认的规律(例如:基底细胞癌在中老年人中更常见),而不是从当前上传的具体病人的真实元数据中读取的。

  1. 模型性能:结果高度依赖于训练数据质量,可能无法泛化到所有人群或成像条件。

🔮 未来方向

  • Swin Transformer骨干网络选项
  • SHAP集成用于特征重要性
  • 真实因果模型(DoWhy、CausalML)
  • 带病变分割的多任务学习
  • 不确定性量化(MC Dropout、集成)
  • DICOM图像支持
  • 临床验证研究

📚 参考文献

  1. Yang, J., et al. "MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification." Scientific Data, 2023.

  2. Schroff, F., et al. "FaceNet: A Unified Embedding for Face Recognition and Clustering." CVPR, 2015.

  3. Selvaraju, R.R., et al. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." ICCV, 2017.

  4. Pearl, J. "Causality: Models, Reasoning, and Inference." Cambridge University Press, 2009.

📄 许可证

本项目采用MIT许可证 - 详情请参阅LICENSE文件。

🤝 贡献

欢迎贡献!请随时提交Pull Request。


CS-MDSS | 因果-相似性医疗决策支持系统

通过可解释的机器学习连接人工智能与临床实践

About

No description or website provided.

Topics

Resources

Stars

Watchers

Forks

Contributors 2

  •  
  •