diff options
Diffstat (limited to 'decision_tree')
| -rw-r--r-- | decision_tree/decision_tree.go | 106 | ||||
| -rw-r--r-- | decision_tree/go.mod | 3 | 
2 files changed, 109 insertions, 0 deletions
| diff --git a/decision_tree/decision_tree.go b/decision_tree/decision_tree.go new file mode 100644 index 0000000..2e80778 --- /dev/null +++ b/decision_tree/decision_tree.go @@ -0,0 +1,106 @@ +package main + +import ( +	"fmt" +	"math" +	"math/rand" +	"time" +) + +type Class int32 + +const ( +	classOne Class = 0 +	classTwo       = 1 +) + +type pointType struct { +	x     float32 +	y     float32 +	class Class +} + +func generateRandData(n int32, points []pointType) { +	rand.Seed(time.Now().UnixNano()) +	var i int32 +	for i = 0; i < n; i++ { +		dirX := rand.Float32() +		dirY := rand.Float32() +		if dirX > 0.5 { + +			points[i].x = rand.Float32() * 20 +		} else { +			points[i].x = -rand.Float32() * 20 +		} +		if dirY > 0.5 { +			points[i].y = rand.Float32() * 20 +		} else { +			points[i].y = -rand.Float32() * 20 +		} +		class := rand.Float32() +		if class < 0.5 { +			points[i].class = classOne +		} else { +			points[i].class = classTwo +		} +	} +} + +func decisionTreeV1(points []pointType) { +	for i := 0; i < len(points); i++ { +		if points[i].x <= -12 { +			if points[i].x <= 9 { +				if points[i].y < 9 { + +				} else { + +				} +			} else { + +			} +		} else { + +		} +	} +} + +func calcClassProbs(points []pointType, classCount int) []float64 { +	var counts = make([]int, classCount) +	for i := 0; i < len(points); i++ { +		counts[points[i].class]++ +	} + +	var probs = make([]float64, classCount) +	for i := 0; i < classCount; i++ { +		probs[i] = float64(counts[i]) / float64(len(points)) +	} + +	return probs +} + +func calcEntropy(probs []float64) float64 { +	var entropy float64 + +	for i := 0; i < len(probs); i++ { +		entropy += probs[i] * float64(math.Log2(probs[i])) +	} + +	entropy = entropy * -1 +	return entropy +} + +func main() { +	var n int32 = 1000 +	var points = make([]pointType, n) +	generateRandData(n, points) +	var i int32 +	for i = 0; i < n; i++ { +		fmt.Println(points[i].x, ":", points[i].y, ":", points[i].class) +	} +	probs := calcClassProbs(points, 2) +	for i := 0; i < 2; i++ { +		fmt.Println(probs[i]) +	} +	entropy := calcEntropy(probs) +	fmt.Println(entropy) +} diff --git a/decision_tree/go.mod b/decision_tree/go.mod new file mode 100644 index 0000000..0ffdf1b --- /dev/null +++ b/decision_tree/go.mod @@ -0,0 +1,3 @@ +module decision_tree + +go 1.18 | 
