aboutsummaryrefslogblamecommitdiffstats
path: root/bin/tunneltop
blob: b3a39620f4852f523be503541117ba97e31ace49 (plain) (tree)
1
2
3
4
5
6
7
8
9




                                                   
           
           
         
             














                                                          


                                                        




                                 
                                                       





                                 
                                                                               





























































































                                                                           

                              
 










                                                                               
 
                                       
 



































                                                                          
 
                    
 





                                                                     

                                    

                                                                           
             

                                  

                    

                                                                      

















                                                                               



                                                                       
                         

                                          







                                              





















                                                                        

                                 
                                                           


                                          


                                                                















                                                                        




                                                               
                                                     

                   
                                                                   


                                                                         
                                                   

                          





                                                                     



                              
                                                          

                                              


                          

                                  
#!/usr/bin/env python
"""A top-like program for monitoring ssh tunnels"""

import argparse
import asyncio
import copy
import enum
import os
import signal
import sys
import typing

import tomllib


class Argparser:  # pylint: disable=too-few-public-methods
    """Argparser class."""

    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.parser.add_argument(
            "--config",
            "-c",
            type=str,
            help="The path to the .tunneltop.toml file,"
            " defaults to $HOME/.tunneltop.toml",
            default="~/.tunneltop.toml",
        )
        self.parser.add_argument(
            "--noheader",
            "-n",
            type=bool,
            help="Dont print the header in the output",
            default=False,
        )
        self.parser.add_argument(
            "--delay",
            "-d",
            type=float,
            help="The delay between redraws in seconds, defaults to 5 seconds",
            default=5,
        )
        self.args = self.parser.parse_args()


# pylint: disable=too-few-public-methods
class Colors(enum.EnumType):
    """static color definitions"""

    purple = "\033[95m"
    blue = "\033[94m"
    green = "\033[92m"
    yellow = "\033[93m"
    red = "\033[91m"
    grey = "\033[1;37m"
    darkgrey = "\033[1;30m"
    cyan = "\033[1;36m"
    ENDC = "\033[0m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"
    blueblue = "\x1b[38;5;24m"
    greenie = "\x1b[38;5;23m"
    goo = "\x1b[38;5;22m"
    screen_clear = "\033c\033[3J"
    hide_cursor = "\033[?25l"


# pylint: disable=too-many-locals
def ffs(
    offset: int,
    header_list: typing.Optional[typing.List[str]],
    numbered: bool,
    *args,
) -> typing.List[str]:
    """A simple columnar printer"""
    max_column_width = []
    lines = []
    numbers_f: typing.List[int] = []
    dummy = []

    if sys.stdout.isatty():
        greenie = Colors.greenie
        bold = Colors.BOLD
        endc = Colors.ENDC
        goo = Colors.goo
        blueblue = Colors.blueblue
    else:
        greenie = ""
        bold = ""
        endc = ""
        goo = ""
        blueblue = ""

    for arg in args:
        max_column_width.append(max(len(repr(argette)) for argette in arg))

    if header_list is not None:
        if numbered:
            numbers_f.extend(range(1, len(args[-1]) + 1))
            max_column_width.append(
                max(len(repr(number)) for number in numbers_f)
            )
            header_list.insert(0, "idx")

        index = range(0, len(header_list))
        for header, width, i in zip(header_list, max_column_width, index):
            max_column_width[i] = max(len(header), width) + offset

        for i in index:
            dummy.append(
                greenie
                + bold
                + header_list[i].ljust(max_column_width[i])
                + endc
            )
        lines.append("".join(dummy))
        dummy.clear()

    index2 = range(0, len(args[-1]))
    for i in index2:
        if numbered:
            dummy.append(
                goo + bold + repr(i).ljust(max_column_width[0]) + endc
            )
            for arg, width in zip(args, max_column_width[1:]):
                dummy.append(blueblue + (arg[i]).ljust(width) + endc)
        else:
            for arg, width in zip(args, max_column_width):
                dummy.append(blueblue + (arg[i]).ljust(width) + endc)
        lines.append("".join(dummy))
        dummy.clear()
    return lines


