博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
JS简单实现决策树(ID3算法)
阅读量:5832 次
发布时间:2019-06-18

本文共 5943 字,大约阅读时间需要 19 分钟。

img_876d970d8d4b0c5624b8cbd30239d9b4.png

推荐阅读:

完整示例代码:

决策树算法代码实现

1.准备测试数据

这里我假设公司有个小姐姐相亲见面为例

得到以下是已经见面或被淘汰了的数据(部分数据使用来生成的):

var data =        [            { "姓名": "余夏", "年龄": 29, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },            { "姓名": "豆豆", "年龄": 25, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },            { "姓名": "帅常荣", "年龄": 26, "长相": "帅", "体型": "胖", "收入": "高", 见面: "见" },            { "姓名": "王涛", "年龄": 22, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },            { "姓名": "李东", "年龄": 23, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },            { "姓名": "王五五", "年龄": 23, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "见" },            { "姓名": "王小涛", "年龄": 22, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "见" },            { "姓名": "李缤", "年龄": 21, "长相": "帅", "体型": "胖", "收入": "高", 见面: "见" },            { "姓名": "刘明", "年龄": 21, "长相": "帅", "体型": "胖", "收入": "低", 见面: "不见" },            { "姓名": "红鹤", "年龄": 21, "长相": "不帅", "体型": "胖", "收入": "高", 见面: "不见" },            { "姓名": "李理", "年龄": 32, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "不见" },            { "姓名": "周州", "年龄": 31, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "不见" },            { "姓名": "李乐", "年龄": 27, "长相": "不帅", "体型": "胖", "收入": "高", 见面: "不见" },            { "姓名": "韩明", "年龄": 24, "长相": "不帅", "体型": "瘦", "收入": "高", 见面: "不见" },            { "姓名": "小吕", "年龄": 28, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },            { "姓名": "李四", "年龄": 25, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },            { "姓名": "王鹏", "年龄": 30, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },        ];

2.搭建决策树基本函数

代码:

