aboutsummaryrefslogblamecommitdiffstats
path: root/bin/tunneltop
blob: 7bf187840cb7487824da62a437c8562550732c3c (plain) (tree)
1
2
3
4
5
6
7
8
9




                                                   
           
             
           
         
             














                                                          


                                                        




                                 
                                                       





                                 
                                                                               































                                                   
                   







                                    
                                           




                     





                                  









































                                                                           

































                                                                         

                              
 





                                                                    






















                                                                        




                                                                               
 




                                           
 


















                                                                   

                                                             
















                                                                          
 
                    
 





                                                                     

                                    

                                                                           
             

                                  

                    
                                                                       
                                                                      

















                                                                               



                                                                       
                         

                                          







                                              
                                           


                                                                   
                                        

                                                             




























































                                                                 

                                 


                                  
 
                                             
 



                                                                   
             













































                                                                             


                          

                                  
#!/usr/bin/env python
"""A top-like program for monitoring ssh tunnels"""

import argparse
import asyncio
import copy
import curses
import enum
import os
import signal
import sys
import typing

import tomllib


class Argparser:  # pylint: disable=too-few-public-methods
    """Argparser class."""

    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.parser.add_argument(
            "--config",
            "-c",
            type=str,
            help="The path to the .tunneltop.toml file,"
            " defaults to $HOME/.tunneltop.toml",
            default="~/.tunneltop.toml",
        )
        self.parser.add_argument(
            "--noheader",
            "-n",
            type=bool,
            help="Dont print the header in the output",
            default=False,
        )
        self.parser.add_argument(
            "--delay",
            "-d",
            type=float,
            help="The delay between redraws in seconds, defaults to 5 seconds",
            default=5,
        )
        self.args = self.parser.parse_args()


# pylint: disable=too-few-public-methods
class Colors(enum.EnumType):
    """static color definitions"""

    purple = "\033[95m"
    blue = "\033[94m"
    green = "\033[92m"
    yellow = "\033[93m"
    red = "\033[91m"
    grey = "\033[1;37m"
    darkgrey = "\033[1;30m"
    cyan = "\033[1;36m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"
    blueblue = "\x1b[38;5;24m"
    greenie = "\x1b[38;5;23m"
    goo = "\x1b[38;5;22m"
    screen_clear = "\033c\033[3J"
    hide_cursor = "\033[?25l"


# pylint: disable=too-many-locals
def ffs(
    offset: int,
    header_list: typing.Optional[typing.List[str]],
    numbered: bool,
    no_color: bool,
    *args,
) -> typing.List[str]:
    """A simple columnar printer"""
    max_column_width = []
    lines = []
    numbers_f: typing.List[int] = []
    dummy = []

    if no_color or not sys.stdout.isatty():
        greenie = ""
        bold = ""
        endc = ""
        goo = ""
        blueblue = ""
    else:
        greenie = Colors.greenie
        bold = Colors.BOLD
        endc = Colors.ENDC
        goo = Colors.goo
        blueblue = Colors.blueblue

    for arg in args:
        max_column_width.append(max(len(repr(argette)) for argette in arg))

    if header_list is not None:
        if numbered:
            numbers_f.extend(range(1, len(args[-1]) + 1))
            max_column_width.append(
                max(len(repr(number)) for number in numbers_f)
            )
            header_list.insert(0, "idx")

        index = range(0, len(header_list))
        for header, width, i in zip(header_list, max_column_width, index):
            max_column_width[i] = max(len(header), width) + offset

        for i in index:
            dummy.append(
                greenie
                + bold
                + header_list[i].ljust(max_column_width[i])
                + endc
            )
        lines.append("".join(dummy))
        dummy.clear()

    index2 = range(0, len(args[-1]))
    for i in index2:
        if numbered:
            dummy.append(
                goo + bold + repr(i).ljust(max_column_width[0]) + endc
            )
            for arg, width in zip(args, max_column_width[1:]):
                dummy.append(blueblue + (arg[i]).ljust(width) + endc)
        else:
            for arg, width in zip(args, max_column_width):
                dummy.append(blueblue + (arg[i]).ljust(width) + endc)
        lines.append("".join(dummy))
        dummy.clear()
    return lines


def render(lines: typing.List[str], stdscr, sel: int):
    """Render the text"""
    iterator = iter(lines)
    stdscr.addstr(1, 1, lines[0], curses.color_pair(3))
    next(iterator)
    for i, line in enumerate(iterator):
        if i == sel:
            stdscr.addstr(
                (2 + i) % (len(lines) + 1), 1, line, curses.color_pair(2)
            )
        else:
            stdscr.addstr(2 + i, 1, line, curses.color_pair(1))
        stdscr.addstr("\n")
    stdscr.box()


