aboutsummaryrefslogtreecommitdiffstats
path: root/decision_tree/decision_tree.go
diff options
context:
space:
mode:
authorterminaldweller <thabogre@gmail.com>2022-07-08 08:46:16 +0000
committerterminaldweller <thabogre@gmail.com>2022-07-08 08:46:16 +0000
commit23dae20a546b51d0c87f95289f31966c48cafafa (patch)
treea94743535377627af8f8d4e7d7e15f8cde24c94e /decision_tree/decision_tree.go
parentpoetry (diff)
downloadseer-master.tar.gz
seer-master.zip
a simple decision tree implementationHEADmaster
Diffstat (limited to 'decision_tree/decision_tree.go')
-rw-r--r--decision_tree/decision_tree.go106
1 files changed, 106 insertions, 0 deletions
diff --git a/decision_tree/decision_tree.go b/decision_tree/decision_tree.go
new file mode 100644
index 0000000..2e80778
--- /dev/null
+++ b/decision_tree/decision_tree.go
@@ -0,0 +1,106 @@
+package main
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "time"
+)
+
+type Class int32
+
+const (
+ classOne Class = 0
+ classTwo = 1
+)
+
+type pointType struct {
+ x float32
+ y float32
+ class Class
+}
+
+func generateRandData(n int32, points []pointType) {
+ rand.Seed(time.Now().UnixNano())
+ var i int32
+ for i = 0; i < n; i++ {
+ dirX := rand.Float32()
+ dirY := rand.Float32()
+ if dirX > 0.5 {
+
+ points[i].x = rand.Float32() * 20
+ } else {
+ points[i].x = -rand.Float32() * 20
+ }
+ if dirY > 0.5 {
+ points[i].y = rand.Float32() * 20
+ } else {
+ points[i].y = -rand.Float32() * 20
+ }
+ class := rand.Float32()
+ if class < 0.5 {
+ points[i].class = classOne
+ } else {
+ points[i].class = classTwo
+ }
+ }
+}
+
+func decisionTreeV1(points []pointType) {
+ for i := 0; i < len(points); i++ {
+ if points[i].x <= -12 {
+ if points[i].x <= 9 {
+ if points[i].y < 9 {
+
+ } else {
+
+ }
+ } else {
+
+ }
+ } else {
+
+ }
+ }
+}
+
+func calcClassProbs(points []pointType, classCount int) []float64 {
+ var counts = make([]int, classCount)
+ for i := 0; i < len(points); i++ {
+ counts[points[i].class]++
+ }
+
+ var probs = make([]float64, classCount)
+ for i := 0; i < classCount; i++ {
+ probs[i] = float64(counts[i]) / float64(len(points))
+ }
+
+ return probs
+}
+
+func calcEntropy(probs []float64) float64 {
+ var entropy float64
+
+ for i := 0; i < len(probs); i++ {
+ entropy += probs[i] * float64(math.Log2(probs[i]))
+ }
+
+ entropy = entropy * -1
+ return entropy
+}
+
+func main() {
+ var n int32 = 1000
+ var points = make([]pointType, n)
+ generateRandData(n, points)
+ var i int32
+ for i = 0; i < n; i++ {
+ fmt.Println(points[i].x, ":", points[i].y, ":", points[i].class)
+ }
+ probs := calcClassProbs(points, 2)
+ for i := 0; i < 2; i++ {
+ fmt.Println(probs[i])
+ }
+ entropy := calcEntropy(probs)
+ fmt.Println(entropy)
+}