function DecisionTree(config) {    if (typeof config == "object" && !Array.isArray(config)) this.training(config);};DecisionTree.prototype = {    //分割函数    _predicates: {},    //统计属性值在数据集中的次数    countUniqueValues(items, attr) {},    //获取对象中值最大的Key  假设 counter={a:9,b:2} 得到 "a"     getMaxKey(counter) {},    //寻找最频繁的特定属性值    mostFrequentValue(items, attr) {},    //根据属性切割数据集     split(items, attr, predicate, pivot) {},    //计算熵    entropy(items, attr) {},    //生成决策树    buildDecisionTree(config) {},    //初始化生成决策树    training(config) {},    //预测 测试    predict(data) {},};var decisionTree = new DecisionTree();

3.实现函数功能

由于部分函数过于简单我就不进行讲解了

可前往 查看完整代码
里面包含注释,与每个函数的测试方法

这里的话我主要讲解下:计算熵的函数、生成决策树函数(信息增益)、与预测函数的实现

中解释了计算熵信息增益的公式

img_2efae1bec3a63509d5e1beb8b3fec9ce.jpe
截图

3.1.计算熵(entropy)函数

根据公式:

img_7417a48f0189e3816498d3e305fafc14.png
公式

我们可以知道计算H(S)(也就是熵)需要得到 p(x)=x/总数量 然后进行计算累加就行了

代码:

//......略//统计属性值在数据集中的次数countUniqueValues(items, attr) {    var counter = {}; // 获取不同的结果值 与出现次数    for (var i of items) {        if (!counter[i[attr]]) counter[i[attr]] = 0;        counter[i[attr]] += 1;    }    return counter;},//......略//计算熵entropy(items, attr) {    var counter = this.countUniqueValues(items, attr); //计算值的出现数    var p, entropy = 0; //H(S)=entropy=∑(P(Xi)(log2(P(Xi))))    for (var i in counter) {        p = counter[i] / items.length; //P(Xi)概率值        entropy += -p * Math.log2(p); //entropy+=-(P(Xi)(log2(P(Xi))))    }    return entropy;},//......略var decisionTree = new DecisionTree();console.log("函数 countUniqueValues 测试:");console.log("   长相", decisionTree.countUniqueValues(data, "长相")); //测试console.log("   年龄", decisionTree.countUniqueValues(data, "年龄")); //测试console.log("   收入", decisionTree.countUniqueValues(data, "收入")); //测试console.log("函数 entropy 测试:");console.log("   长相", decisionTree.entropy(data, "长相")); //测试console.log("   年龄", decisionTree.entropy(data, "年龄")); //测试console.log("   收入", decisionTree.entropy(data, "收入")); //测试
3.2.信息增益
img_b009e09dc33b1a21074149f83cb556d5.png
公式

根据公式我们知道要得到信息增益的值需要得到:

  • H(S) 训练集熵
  • p(t)分支元素的占比
  • H(t)分支数据集的熵

其中t我们就先分 match(合适的)on match(不合适),所以H(t):

  • H(match) 分割后合适的数据集的熵
  • H(on match) 分割后不合适的数据集的熵

所以信息增益G=H(S)-(p(match)H(match)+p(on match)H(on match))

因为p(match)=match数量/数据集总项数量
信息增益G=H(S)-((match数量)xH(match)+(on match数量)xH(on match))/数据集总项数量

//......略buildDecisionTree(config){    var trainingSet = config.trainingSet;//训练集     var categoryAttr = config.categoryAttr;//用于区分的类别属性    //......略    //初始计算 训练集的熵    var initialEntropy = this.entropy(trainingSet, categoryAttr);//<===H(S)    //......略    var alreadyChecked = [];//标识已经计算过了    var bestSplit = { gain: 0 };//储存当前最佳的分割节点数据信息    //遍历数据集    for (var item of trainingSet) {        // 遍历项中的所有属性        for (var attr in item) {            //跳过区分属性与忽略属性            if ((attr == categoryAttr) || (ignoredAttributes.indexOf(attr) >= 0)) continue;            var pivot = item[attr];// 当前属性的值             var predicateName = ((typeof pivot == 'number') ? '>=' : '=='); //根据数据类型选择判断条件            var attrPredPivot = attr + predicateName + pivot;            if (alreadyChecked.indexOf(attrPredPivot) >= 0) continue;//已经计算过则跳过            alreadyChecked.push(attrPredPivot);//记录            var predicate = this._predicates[predicateName];//匹配分割方式            var currSplit = this.split(trainingSet, attr, predicate, pivot);            var matchEntropy = this.entropy(currSplit.match, categoryAttr);//  H(match) 计算分割后合适的数据集的熵            var notMatchEntropy = this.entropy(currSplit.notMatch, categoryAttr);// H(on match) 计算分割后不合适的数据集的熵             //计算信息增益:              // IG(A,S)=H(S)-(∑P(t)H(t)))              // t为分裂的子集match(匹配),on match(不匹配)             // P(match)=match的长度/数据集的长度             // P(on match)=on match的长度/数据集的长度             var iGain = initialEntropy - ((matchEntropy * currSplit.match.length                        + notMatchEntropy * currSplit.notMatch.length) / trainingSet.length);              //不断匹配最佳增益值对应的节点信息              if (iGain > bestSplit.gain) {                  //......略              }        }    }     //......递归计算分支}
3.3.预测功能

预测功能的话就只要将要预测的值传入,循环去寻找符合条件的分支,直到找到最后的所属分类为止,这里就不详细解释了

代码:

//......略//预测 测试predict(data) {    var attr, value, predicate, pivot;    var tree = this.root;    while (true) {        if (tree.category) {            return tree.category;        }        attr = tree.attribute;        value = data[attr];        predicate = tree.predicate;        pivot = tree.pivot;        if (predicate(value, pivot)) {            tree = tree.match;        } else {            tree = tree.notMatch;        }    }}//......略

4.最后测试

img_a5b67b4bb67c52863414145317bd7a2a.gif

转载地址:http://fxrdx.baihongyu.com/

你可能感兴趣的文章
listbox用法
查看>>
冲刺第九天 1.10 THU
查看>>
传值方式:ajax技术和普通传值方式
查看>>
Linux-网络连接-(VMware与CentOS)
查看>>
寻找链表相交节点
查看>>
AS3——禁止swf缩放
查看>>
linq 学习笔记之 Linq基本子句
查看>>
[Js]布局转换
查看>>
Hot Bath
查看>>
国内常用NTP服务器地址及
查看>>
Java annotation 自定义注释@interface的用法
查看>>
Apache Spark 章节1
查看>>
phpcms与discuz的ucenter整合
查看>>
Linux crontab定时执行任务
查看>>
mysql root密码重置
查看>>
33蛇形填数
查看>>
选择排序
查看>>
SQL Server 数据库的数据和日志空间信息
查看>>
前端基础之JavaScript
查看>>
自己动手做个智能小车(6)
查看>>