class TunnelTop:
    """The tunnel top class"""

    def __init__(self):
        self.argparser = Argparser()
        self.data_cols: typing.Dict[str, typing.Dict[str, str]] = {}
        self.tunnel_tasks: typing.List[asyncio.Task] = []
        self.tunnel_test_tasks: typing.List[asyncio.Task] = []

    async def run_subshell(self, cmd: str) -> typing.Tuple[bytes, bytes]:
        """Run a command in a subshell"""
        proc = await asyncio.create_subprocess_shell(
            cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
        )

        return await proc.communicate()

    def tunnel_test_callback(self, task: asyncio.Task) -> None:
        """Tunnel test callback function."""
        try:
            task_name = task.get_name()
            self.data_cols[task_name]["stdout"] = (
                task.result()[0].decode("utf-8").strip("\n")
            )
            self.data_cols[task_name]["stderr"] = (
                task.result()[1].decode("utf-8").strip("\n")
            )
            if (
                task.result()[0].decode("utf-8").strip("\n")
                == self.data_cols[task_name]["test_command_result"]
            ):
                self.data_cols[task_name]["status"] = "UP"
            else:
                self.data_cols[task_name]["status"] = "DOWN"
        except asyncio.TimeoutError:
            self.data_cols[task_name]["status"] = "TMOUT"

    async def tunnel_test_procs(self) -> typing.List[asyncio.Task]:
        """run all the tunnel tests in the background as separate tasks"""
        tasks: typing.List[asyncio.Task] = []
        for _, value in self.data_cols.items():
            if value["test_command"] != "":
                tasks.append(
                    asyncio.create_task(
                        asyncio.wait_for(
                            self.run_subshell(value["test_command"]),
                            timeout=float(value["test_timeout"]),
                        ),
                        name=value["name"],
                    )
                )
                tasks[-1].add_done_callback(self.tunnel_test_callback)
                await asyncio.sleep(0)

        return tasks

    async def tunnel_procs(
        self,
    ) -> typing.List[asyncio.Task]:
        """run all the tunnels in the background as separate tasks"""
        tasks: typing.List[asyncio.Task] = []
        for _, value in self.data_cols.items():
            tasks.append(
                asyncio.create_task(
                    self.run_subshell(value["command"]), name=value["name"]
                ),
            )
            await asyncio.sleep(0)

        return tasks

    async def sighup_handler_async_worker(self, data_cols_new):
        """Handles the actual updating of tasks when we get SIGTERM"""
        for k, value in data_cols_new.items():
            if k not in self.data_cols:
                self.tunnel_tasks.append(
                    asyncio.create_task(
                        self.run_subshell(value["command"]), name=k
                    )
                )
                await asyncio.sleep(0)
            else:
                if (
                    self.data_cols[k]["command"] != data_cols_new[k]["command"]
                    or self.data_cols[k]["port"] != data_cols_new[k]["port"]
                    or self.data_cols[k]["address"]
                    != data_cols_new[k]["address"]
                ):
                    for task in self.tunnel_tasks:
                        if task.get_name() == k:
                            task.cancel()
                    self.data_cols[k] = copy.deepcopy(data_cols_new[k])
                    self.tunnel_tasks.append(
                        asyncio.create_task(
                            self.run_subshell(value["command"]), name=k
                        )
                    )
                    await asyncio.sleep(0)

        for k, _ in self.data_cols.items():
            if k not in data_cols_new:
                for task in self.tunnel_tasks:
                    if task.get_name() == k:
                        task.cancel()
                del self.data_cols[k]

    async def sighup_handler(self):
        """SIGHUP handler. we want to reload the config."""
        # type: ignore # pylint: disable=E0203
        data_cols_new: typing.Dict[str, typing.Dict[str, str]] = {}
        with open(self.argparser.args.config, "rb") as conf_file:
            data = tomllib.load(conf_file)
            for key, value in data.items():
                data_cols_new[key] = {
                    "name": key,
                    "address": value["address"],
                    "port": value["port"],
                    "command": value["command"],
                    "status": "UNKN",
                    "test_command": value["test_command"],
                    "test_command_result": value["test_command_result"],
                    "test_interval": value["test_interval"],
                    "test_timeout": value["test_timeout"],
                    "stdout": "",
                    "stderr": "",
                }
        await self.sighup_handler_async_worker(data_cols_new)

    async def main(self) -> None:
        """entrypoint"""
        # signal.signal(signal.SIGHUP, self.sighup_handler)
        print(Colors.screen_clear, end="")
        print(Colors.hide_cursor, end="")

        with open(
            os.path.expanduser(self.argparser.args.config), "rb"
        ) as conf_file:
            data = tomllib.load(conf_file)
            for key, value in data.items():
                self.data_cols[key] = {
                    "name": key,
                    "address": value["address"],
                    "port": value["port"],
                    "command": value["command"],
                    "status": "UNKN",
                    "test_command": value["test_command"],
                    "test_command_result": value["test_command_result"],
                    "test_interval": value["test_interval"],
                    "test_timeout": value["test_timeout"],
                    "stdout": "",
                    "stderr": "",
                }

        loop = asyncio.get_event_loop()
        loop.add_signal_handler(
            signal.SIGHUP,
            lambda: asyncio.create_task(self.sighup_handler()),
        )
        self.tunnel_tasks = await self.tunnel_procs()

        while True:
            self.tunnel_test_tasks = await self.tunnel_test_procs()
            lines = ffs(
                2,
                ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"]
                if not self.argparser.args.noheader
                else None,
                False,
                [v["name"] for _, v in self.data_cols.items()],
                [v["address"] for _, v in self.data_cols.items()],
                [repr(v["port"]) for _, v in self.data_cols.items()],
                [v["status"] for _, v in self.data_cols.items()],
                [v["stdout"] for _, v in self.data_cols.items()],
                [v["stderr"] for _, v in self.data_cols.items()],
            )
            for line in lines:
                print(line)

            await asyncio.sleep(self.argparser.args.delay)
            print(Colors.screen_clear, end="")
            print(Colors.hide_cursor, end="")


if __name__ == "__main__":
    tunnel_top = TunnelTop()
    asyncio.run(tunnel_top.main())