可解释人工智能及其研究-SHAP算法应用篇


SHAP算法在成人人口普查数据的应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# 使用的是shap = 0.48.0 版本
# 使用的是xgboost= 3.0.4 版本
# 使用的是numpy = 1.26.4 版本
# 使用的是matplotlib = 3.10.5版本
# 使用的是sklearn= 1.6.1版本
# 使用的是lightgbm = 4.6.0版本
# 使用的是pandas = 2.2.3版本
# 查看本地全部包和对应版本命令:pip list
import shap
import xgboost
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import lightgbm as lgb
import pandas as pd


# 如果 display 为 True,则 X 包含不包含“Education”、“Target”和“fnlwgt”目标列和冗余列的原始数据。否则,X 包含没有“Target”和“fnlwgt”列的已处理数据。
# X,X_display都是32561 rows x 12 columns的数据
# y,y_display都是32561 rows x 1 columns的数据,数组返回的“Target”列,这里可能有bug.
X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)


# XGBClassifier是sklearn中XGBoost分类器的实现,集成多个决策树来改善模型预测精度
# eval_metric确认任务类型为分类任务,使用负对数似然函数'mlogloss'
# 模型初始化并做模型训练
model = xgboost.XGBClassifier(eval_metric='mlogloss').fit(X, y)

# 这是模型自带的特征重要性,可以查看特征重要性,plot_importance()函数基于XGBoost模型训练后计算的特征重要性分数来绘制图表,
# 默认使用importance_type = 'weight',特征在树中的平均权重做为重要性的度量.
# xgboost.plot_importance(model)


# 创建Explainer解释器并利用训练数据X计算shap值
# shap_values的数据类型为 <class 'numpy.ndarray'>
# shap_values2的数据类型为<class 'shap._explanation.Explanation'>
explainer = shap.Explainer(model)
shap_values = explainer.shap_values(X) # 第一个特征的shap values:shap_values[0][0]
shap_values2 = explainer(X)

'''
# 全局条形图:summary plot是针对全部样本预测的解释,是取每个特征的shap values的平均绝对值来获得标准条形图,这个其实就是全局重要度
# Summary_plot 为每一个样本绘制其每个特征的Shapley value,它说明哪些特征最重要,以及它们对数据集的影响范围。
# 另一种是通过散点简单绘制每个样本的每个特征的shap values,通过颜色可以看到特征值大小与预测影响之间的关系,同时展示其特征值分布
# 两个图都可以看到Relationship全局重要度是最高的,其次是Age。第一个图可以看到各个特征重要度的相对关系,虽然Capital Gain是第三,但是重要度只有Relationship的60%,
# shap.summary_plot使用的是shap_values的数据类型为 <class 'numpy.ndarray'>
# shap.plots.bar使用的是shap_values2的数据类型为<class 'shap._explanation.Explanation'>
#shap.plots.bar(shap_values2, max_display = 12)
shap.summary_plot(shap_values, X, plot_type="bar")
#按照计算公式,可得其数值表格化如下
feature_importance = pd.DataFrame()
feature_importance['feature'] = X.columns
feature_importance['importance'] = np.abs(shap_values).mean(0)
feature_importance.sort_values('importance', ascending=False)
print(feature_importance)
'''



# y 轴上的位置由特征确定,x 轴上的位置由每 Shapley value 确定。颜色表示特征值(红色高,蓝色低),颜色使我们能够匹配特征值的变化如何影响风险的变化。
# 重叠点在 y 轴方向抖动,因此我们可以了解每个特征的 Shapley value分布,并且这些特征是根据它们的重要性排序的。
# 由颜色深浅则可以看到Relationship和Age都是值越大,个人年收入超过5万美元的可能性越大,平均而言年龄是最重要的特征,与年轻(蓝色)人相比,收入超过 5 万美元的可能性较小
# 也使用可以shap.plots.beeswarm绘制一样的蜂群图,在显示数据集中的TOP特征如何影响模型输出的信息密集摘要。
# 点的 x 位置由该特征的 SHAP 值 ( shap_values.value[instance,feature]) 确定,并且点沿每个特征行“堆积”以显示密度。
#shap.summary_plot(shap_values, X)
#shap.plots.beeswarm(shap_values2, order=np.abs(shap_values).mean(0).argsort()[::-1], max_display = 12)







