English | 中文
CS-MDSS 是一个最小可行性产品(MVP)演示系统,展示了人工智能如何通过三种创新方法辅助皮肤病学的临床决策:
- 相似性检索(度量学习):使用学习到的嵌入向量查找相似的历史病例
- 因果推理:理解临床特征的变化如何影响预测结果
- 决策链可视化:通过 Grad-CAM 热力图解释AI推理过程
该系统旨在增强而非替代临床专业知识,提供透明、可解释的AI辅助。
- ResNet50 骨干网络:在 DermaMNIST 上微调的预训练特征提取器
- 分类头:7类皮肤病变分类
- 嵌入头:用于相似性搜索的128维向量
- 度量学习:结合交叉熵 + 三元组边界 + 余弦相似度损失
- Faiss驱动的检索:快速近似最近邻搜索
- Top-K相似病例:历史病例比较以提供临床背景
- 视觉相似性:"人脸识别"范式应用于医学图像
- Grad-CAM热力图:可视化驱动预测的图像区域
- 反事实分析:针对临床特征的"假设"场景
- 风险评估:自动化严重程度分类
- 基于Streamlit的UI:专业、直观的界面
- 实时分析:即时预测和解释
- 临床建议:AI生成的决策支持
ws/
├── src/
│ ├── model.py # 神经网络架构和损失函数
│ ├── utils.py # 数据加载、Faiss索引、反事实逻辑
│ ├── train.py # PyTorch Lightning训练脚本
│ ├── quick_demo.py # 快速演示脚本(生成模拟数据)
│ └── app.py # Streamlit Web界面
├── test/ # 单元测试文件
├── models/ # 保存的模型检查点、Faiss索引及因果模型缓存
├── data/ # DermaMNIST数据集(自动下载)
├── logs/ # TensorBoard训练日志
├── pyproject.toml # 依赖项(uv包管理器)
└── README.md
- Python 3.13+
- uv 包管理器
# 克隆仓库
cd ws
# 使用uv安装依赖
uv sync
# 或者安装特定包
uv add torch torchvision pytorch-lightning monai medmnist faiss-cpu shap streamlit pandas numpy matplotlib plotly opencv-python seaborn如果你希望跳过训练过程直接体验系统界面,可以使用 quick_demo.py 生成模拟的模型权重和索引数据:
# 生成演示资源(模拟权重和合成数据)
uv run python src/quick_demo.py
# 启动 Streamlit 应用
uv run streamlit run src/app.py# 使用默认参数训练
uv run python src/train.py
# 使用自定义参数训练
uv run python src/train.py \
--batch-size 64 \
--max-epochs 50 \
--learning-rate 1e-4 \
--image-size 224
# 仅快速测试运行
uv run python src/train.py --fast-dev-run说明:执行
train.py不仅会训练深度学习模型,还会自动训练并保存基于 DoWhy 的因果推断模型 (models/causal_analyzer.pkl)。该过程会预计算各临床特征的平均干预效应(ATE),以便在网页演示中实现实时的反事实推理。
训练参数:
| 参数 | 默认值 | 描述 |
|---|---|---|
--batch-size |
32 | 训练批次大小 |
--max-epochs |
30 | 最大训练轮数 |
--learning-rate |
1e-4 | 初始学习率 |
--image-size |
224 | 输入图像大小(28或224) |
--embedding-dim |
128 | 嵌入向量维度 |
--ce-weight |
1.0 | 交叉熵损失权重 |
--triplet-weight |
0.5 | 三元组损失权重 |
--cosine-weight |
0.3 | 余弦相似度损失权重 |
--index-samples |
2000 | 要索引的样本数量 |
确保代码模块功能正常:
uv run python -m unittest discover test# 启动Streamlit应用
uv run streamlit run src/app.py
# 或使用自定义端口
uv run streamlit run src/app.py --server.port 8501应用程序将在浏览器中打开,地址为 http://localhost:8501。
本项目使用来自 MedMNIST v2 集合的 DermaMNIST 数据集。
| 类别 | 标签 | 描述 |
|---|---|---|
| 0 | akiec | 光化性角化病 |
| 1 | bcc | 基底细胞癌 |
| 2 | bkl | 良性角化病 |
| 3 | df | 皮肤纤维瘤 |
| 4 | mel | 黑色素瘤 |
| 5 | nv | 黑素细胞痣 |
| 6 | vasc | 血管病变 |
数据集统计:
- 训练集:7,007张图像
- 验证集:1,003张图像
- 测试集:2,005张图像
- 图像尺寸:224×224(RGB)
与传统分类不同,CS-MDSS学习一个语义嵌入空间,其中相似的病变聚集在一起。这使得:
- 少样本学习:用有限的样本识别罕见病症
- 基于案例的推理:通过相似的历史病例解释预测
- 迁移学习:嵌入向量可以泛化到未见过的病变类型
损失函数:
L_total = λ_ce × L_CrossEntropy + λ_triplet × L_TripletMargin + λ_cosine × L_CosineSimilarity
系统集成 DoWhy 因果推理库,通过结构因果模型(SCM)模拟因果干预以回答"假设"问题:
- "如果病变是对称的会怎样?"
- "20岁的年龄差异会如何影响诊断?"
- "如果边界是规则的会怎样?"
这为临床医生提供了对特征重要性的直观理解。
梯度加权类激活映射突出显示诊断相关区域:
- 视觉注意力:模型在看哪里?
- 临床相关性:高亮区域是否符合已知的诊断标准?
- 信任校准:帮助临床医生评估AI可靠性
CSMDSSEncoder
├── 骨干网络:ResNet50(ImageNet预训练)
│ └── 输出:2048维特征向量
├── 分类头
│ ├── Dropout(0.3)
│ ├── Linear(2048 → 512) + ReLU
│ ├── Dropout(0.3)
│ └── Linear(512 → 7) # 7个类别
└── 嵌入头
├── Linear(2048 → 512) + ReLU
├── Dropout(0.3)
├── Linear(512 → 128)
└── L2归一化 # 单位球面
Streamlit界面包含:
- 图像上传组件
- Top-K检索滑块
- 置信度阈值调整
- Grad-CAM开关
- 因果分析开关
- 系统状态指示器
第一行:相似病例
- 查询图像显示
- Top-K相似历史病例
- 相似度分数和确诊诊断
第二行:预测分析
- 主要诊断及置信度
- 风险级别评估
- 概率分布图表
- Grad-CAM注意力热力图
第三行:决策支持
- 因果反事实探索器
- 临床建议
- 基于风险级别的行动项目
-
仅用于研究目的:该系统设计用于演示和研究。未获批准用于临床使用。
-
数据集局限性:DermaMNIST包含处理过的标准化图像,可能无法反映真实世界的临床摄影情况。
-
因果分析:本项目集成了 DoWhy 库进行因果推理。由于 DermaMNIST 数据集缺乏真实的临床元数据,系统目前使用基于皮肤病学领域知识(如 ABCD 规则)生成的合成数据来训练因果模型,以演示完整的反事实分析流程。
由于 MedMNIST 版数据集不含元数据,当前的因果分析模块是通过领域知识(Domain Knowledge)模拟出来的(
def _generate_domain_knowledge_data(self, n_samples=2000))。它利用DoWhy学习的是皮肤病学中公认的规律(例如:基底细胞癌在中老年人中更常见),而不是从当前上传的具体病人的真实元数据中读取的。
- 模型性能:结果高度依赖于训练数据质量,可能无法泛化到所有人群或成像条件。
- Swin Transformer骨干网络选项
- SHAP集成用于特征重要性
- 真实因果模型(DoWhy、CausalML)
- 带病变分割的多任务学习
- 不确定性量化(MC Dropout、集成)
- DICOM图像支持
- 临床验证研究
-
Yang, J., et al. "MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification." Scientific Data, 2023.
-
Schroff, F., et al. "FaceNet: A Unified Embedding for Face Recognition and Clustering." CVPR, 2015.
-
Selvaraju, R.R., et al. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." ICCV, 2017.
-
Pearl, J. "Causality: Models, Reasoning, and Inference." Cambridge University Press, 2009.
本项目采用MIT许可证 - 详情请参阅LICENSE文件。
欢迎贡献!请随时提交Pull Request。
CS-MDSS | 因果-相似性医疗决策支持系统
通过可解释的机器学习连接人工智能与临床实践