diff options
author | terminaldweller <thabogre@gmail.com> | 2022-07-08 08:46:16 +0000 |
---|---|---|
committer | terminaldweller <thabogre@gmail.com> | 2022-07-08 08:46:16 +0000 |
commit | 23dae20a546b51d0c87f95289f31966c48cafafa (patch) | |
tree | a94743535377627af8f8d4e7d7e15f8cde24c94e /decision_tree/decision_tree.go | |
parent | poetry (diff) | |
download | seer-23dae20a546b51d0c87f95289f31966c48cafafa.tar.gz seer-23dae20a546b51d0c87f95289f31966c48cafafa.zip |
Diffstat (limited to 'decision_tree/decision_tree.go')
-rw-r--r-- | decision_tree/decision_tree.go | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/decision_tree/decision_tree.go b/decision_tree/decision_tree.go new file mode 100644 index 0000000..2e80778 --- /dev/null +++ b/decision_tree/decision_tree.go @@ -0,0 +1,106 @@ +package main + +import ( + "fmt" + "math" + "math/rand" + "time" +) + +type Class int32 + +const ( + classOne Class = 0 + classTwo = 1 +) + +type pointType struct { + x float32 + y float32 + class Class +} + +func generateRandData(n int32, points []pointType) { + rand.Seed(time.Now().UnixNano()) + var i int32 + for i = 0; i < n; i++ { + dirX := rand.Float32() + dirY := rand.Float32() + if dirX > 0.5 { + + points[i].x = rand.Float32() * 20 + } else { + points[i].x = -rand.Float32() * 20 + } + if dirY > 0.5 { + points[i].y = rand.Float32() * 20 + } else { + points[i].y = -rand.Float32() * 20 + } + class := rand.Float32() + if class < 0.5 { + points[i].class = classOne + } else { + points[i].class = classTwo + } + } +} + +func decisionTreeV1(points []pointType) { + for i := 0; i < len(points); i++ { + if points[i].x <= -12 { + if points[i].x <= 9 { + if points[i].y < 9 { + + } else { + + } + } else { + + } + } else { + + } + } +} + +func calcClassProbs(points []pointType, classCount int) []float64 { + var counts = make([]int, classCount) + for i := 0; i < len(points); i++ { + counts[points[i].class]++ + } + + var probs = make([]float64, classCount) + for i := 0; i < classCount; i++ { + probs[i] = float64(counts[i]) / float64(len(points)) + } + + return probs +} + +func calcEntropy(probs []float64) float64 { + var entropy float64 + + for i := 0; i < len(probs); i++ { + entropy += probs[i] * float64(math.Log2(probs[i])) + } + + entropy = entropy * -1 + return entropy +} + +func main() { + var n int32 = 1000 + var points = make([]pointType, n) + generateRandData(n, points) + var i int32 + for i = 0; i < n; i++ { + fmt.Println(points[i].x, ":", points[i].y, ":", points[i].class) + } + probs := calcClassProbs(points, 2) + for i := 0; i < 2; i++ { + fmt.Println(probs[i]) + } + entropy := calcEntropy(probs) + fmt.Println(entropy) +} |