diff options
Diffstat (limited to '')
| -rwxr-xr-x | bin/tunneltop | 333 | 
1 files changed, 225 insertions, 108 deletions
| diff --git a/bin/tunneltop b/bin/tunneltop index 7bf1878..df3510c 100755 --- a/bin/tunneltop +++ b/bin/tunneltop @@ -1,6 +1,6 @@  #!/usr/bin/env python  """A top-like program for monitoring ssh tunnels""" - +# TODO- task cancellation is very slow as should be with tasks  import argparse  import asyncio  import copy @@ -9,9 +9,8 @@ import enum  import os  import signal  import sys -import typing -  import tomllib +import typing  class Argparser:  # pylint: disable=too-few-public-methods @@ -134,18 +133,52 @@ def ffs(      return lines -def render(lines: typing.List[str], stdscr, sel: int): +def render( +    data_cols: typing.Dict[str, typing.Dict[str, str]], +    tasks: typing.List[asyncio.Task], +    stdscr, +    sel: int, +):      """Render the text""" +    lines = ffs( +        2, +        ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"], +        False, +        True, +        [v["name"] for _, v in data_cols.items()], +        [v["address"] for _, v in data_cols.items()], +        [repr(v["port"]) for _, v in data_cols.items()], +        [v["status"] for _, v in data_cols.items()], +        [v["stdout"] for _, v in data_cols.items()], +        [v["stderr"] for _, v in data_cols.items()], +    )      iterator = iter(lines)      stdscr.addstr(1, 1, lines[0], curses.color_pair(3))      next(iterator)      for i, line in enumerate(iterator): +        try: +            line_content = stdscr.instr(sel + 2, 1).decode("utf-8") +            name: str = line_content[: line_content.find(" ")] +        finally: +            name = ""          if i == sel:              stdscr.addstr( -                (2 + i) % (len(lines) + 1), 1, line, curses.color_pair(2) +                (2 + i) % (len(lines) + 1), +                1, +                line, +                curses.color_pair(2) +                if name not in tasks +                else curses.color_pair(5),              )          else: -            stdscr.addstr(2 + i, 1, line, curses.color_pair(1)) +            stdscr.addstr( +                2 + i, +                1, +                line, +                curses.color_pair(1) +                if name not in tasks +                else curses.color_pair(4), +            )          stdscr.addstr("\n")      stdscr.box() @@ -164,18 +197,56 @@ def curses_init():      curses.init_pair(2, curses.COLOR_BLACK, curses.COLOR_GREEN)      curses.init_pair(3, curses.COLOR_BLUE, curses.COLOR_BLACK)      curses.init_pair(4, curses.COLOR_CYAN, curses.COLOR_BLACK) +    curses.init_pair(5, curses.COLOR_BLACK, curses.COLOR_CYAN)      return stdscr -class TunnelTop: +class TunnelManager:      """The tunnel top class"""      def __init__(self):          self.argparser = Argparser() -        self.data_cols: typing.Dict[str, typing.Dict[str, str]] = {} +        self.data_cols: typing.Dict[ +            str, typing.Dict[str, str] +        ] = self.read_conf()          self.tunnel_tasks: typing.List[asyncio.Task] = []          self.tunnel_test_tasks: typing.List[asyncio.Task] = [] +        self.scheduler_task: asyncio.Task +        self.scheduler_table: typing.Dict[ +            str, int +        ] = self.init_scheduler_table() +        # we use this when its time to quit. this will prevent any +        # new tasks from being scheduled +        self.are_we_dying: bool = False + +    def init_scheduler_table(self) -> typing.Dict[str, int]: +        """initialize the scheduler table""" +        result: typing.Dict[str, int] = {} +        for key, value in self.data_cols.items(): +            if "test_interval" in value and value["test_command"] != "": +                result[key] = 0 + +        return result + +    async def stop_task( +        self, +        delete_task: asyncio.Task, +        task_list: typing.List[asyncio.Task], +        delete: bool = True, +    ): +        """Remove the reference""" +        delete_index: int = -1 +        delete_task.cancel() +        self.write_log(f"{delete_task.get_name()} is being cancelled\n") +        await asyncio.sleep(0) +        for i, task in enumerate(task_list): +            if task.get_name() == delete_task.get_name(): +                delete_index = i +                break + +        if delete and delete_index >= 0: +            task_list.remove(self.tunnel_tasks[delete_index])      def read_conf(self) -> typing.Dict[str, typing.Dict[str, str]]:          """Read the config file""" @@ -200,58 +271,43 @@ class TunnelTop:                  }          return data_cols -    async def run_subshell(self, cmd: str) -> typing.Tuple[bytes, bytes]: -        """Run a command in a subshell""" -        proc = await asyncio.create_subprocess_shell( -            cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE -        ) - +    async def run_subprocess(self, cmd: str) -> typing.Tuple[bytes, bytes]: +        """Run a command"""          try: +            proc = await asyncio.create_subprocess_exec( +                *cmd.split(" "), +                stdout=asyncio.subprocess.PIPE, +                stderr=asyncio.subprocess.PIPE, +            ) +              return await proc.communicate() -        except asyncio.CancelledError: -            self.write_log("fucking fuck") +        except asyncio.TimeoutError: +            proc.terminate()              return (bytes(), bytes()) +        except asyncio.CancelledError: +            proc.terminate() +            raise -    def tunnel_test_callback(self, task: asyncio.Task) -> None: -        """Tunnel test callback function.""" +    async def run_test_coro( +        self, cmd: str, task_name: str +    ) -> typing.Tuple[bytes, bytes]: +        """Run a test command"""          try: -            task_name = task.get_name() -            self.data_cols[task_name]["stdout"] = ( -                task.result()[0].decode("utf-8").strip("\n") -            ) -            self.data_cols[task_name]["stderr"] = ( -                task.result()[1].decode("utf-8").strip("\n") -            ) -            if ( -                task.result()[0].decode("utf-8").strip("\n") -                == self.data_cols[task_name]["test_command_result"] -            ): +            stdout, stderr = await self.run_subprocess(cmd) +            stdout_str: str = stdout.decode("utf-8").strip("\n").strip('"') +            stderr_str: str = stderr.decode("utf-8").strip("\n").strip('"') + +            self.data_cols[task_name]["stdout"] = stdout_str +            self.data_cols[task_name]["stderr"] = stderr_str +            if stdout_str == self.data_cols[task_name]["test_command_result"]:                  self.data_cols[task_name]["status"] = "UP"              else:                  self.data_cols[task_name]["status"] = "DOWN" + +            return stdout, stderr          except asyncio.TimeoutError:              self.data_cols[task_name]["status"] = "TMOUT" -        except asyncio.CancelledError: -            self.data_cols[task_name]["status"] = "CANCELLED" - -    async def tunnel_test_procs(self) -> typing.List[asyncio.Task]: -        """run all the tunnel tests in the background as separate tasks""" -        tasks: typing.List[asyncio.Task] = [] -        for _, value in self.data_cols.items(): -            if value["test_command"] != "": -                tasks.append( -                    asyncio.create_task( -                        asyncio.wait_for( -                            self.run_subshell(value["test_command"]), -                            timeout=float(value["test_timeout"]), -                        ), -                        name=value["name"], -                    ) -                ) -                tasks[-1].add_done_callback(self.tunnel_test_callback) -                await asyncio.sleep(0) - -        return tasks +            raise      async def tunnel_procs(          self, @@ -261,7 +317,7 @@ class TunnelTop:          for _, value in self.data_cols.items():              tasks.append(                  asyncio.create_task( -                    self.run_subshell(value["command"]), name=value["name"] +                    self.run_subprocess(value["command"]), name=value["name"]                  ),              )              await asyncio.sleep(0) @@ -270,14 +326,18 @@ class TunnelTop:      async def sighup_handler_async_worker(self, data_cols_new) -> None:          """Handles the actual updating of tasks when we get SIGTERM""" +        delete_task: typing.Optional[asyncio.Task] = None          for k, value in data_cols_new.items():              if k not in self.data_cols:                  self.tunnel_tasks.append(                      asyncio.create_task( -                        self.run_subshell(value["command"]), name=k +                        self.run_subprocess(value["command"]), name=k                      )                  )                  await asyncio.sleep(0) +                self.data_cols[k] = copy.deepcopy(value) +                if k in self.scheduler_table: +                    self.scheduler_table[k] = 0              else:                  if (                      self.data_cols[k]["command"] != data_cols_new[k]["command"] @@ -287,21 +347,38 @@ class TunnelTop:                  ):                      for task in self.tunnel_tasks:                          if task.get_name() == k: -                            task.cancel() +                            delete_task = task +                            break +                            # task.cancel() +                            # await asyncio.sleep(0) + +                    if delete_task is not None: +                        await self.stop_task(delete_task, self.tunnel_tasks) +                        delete_task = None                      self.data_cols[k] = copy.deepcopy(data_cols_new[k])                      self.tunnel_tasks.append(                          asyncio.create_task( -                            self.run_subshell(value["command"]), name=k +                            self.run_subprocess(value["command"]), name=k                          )                      ) +                    if k in self.scheduler_table: +                        self.scheduler_table[k] = 0                      await asyncio.sleep(0)          for k, _ in self.data_cols.items():              if k not in data_cols_new:                  for task in self.tunnel_tasks:                      if task.get_name() == k: -                        task.cancel() +                        # task.cancel() +                        # await asyncio.sleep(0) +                        delete_task = task +                        break +                if delete_task is not None: +                    await self.stop_task(delete_task, self.tunnel_tasks) +                    delete_task = None                  del self.data_cols[k] +                if k in self.scheduler_table: +                    del self.scheduler_table[k]      async def sighup_handler(self) -> None:          """SIGHUP handler. we want to reload the config.""" @@ -314,7 +391,7 @@ class TunnelTop:          """A simple logger"""          with open(              "/home/devi/devi/abbatoir/hole15/log", -            "w", +            "a",              encoding="utf-8",          ) as logfile:              logfile.write(log) @@ -322,33 +399,36 @@ class TunnelTop:      async def restart_task(self, line_content: str) -> None:          """restart a task"""          name: str = line_content[: line_content.find(" ")] -        was_cancelled: bool = False +        # was_cancelled: bool = False          for task in self.tunnel_tasks:              if task.get_name() == name: -                was_cancelled = task.cancel() -                self.write_log(f"was_cancelled: {was_cancelled}") -                await task -        for _, value in self.data_cols.items(): -            if value["name"] == name: -                self.tunnel_tasks.append( -                    asyncio.create_task( -                        self.run_subshell(value["command"]), -                        name=value["name"], -                    ) -                ) -                await asyncio.sleep(0) +                # was_cancelled = task.cancel() +                # self.write_log(f"was_cancelled: {was_cancelled}") +                await self.stop_task(task, self.tunnel_tasks) +                # await task +                # await asyncio.sleep(0) +                for _, value in self.data_cols.items(): +                    if value["name"] == name and task.cancelled(): +                        self.tunnel_tasks.append( +                            asyncio.create_task( +                                self.run_subprocess(value["command"]), +                                name=value["name"], +                            ) +                        ) +                        await asyncio.sleep(0)      async def flip_task(self, line_content: str) -> None:          """flip a task"""          name: str = line_content[: line_content.find(" ")] -        was_cancelled: bool = False +        # was_cancelled: bool = False          was_active: bool = False          for task in self.tunnel_tasks:              if task.get_name() == name: -                was_cancelled = task.cancel() -                await asyncio.sleep(0) -                self.write_log(f"was_cancelled: {was_cancelled}") -                await task +                await self.stop_task(task, self.tunnel_tasks) +                # was_cancelled = task.cancel() +                # await asyncio.sleep(0) +                # self.write_log(f"was_cancelled: {was_cancelled}") +                # await task                  was_active = True                  break @@ -357,69 +437,105 @@ class TunnelTop:                  if value["name"] == name:                      self.tunnel_tasks.append(                          asyncio.create_task( -                            self.run_subshell(value["command"]), +                            self.run_subprocess(value["command"]),                              name=value["name"],                          )                      )                      await asyncio.sleep(0) -                    break      async def quit(self) -> None:          """Cleanly quit the applicaiton""" -        for tunnel_test_task in self.tunnel_test_tasks: -            tunnel_test_task.cancel() -        for tunnel_task in self.tunnel_tasks: -            tunnel_task.cancel() +        # scheduler checks for this so stop making new tasks +        # when we want to quit +        self.are_we_dying = True -    async def main(self) -> None: -        """entrypoint""" -        sel: int = 0 +        for task in asyncio.all_tasks(): +            task.cancel() +            await asyncio.sleep(0)          try: -            stdscr = curses_init() +            await asyncio.gather(*asyncio.all_tasks()) +        finally: +            sys.exit(0) + +    async def scheduler(self) -> None: +        """schedulaer manages running the tests and reviving dead tunnels""" +        try: +            while True: +                if self.are_we_dying: +                    return +                for key, value in self.scheduler_table.items(): +                    if value == 0 and key not in self.tunnel_test_tasks: +                        tunnel_entry = self.data_cols[key] +                        test_task = asyncio.create_task( +                            asyncio.wait_for( +                                self.run_test_coro( +                                    tunnel_entry["test_command"], +                                    tunnel_entry["name"], +                                ), +                                timeout=float(tunnel_entry["test_timeout"]), +                            ), +                            name=key, +                        ) +                        self.tunnel_test_tasks.append(test_task) +                        self.scheduler_table[key] = int( +                            tunnel_entry["test_interval"] +                        ) +                        await asyncio.sleep(0) +                    else: +                        self.scheduler_table[key] = ( +                            self.scheduler_table[key] - 1 +                        ) -            self.data_cols = self.read_conf() +                # we are using a 1 second ticker. basically the scheduler +                # runs every second instead of as fast as it can +                await asyncio.sleep(1) +        except asyncio.CancelledError: +            pass +    async def tui_loop(self) -> None: +        """the tui loop""" +        sel: int = 0 +        try: +            stdscr = curses_init() +            # we spawn the tunnels and the test scheduler put them +            # in the background and then run the TUI loop +            self.tunnel_tasks = await self.tunnel_procs() +            self.scheduler_task = asyncio.create_task( +                self.scheduler(), name="scheduler" +            )              loop = asyncio.get_event_loop()              loop.add_signal_handler(                  signal.SIGHUP,                  lambda: asyncio.create_task(self.sighup_handler()),              ) -            self.tunnel_tasks = await self.tunnel_procs()              while True: -                # self.tunnel_test_tasks = await self.tunnel_test_procs() -                lines = ffs( -                    2, -                    ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"] -                    if not self.argparser.args.noheader -                    else None, -                    False, -                    True, -                    [v["name"] for _, v in self.data_cols.items()], -                    [v["address"] for _, v in self.data_cols.items()], -                    [repr(v["port"]) for _, v in self.data_cols.items()], -                    [v["status"] for _, v in self.data_cols.items()], -                    [v["stdout"] for _, v in self.data_cols.items()], -                    [v["stderr"] for _, v in self.data_cols.items()], -                )                  stdscr.clear() -                render(lines, stdscr, sel) +                render(self.data_cols, self.tunnel_tasks, stdscr, sel)                  char = stdscr.getch()                  if char == ord("j") or char == curses.KEY_DOWN:                      sel = (sel + 1) % len(self.data_cols)                  elif char == ord("k") or char == curses.KEY_UP:                      sel = (sel - 1) % len(self.data_cols) +                elif char == ord("g") or char == curses.KEY_UP: +                    sel = 0 +                elif char == ord("G") or char == curses.KEY_UP: +                    sel = len(self.data_cols) - 1                  elif char == ord("r"):                      line_content = stdscr.instr(sel + 2, 1)                      await self.restart_task(line_content.decode("utf-8"))                  elif char == ord("q"):                      await self.quit() -                # elif char == curses.KEY_ENTER:                  elif char == ord("s"):                      line_content = stdscr.instr(sel + 2, 1)                      await self.flip_task(line_content.decode("utf-8")) +                for task in self.tunnel_tasks: +                    self.write_log( +                        f"{task.get_name()} is {task.cancelled()} or {task.cancelling()}\n" +                    ) +                  stdscr.refresh()                  await asyncio.sleep(0)          finally: @@ -427,11 +543,12 @@ class TunnelTop:              stdscr.keypad(False)              curses.echo()              curses.endwin() -            tasks = asyncio.all_tasks() -            for task in tasks: -                task.cancel() +            # tasks = asyncio.all_tasks() +            # for task in tasks: +            #     task.cancel() +            await self.quit()  if __name__ == "__main__": -    tunnel_top = TunnelTop() -    asyncio.run(tunnel_top.main()) +    tunnel_manager = TunnelManager() +    asyncio.run(tunnel_manager.tui_loop()) | 