def curses_init():
    """Initialize ncurses"""
    stdscr = curses.initscr()
    curses.start_color()
    curses.use_default_colors()
    curses.curs_set(False)
    curses.noecho()
    curses.cbreak()
    stdscr.keypad(True)
    curses.halfdelay(20)
    curses.init_pair(1, curses.COLOR_GREEN, curses.COLOR_BLACK)
    curses.init_pair(2, curses.COLOR_BLACK, curses.COLOR_GREEN)
    curses.init_pair(3, curses.COLOR_BLUE, curses.COLOR_BLACK)
    curses.init_pair(4, curses.COLOR_CYAN, curses.COLOR_BLACK)

    return stdscr


class TunnelTop:
    """The tunnel top class"""

    def __init__(self):
        self.argparser = Argparser()
        self.data_cols: typing.Dict[str, typing.Dict[str, str]] = {}
        self.tunnel_tasks: typing.List[asyncio.Task] = []
        self.tunnel_test_tasks: typing.List[asyncio.Task] = []

    def read_conf(self) -> typing.Dict[str, typing.Dict[str, str]]:
        """Read the config file"""
        data_cols: typing.Dict[str, typing.Dict[str, str]] = {}
        with open(
            os.path.expanduser(self.argparser.args.config), "rb"
        ) as conf_file:
            data = tomllib.load(conf_file)
            for key, value in data.items():
                data_cols[key] = {
                    "name": key,
                    "address": value["address"],
                    "port": value["port"],
                    "command": value["command"],
                    "status": "UNKN",
                    "test_command": value["test_command"],
                    "test_command_result": value["test_command_result"],
                    "test_interval": value["test_interval"],
                    "test_timeout": value["test_timeout"],
                    "stdout": "",
                    "stderr": "",
                }
        return data_cols

    async def run_subshell(self, cmd: str) -> typing.Tuple[bytes, bytes]:
        """Run a command in a subshell"""
        proc = await asyncio.create_subprocess_shell(
            cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
        )

        try:
            return await proc.communicate()
        except asyncio.CancelledError:
            self.write_log("fucking fuck")
            return (bytes(), bytes())

    def tunnel_test_callback(self, task: asyncio.Task) -> None:
        """Tunnel test callback function."""
        try:
            task_name = task.get_name()
            self.data_cols[task_name]["stdout"] = (
                task.result()[0].decode("utf-8").strip("\n")
            )
            self.data_cols[task_name]["stderr"] = (
                task.result()[1].decode("utf-8").strip("\n")
            )
            if (
                task.result()[0].decode("utf-8").strip("\n")
                == self.data_cols[task_name]["test_command_result"]
            ):
                self.data_cols[task_name]["status"] = "UP"
            else:
                self.data_cols[task_name]["status"] = "DOWN"
        except asyncio.TimeoutError:
            self.data_cols[task_name]["status"] = "TMOUT"
        except asyncio.CancelledError:
            self.data_cols[task_name]["status"] = "CANCELLED"

    async def tunnel_test_procs(self) -> typing.List[asyncio.Task]:
        """run all the tunnel tests in the background as separate tasks"""
        tasks: typing.List[asyncio.Task] = []
        for _, value in self.data_cols.items():
            if value["test_command"] != "":
                tasks.append(
                    asyncio.create_task(
                        asyncio.wait_for(
                            self.run_subshell(value["test_command"]),
                            timeout=float(value["test_timeout"]),
                        ),
                        name=value["name"],
                    )
                )
                tasks[-1].add_done_callback(self.tunnel_test_callback)
                await asyncio.sleep(0)

        return tasks

    async def tunnel_procs(
        self,
    ) -> typing.List[asyncio.Task]:
        """run all the tunnels in the background as separate tasks"""
        tasks: typing.List[asyncio.Task] = []
        for _, value in self.data_cols.items():
            tasks.append(
                asyncio.create_task(
                    self.run_subshell(value["command"]), name=value["name"]
                ),
            )
            await asyncio.sleep(0)

        return tasks

    async def sighup_handler_async_worker(self, data_cols_new) -> None:
        """Handles the actual updating of tasks when we get SIGTERM"""
        for k, value in data_cols_new.items():
            if k not in self.data_cols:
                self.tunnel_tasks.append(
                    asyncio.create_task(
                        self.run_subshell(value["command"]), name=k
                    )
                )
                await asyncio.sleep(0)
            else:
                if (
                    self.data_cols[k]["command"] != data_cols_new[k]["command"]
                    or self.data_cols[k]["port"] != data_cols_new[k]["port"]
                    or self.data_cols[k]["address"]
                    != data_cols_new[k]["address"]
                ):
                    for task in self.tunnel_tasks:
                        if task.get_name() == k:
                            task.cancel()
                    self.data_cols[k] = copy.deepcopy(data_cols_new[k])
                    self.tunnel_tasks.append(
                        asyncio.create_task(
                            self.run_subshell(value["command"]), name=k
                        )
                    )
                    await asyncio.sleep(0)

        for k, _ in self.data_cols.items():
            if k not in data_cols_new:
                for task in self.tunnel_tasks:
                    if task.get_name() == k:
                        task.cancel()
                del self.data_cols[k]

    async def sighup_handler(self) -> None:
        """SIGHUP handler. we want to reload the config."""
        # type: ignore # pylint: disable=E0203
        data_cols_new: typing.Dict[str, typing.Dict[str, str]] = {}
        data_cols_new = self.read_conf()
        await self.sighup_handler_async_worker(data_cols_new)

    def write_log(self, log: str):
        """A simple logger"""
        with open(
            "/home/devi/devi/abbatoir/hole15/log",
            "w",
            encoding="utf-8",
        ) as logfile:
            logfile.write(log)

    async def restart_task(self, line_content: str) -> None:
        """restart a task"""
        name: str = line_content[: line_content.find(" ")]
        was_cancelled: bool = False
        for task in self.tunnel_tasks:
            if task.get_name() == name:
                was_cancelled = task.cancel()
                self.write_log(f"was_cancelled: {was_cancelled}")
                await task
        for _, value in self.data_cols.items():
            if value["name"] == name:
                self.tunnel_tasks.append(
                    asyncio.create_task(
                        self.run_subshell(value["command"]),
                        name=value["name"],
                    )
                )
                await asyncio.sleep(0)

    async def flip_task(self, line_content: str) -> None:
        """flip a task"""
        name: str = line_content[: line_content.find(" ")]
        was_cancelled: bool = False
        was_active: bool = False
        for task in self.tunnel_tasks:
            if task.get_name() == name:
                was_cancelled = task.cancel()
                await asyncio.sleep(0)
                self.write_log(f"was_cancelled: {was_cancelled}")
                await task
                was_active = True
                break

        if not was_active:
            for _, value in self.data_cols.items():
                if value["name"] == name:
                    self.tunnel_tasks.append(
                        asyncio.create_task(
                            self.run_subshell(value["command"]),
                            name=value["name"],
                        )
                    )
                    await asyncio.sleep(0)
                    break

    async def quit(self) -> None:
        """Cleanly quit the applicaiton"""
        for tunnel_test_task in self.tunnel_test_tasks:
            tunnel_test_task.cancel()
        for tunnel_task in self.tunnel_tasks:
            tunnel_task.cancel()

    async def main(self) -> None:
        """entrypoint"""
        sel: int = 0
        try:
            stdscr = curses_init()

            self.data_cols = self.read_conf()

            loop = asyncio.get_event_loop()
            loop.add_signal_handler(
                signal.SIGHUP,
                lambda: asyncio.create_task(self.sighup_handler()),
            )
            self.tunnel_tasks = await self.tunnel_procs()

            while True:
                # self.tunnel_test_tasks = await self.tunnel_test_procs()
                lines = ffs(
                    2,
                    ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"]
                    if not self.argparser.args.noheader
                    else None,
                    False,
                    True,
                    [v["name"] for _, v in self.data_cols.items()],
                    [v["address"] for _, v in self.data_cols.items()],
                    [repr(v["port"]) for _, v in self.data_cols.items()],
                    [v["status"] for _, v in self.data_cols.items()],
                    [v["stdout"] for _, v in self.data_cols.items()],
                    [v["stderr"] for _, v in self.data_cols.items()],
                )
                stdscr.clear()
                render(lines, stdscr, sel)
                char = stdscr.getch()

                if char == ord("j") or char == curses.KEY_DOWN:
                    sel = (sel + 1) % len(self.data_cols)
                elif char == ord("k") or char == curses.KEY_UP:
                    sel = (sel - 1) % len(self.data_cols)
                elif char == ord("r"):
                    line_content = stdscr.instr(sel + 2, 1)
                    await self.restart_task(line_content.decode("utf-8"))
                elif char == ord("q"):
                    await self.quit()
                # elif char == curses.KEY_ENTER:
                elif char == ord("s"):
                    line_content = stdscr.instr(sel + 2, 1)
                    await self.flip_task(line_content.decode("utf-8"))

                stdscr.refresh()
                await asyncio.sleep(0)
        finally:
            curses.nocbreak()
            stdscr.keypad(False)
            curses.echo()
            curses.endwin()
            tasks = asyncio.all_tasks()
            for task in tasks:
                task.cancel()


if __name__ == "__main__":
    tunnel_top = TunnelTop()
    asyncio.run(tunnel_top.main())