# 可以使用matplotlib构建自定义颜色图,color默认使用shap.plots.colors.red_blue的颜色
#shap.plots.beeswarm(shap_values2, color=plt.get_cmap("cool"))

# 局部特征重要性条形图, 条形是每个特征的 SHAP 值。最左侧的灰色显示数值是特征值
#shap.plots.bar(shap_values2[1], max_display = 12)


# 绘制男性和女性特征重要性的全局摘要
#sex = ["Women" if shap_values2[i,"Sex"].data == 0 else "Men" for i in range(shap_values2.shape[0])]
#shap.plots.bar(shap_values2.cohorts(sex).abs.mean(0))


# 使用Explanation对象的自动群组功能来创建一个群组,调用Explanation.cohorts(N)将创建N个队列,
# 使用 sklearn DecisionTreeRegressor 最佳地分离实例的 SHAP 值,图例中的方括号显示的是每个队列中的实例数
#shap.plots.bar(shap_values2.cohorts(2).abs.mean(0), max_display = 12)


# 拟合相对于目标变量 y 的特征 X 的分层聚类模型,可以在SHAP中通过模型损失比较来测量特征冗余
# 计算聚类并传递给条形图,就可以同时可视化特征冗余结构和特征重要性。默认只会显示距离 < 0.5 的聚类部分
#clustering = shap.utils.hclust(X, y)
#shap.plots.bar(shap_values2, clustering=clustering, max_display = 12)



# SHAP Partial dependence plot (PDP or PD plot) 依赖图显示了一个或两个特征对机器学习模型的预测结果的边际效应
# PDP 的一个假设是第一个特征与第二个特征不相关。如果违反此假设,则 PDP 计算的平均值将包括极不可能甚至不可能的数据点
# 参数interaction_index用于设置交互项,用于验证特征之间是否存在交互效应
# Dependence plot 是一个散点图,显示单个特征对整个数据集的影响。每个点都是来自数据集的单个预测(行)。
# x 轴是数据集中的实际值。(来自 X 矩阵,存储在 中shap_values.data)。
# y 轴是该特征的 SHAP 值(存储在 中shap_values.values),它表示该特征值对该预测的模型输出的改变程度
#先看某个特征是如何影响到模型预测结果的
#看到Age越大,个人年收入超过5万美元的可能性越大,但是在65岁后出现波动,80岁后可能性下降
#shap.dependence_plot('Age', shap_values, X, interaction_index=None)
#相同的图片画法
#plt.figure(figsize=(7.5, 5))
#plt.scatter(X['Age'], shap_values[:, 0], s=10, alpha=1)

# 一个特征是如何和另一个特征交互影响到模型预测结果的,这里以Age和Capital Gain为例子
# 散点还是由Age绘制,但是Age的预测其实是有其他特征相互作用的,散点图垂直分散就是由相互作用效应驱动,
# 所以用Capital Gain进行着色以突出显示可能的相互作用
shap.dependence_plot('Age', shap_values, X, display_features=X_display, interaction_index='Capital Gain')


#shap values绘制归因关系是有其他特征的相互作用使图垂直分散的,可以使用shap_interaction_values消除这种相互作用
shap_interaction_values = explainer.shap_interaction_values(X)
shap.dependence_plot(('Age', 'Age'), shap_interaction_values, X, interaction_index=None)
#相同的图片画法
plt.figure(figsize=(7.5, 5))
plt.scatter(X['Age'], shap_interaction_values[:, 0, 0], s=10, alpha=1)

# shap.plots.scatte图底部的浅灰色区域是显示数据值分布的直方图
# 在交互颜色方面,散点图则需要将整个 Explanation 对象传递给 color 参数
#shap_values2.display_data = X_display.values
#shap.plots.scatter(shap_values2[:, "Age"], color=shap_values2[:,"Workclass"])

