"""
#-*- coding: utf-8 -*-
年龄:0代表<=30,1代表31~40,2代表>40
收入:0代表高,1代表中,2代表低
学生:0代表否,1代表是
信誉:0代表中,1代表优
类别:C1代表是,C2代表否
"""
from matplotlib
.font_manager
import FontProperties
import matplotlib
.pyplot
as plt
from math
import log
import operator
import pickle
"""
函数说明:创建测试数据集
Parameters:
None
Returns:
dataSet - 数据集
labels - 分类属性
"""
def createDataSet():
dataSet
= [[0, 0, 0, 0, 'C2'],
[0, 0, 0, 1, 'C2'],
[1, 0, 0, 0, 'C1'],
[2, 1, 0, 0, 'C1'],
[2, 2, 1, 0, 'C1'],
[2, 2, 1, 1, 'C2'],
[1, 2, 1, 1, 'C1'],
[0, 1, 0, 0, 'C2'],
[0, 2, 1, 0, 'C1'],
[2, 1, 1, 0, 'C1']]
labels
= ['年龄', '收入', '学生', '信誉']
return dataSet
, labels
"""
函数说明:计算给定数据集的经验熵(香农熵)
Ent(D) = -SUM(kp*Log2(kp))
Parameters:
dataSet - 数据集
Returns:
shannonEnt - 经验熵(香农熵)
"""
def calcShannonEnt(dataSet
):
numEntires
= len(dataSet
)
labelCounts
= {}
for featVec
in dataSet
:
currentLabel
= featVec
[-1]
if currentLabel
not in labelCounts
.keys
():
labelCounts
[currentLabel
] = 0
labelCounts
[currentLabel
] += 1
shannonEnt
= 0.0
for key
in labelCounts
:
prob
= float(labelCounts
[key
]) / numEntires
shannonEnt
-= prob
* log
(prob
, 2)
return shannonEnt
"""
函数说明:按照给定特征划分数据集
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
values - 需要返回的特征的值
Returns:
None
"""
def splitDataSet(dataSet
, axis
, value
):
retDataSet
= []
for featVec
in dataSet
:
if featVec
[axis
] == value
:
reducedFeatVec
= featVec
[:axis
]
reducedFeatVec
.extend
(featVec
[axis
+ 1:])
retDataSet
.append
(reducedFeatVec
)
return retDataSet
"""
函数说明:选择最优特征
Gain(D,g) = Ent(D) - SUM(|Dv|/|D|)*Ent(Dv)
Parameters:
dataSet - 数据集
Returns:
bestFeature - 信息增益最大的(最优)特征的索引值
"""
def chooseBestFeatureToSplit(dataSet
):
numFeatures
= len(dataSet
[0]) - 1
baseEntropy
= calcShannonEnt
(dataSet
)
bestInfoGain
= 0.0
bestFeature
= -1
for i
in range(numFeatures
):
featList
= [example
[i
] for example
in dataSet
]
uniqueVals
= set(featList
)
newEntropy
= 0.0
for value
in uniqueVals
:
subDataSet
= splitDataSet
(dataSet
, i
, value
)
prob
= len(subDataSet
) / float(len(dataSet
))
newEntropy
+= prob
* calcShannonEnt
(subDataSet
)
infoGain
= baseEntropy
- newEntropy
print("第%d个特征的增益为%.3f" % (i
, infoGain
))
if (infoGain
> bestInfoGain
):
bestInfoGain
= infoGain
bestFeature
= i
return bestFeature
"""
函数说明:统计classList中出现次数最多的元素(类标签)
服务于递归第两个终止条件
Parameters:
classList - 类标签列表
Returns:
sortedClassCount[0][0] - 出现次数最多的元素(类标签)
"""
def majorityCnt(classList
):
classCount
= {}
for vote
in classList
:
if vote
not in classCount
.keys
():
classCount
[vote
] = 0
classCount
[vote
] += 1
sortedClassCount
= sorted(classCount
.items
(), key
=operator
.itemgetter
(1), reverse
=True)
return sortedClassCount
[0][0]
"""
函数说明:创建决策树(ID3算法)
递归有两个终止条件:1、所有的类标签完全相同,直接返回类标签
2、用完所有标签但是得不到唯一类别的分组,即特征不够用,挑选出现数量最多的类别作为返回
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
Returns:
myTree - 决策树
"""
def createTree(dataSet
, labels
, featLabels
):
classList
= [example
[-1] for example
in dataSet
]
if classList
.count
(classList
[0]) == len(classList
):
return classList
[0]
if len(dataSet
[0]) == 1:
return majorityCnt
(classList
)
bestFeat
= chooseBestFeatureToSplit
(dataSet
)
bestFeatLabel
= labels
[bestFeat
]
featLabels
.append
(bestFeatLabel
)
myTree
= {bestFeatLabel
: {}}
featValues
= [example
[bestFeat
] for example
in dataSet
]
uniqueVals
= set(featValues
)
for value
in uniqueVals
:
del_bestFeat
= bestFeat
del_labels
= labels
[bestFeat
]
del (labels
[bestFeat
])
myTree
[bestFeatLabel
][value
] = createTree
(splitDataSet
(dataSet
, bestFeat
, value
), labels
, featLabels
)
labels
.insert
(del_bestFeat
, del_labels
)
return myTree
"""
函数说明:获取决策树叶子结点的数目
Parameters:
myTree - 决策树
Returns:
numLeafs - 决策树的叶子结点的数目
"""
def getNumLeafs(myTree
):
numLeafs
= 0
firstStr
= next(iter(myTree
))
secondDict
= myTree
[firstStr
]
for key
in secondDict
.keys
():
if type(secondDict
[key
]).__name__
== 'dict':
numLeafs
+= getNumLeafs
(secondDict
[key
])
else:
numLeafs
+= 1
return numLeafs
"""
函数说明:获取决策树的层数
Parameters:
myTree - 决策树
Returns:
maxDepth - 决策树的层数
"""
def getTreeDepth(myTree
):
maxDepth
= 0
firstStr
= next(iter(myTree
))
secondDict
= myTree
[firstStr
]
for key
in secondDict
.keys
():
if type(secondDict
[key
]).__name__
== 'dict':
thisDepth
= 1 + getTreeDepth
(secondDict
[key
])
else:
thisDepth
= 1
if thisDepth
> maxDepth
:
maxDepth
= thisDepth
return maxDepth
"""
函数说明:绘制结点
Parameters:
nodeTxt - 结点名
centerPt - 文本位置
parentPt - 标注的箭头位置
nodeType - 结点格式
Returns:
None
"""
def plotNode(nodeTxt
, centerPt
, parentPt
, nodeType
):
arrow_args
= dict(arrowstyle
="<-")
font
= FontProperties
(fname
=r
"C:\Windows\Fonts\simsun.ttc", size
=14)
createPlot
.ax1
.annotate
(nodeTxt
, xy
=parentPt
, xycoords
='axes fraction',
xytext
=centerPt
, textcoords
='axes fraction',
va
='center', ha
='center', bbox
=nodeType
,
arrowprops
=arrow_args
, FontProperties
=font
)
"""
函数说明:标注有向边属性值
Parameters:
cntrPt、parentPt - 用于计算标注位置
txtString - 标注内容
Returns:
None
"""
def plotMidText(cntrPt
, parentPt
, txtString
):
xMid
= (parentPt
[0] - cntrPt
[0]) / 2.0 + cntrPt
[0]
yMid
= (parentPt
[1] - cntrPt
[1]) / 2.0 + cntrPt
[1]
createPlot
.ax1
.text
(xMid
, yMid
, txtString
, va
="center", ha
="center", rotation
=30)
"""
函数说明:绘制决策树
Parameters:
myTree - 决策树(字典)
parentPt - 标注的内容
nodeTxt - 结点名
Returns:
None
"""
def plotTree(myTree
, parentPt
, nodeTxt
):
decisionNode
= dict(boxstyle
="sawtooth", fc
="0.8")
leafNode
= dict(boxstyle
="round4", fc
="0.8")
numLeafs
= getNumLeafs
(myTree
)
depth
= getTreeDepth
(myTree
)
firstStr
= next(iter(myTree
))
cntrPt
= (plotTree
.xoff
+ (1.0 + float(numLeafs
)) / 2.0 / plotTree
.totalW
, plotTree
.yoff
)
plotMidText
(cntrPt
, parentPt
, nodeTxt
)
plotNode
(firstStr
, cntrPt
, parentPt
, decisionNode
)
secondDict
= myTree
[firstStr
]
plotTree
.yoff
= plotTree
.yoff
- 1.0 / plotTree
.totalD
for key
in secondDict
.keys
():
if type(secondDict
[key
]).__name__
== 'dict':
plotTree
(secondDict
[key
], cntrPt
, str(key
))
else:
plotTree
.xoff
= plotTree
.xoff
+ 1.0 / plotTree
.totalW
plotNode
(secondDict
[key
], (plotTree
.xoff
, plotTree
.yoff
), cntrPt
, leafNode
)
plotMidText
((plotTree
.xoff
, plotTree
.yoff
), cntrPt
, str(key
))
plotTree
.yoff
= plotTree
.yoff
+ 1.0 / plotTree
.totalD
"""
函数说明:创建绘图面板
Parameters:
inTree - 决策树(字典)
Returns:
None
"""
def createPlot(inTree
):
fig
= plt
.figure
(1, facecolor
="white")
fig
.clf
()
axprops
= dict(xticks
=[], yticks
=[])
createPlot
.ax1
= plt
.subplot
(111, frameon
=False, **axprops
)
plotTree
.totalW
= float(getNumLeafs
(inTree
))
plotTree
.totalD
= float(getTreeDepth
(inTree
))
plotTree
.xoff
= -0.5 / plotTree
.totalW
plotTree
.yoff
= 1.0
plotTree
(inTree
, (0.5, 1.0), '')
plt
.show
()
"""
函数说明:使用决策树分类
Parameters:
inputTree - 已经生成的决策树
featLabels - 存储选择的最优特征标签
testVec - 测试数据列表,顺序对应最优特征标签
Returns:
classLabel - 分类结果
"""
def classify(inputTree
, featLabels
, testVec
):
firstStr
= next(iter(inputTree
))
secondDict
= inputTree
[firstStr
]
featIndex
= featLabels
.index
(firstStr
)
for key
in secondDict
.keys
():
if testVec
[featIndex
] == key
:
if type(secondDict
[key
]).__name__
== 'dict':
classLabel
= classify
(secondDict
[key
], featLabels
, testVec
)
else:
classLabel
= secondDict
[key
]
return classLabel
"""
函数说明:存储决策树
Parameters:
inputTree - 已经生成的决策树
filename - 决策树的存储文件名
Returns:
None
Modify:
2018-07-17
"""
def storeTree(inputTree
, filename
):
with open(filename
, 'wb') as fw
:
pickle
.dump
(inputTree
, fw
)
"""
函数说明:读取决策树
Parameters:
filename - 决策树的存储文件名
Returns:
pickle.load(fr) - 决策树字典
Modify:
2018-07-17
"""
def grabTree(filename
):
fr
= open(filename
, 'rb')
return pickle
.load
(fr
)
"""
函数说明:main函数
Parameters:
None
Returns:
None
"""
def main():
dataSet
, features
= createDataSet
()
featLabels
= []
myTree
= createTree
(dataSet
, features
, featLabels
)
testVec
= [0, 1, 1]
result
= classify
(myTree
, featLabels
, testVec
)
if result
== 'C1':
print('C1')
if result
== 'C2':
print('C2')
print(myTree
)
createPlot
(myTree
)
print("最优特征索引值:" + str(chooseBestFeatureToSplit
(dataSet
)))
if __name__
== '__main__':
main
()
原文链接:https://blog.csdn.net/weixin_41475854/article/details/106303409
转载请注明原文地址:https://ipadbbs.8miu.com/read-42892.html