'''
# SHAP force plot 提供了单一模型预测的可解释性,可用于误差分析,找到对特定实例预测的解释
# 如果不想用JS,需要在shap.force_plot传入matplotlib=True的参数。否则就需要使用shap.initjs()
# 模型输出值是-6.75,模型基值是-1.297,绘图箭头下方数字是此实例的特征值,将预测推高的特征用红色表示,将预测推低的特征用蓝色表示
# 箭头越长,特征对输出的影响越大。explainer.expected_value是解释模型的常数
#shap.initjs() # 初始化JavaScript库
#shap.force_plot(explainer.expected_value, shap_values[0,:], X_display.iloc[0,:])
# 或者
shap.force_plot(explainer.expected_value, shap_values[0,:], X_display.iloc[0,:], matplotlib=True)
#其数值表格化如下:
sample_0_shap = pd.DataFrame(X.iloc[0,:])
sample_0_shap.rename(columns={0: 'feature_value'}, inplace=True)
sample_0_shap['shap_value'] = shap_values[0]
sample_0_shap.sort_values('shap_value', ascending=False)
print(sample_0_shap)
'''



# explainer.shap_interaction_values是实现快速两两交互计算,将为每个预测返回一个矩阵,其中主要影响在对角线上,交互影响在对角线外
#shap.summary_plot(explainer.shap_interaction_values(X), X) # 第一个特征的shap interaction values:explainer.shap_interaction_values(X)[0][0]

# Decision plot决策图:SHAP 决策图显示复杂模型如何得出其预测(即模型如何做出决策)
# 决策图中间灰色垂直直线标记了模型的基础值,彩色线是预测,表示每个特征是否将输出值移动到高于或低于平均预测的值。特征值在预测线旁边以供参考。
# 从图的底部开始,预测线显示 SHAP value 如何从基础值累积到图顶部的模型最终分数
#expected_value = explainer.expected_value
#features = X.iloc[range(20)]# 限制20个样本
#shap_values = explainer.shap_values(features)[1]# 展示第一条样本
#features_display = X_display.loc[features.index]
#shap.decision_plot(expected_value, shap_values, features_display)


# 决策图支持将对link='logit'数几率转换为概率。
# 使用虚线样式highlight=misclassified突出显示一个错误分类的观察结果
'''
shap_values = explainer.shap_values(features)
y_pred = (shap_values.sum(1) + expected_value) > 0
misclassified = y_pred != y[:20]
shap.decision_plot(expected_value, shap_values,
features_display,
link='logit',
highlight=misclassified)

# 通过单独绘制来检查错误分类的观察结果
shap.decision_plot(expected_value, shap_values[misclassified],
features_display[misclassified],
link='logit',
highlight=0)

# 错误分类观察的力图
shap.force_plot(expected_value, shap_values[misclassified],
features_display[misclassified],
link='logit')
'''


# 瀑布图旨在显示单个预测的解释,因此将解释对象的单行作为输入。瀑布图从底部的模型输出的预期值开始,
# 每一行显示每个特征的是正(红色)或负(蓝色)贡献,即如何将值从数据集上的模型预期输出值推动到模型预测的输出值
# waterfall绘图显示单个样本数据
#shap.plots.waterfall(shap_values2[5], max_display = 12)
#shap.plots.scatter(shap_values2[:,"Relationship"])

参考资料:

[1] 用 SHAP 可视化解释机器学习模型的输出实用指南
[2] SHAP的理解与应用
[3]机器学习可解释性工具:SHAP

以后的研究注意方向:使用 GPU 加速 SHAP 解释机器学习模型预测


 上一篇
AI应用实践-Text2SQL(NL2SQL, Natural Language to SQL) AI应用实践-Text2SQL(NL2SQL, Natural Language to SQL)
AI应用实践-Text2SQL(NL2SQL, Natural Language to SQL),主要介绍Text2SQL技术的历史发展、关键技术环节、中外评测数据集、评分标准、学术界当前的研究情况、目前可以使用的开源工具和项目和个人设想等。
2025-09-17
下一篇 
可解释人工智能及其研究-SHAP算法说明篇 可解释人工智能及其研究-SHAP算法说明篇
可解释人工智能及其研究-SHAP算法说明篇,主要介绍SHAP可解释性算法,包括这个算法的历史、价值、解决问题、数据基础、计算解决方案、应用价值、SHAP开源算法库的使用(安装、API、解释器、可视化工具、使用注意事项等)。
2025-09-16
  目录