diff options
-rwxr-xr-x | bin/tunneltop | 333 |
1 files changed, 225 insertions, 108 deletions
diff --git a/bin/tunneltop b/bin/tunneltop index 7bf1878..df3510c 100755 --- a/bin/tunneltop +++ b/bin/tunneltop @@ -1,6 +1,6 @@ #!/usr/bin/env python """A top-like program for monitoring ssh tunnels""" - +# TODO- task cancellation is very slow as should be with tasks import argparse import asyncio import copy @@ -9,9 +9,8 @@ import enum import os import signal import sys -import typing - import tomllib +import typing class Argparser: # pylint: disable=too-few-public-methods @@ -134,18 +133,52 @@ def ffs( return lines -def render(lines: typing.List[str], stdscr, sel: int): +def render( + data_cols: typing.Dict[str, typing.Dict[str, str]], + tasks: typing.List[asyncio.Task], + stdscr, + sel: int, +): """Render the text""" + lines = ffs( + 2, + ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"], + False, + True, + [v["name"] for _, v in data_cols.items()], + [v["address"] for _, v in data_cols.items()], + [repr(v["port"]) for _, v in data_cols.items()], + [v["status"] for _, v in data_cols.items()], + [v["stdout"] for _, v in data_cols.items()], + [v["stderr"] for _, v in data_cols.items()], + ) iterator = iter(lines) stdscr.addstr(1, 1, lines[0], curses.color_pair(3)) next(iterator) for i, line in enumerate(iterator): + try: + line_content = stdscr.instr(sel + 2, 1).decode("utf-8") + name: str = line_content[: line_content.find(" ")] + finally: + name = "" if i == sel: stdscr.addstr( - (2 + i) % (len(lines) + 1), 1, line, curses.color_pair(2) + (2 + i) % (len(lines) + 1), + 1, + line, + curses.color_pair(2) + if name not in tasks + else curses.color_pair(5), ) else: - stdscr.addstr(2 + i, 1, line, curses.color_pair(1)) + stdscr.addstr( + 2 + i, + 1, + line, + curses.color_pair(1) + if name not in tasks + else curses.color_pair(4), + ) stdscr.addstr("\n") stdscr.box() @@ -164,18 +197,56 @@ def curses_init(): curses.init_pair(2, curses.COLOR_BLACK, curses.COLOR_GREEN) curses.init_pair(3, curses.COLOR_BLUE, curses.COLOR_BLACK) curses.init_pair(4, curses.COLOR_CYAN, curses.COLOR_BLACK) + curses.init_pair(5, curses.COLOR_BLACK, curses.COLOR_CYAN) return stdscr -class TunnelTop: +class TunnelManager: """The tunnel top class""" def __init__(self): self.argparser = Argparser() - self.data_cols: typing.Dict[str, typing.Dict[str, str]] = {} + self.data_cols: typing.Dict[ + str, typing.Dict[str, str] + ] = self.read_conf() self.tunnel_tasks: typing.List[asyncio.Task] = [] self.tunnel_test_tasks: typing.List[asyncio.Task] = [] + self.scheduler_task: asyncio.Task + self.scheduler_table: typing.Dict[ + str, int + ] = self.init_scheduler_table() + # we use this when its time to quit. this will prevent any + # new tasks from being scheduled + self.are_we_dying: bool = False + + def init_scheduler_table(self) -> typing.Dict[str, int]: + """initialize the scheduler table""" + result: typing.Dict[str, int] = {} + for key, value in self.data_cols.items(): + if "test_interval" in value and value["test_command"] != "": + result[key] = 0 + + return result + + async def stop_task( + self, + delete_task: asyncio.Task, + task_list: typing.List[asyncio.Task], + delete: bool = True, + ): + """Remove the reference""" + delete_index: int = -1 + delete_task.cancel() + self.write_log(f"{delete_task.get_name()} is being cancelled\n") + await asyncio.sleep(0) + for i, task in enumerate(task_list): + if task.get_name() == delete_task.get_name(): + delete_index = i + break + + if delete and delete_index >= 0: + task_list.remove(self.tunnel_tasks[delete_index]) def read_conf(self) -> typing.Dict[str, typing.Dict[str, str]]: """Read the config file""" @@ -200,58 +271,43 @@ class TunnelTop: } return data_cols - async def run_subshell(self, cmd: str) -> typing.Tuple[bytes, bytes]: - """Run a command in a subshell""" - proc = await asyncio.create_subprocess_shell( - cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - + async def run_subprocess(self, cmd: str) -> typing.Tuple[bytes, bytes]: + """Run a command""" try: + proc = await asyncio.create_subprocess_exec( + *cmd.split(" "), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + return await proc.communicate() - except asyncio.CancelledError: - self.write_log("fucking fuck") + except asyncio.TimeoutError: + proc.terminate() return (bytes(), bytes()) + except asyncio.CancelledError: + proc.terminate() + raise - def tunnel_test_callback(self, task: asyncio.Task) -> None: - """Tunnel test callback function.""" + async def run_test_coro( + self, cmd: str, task_name: str + ) -> typing.Tuple[bytes, bytes]: + """Run a test command""" try: - task_name = task.get_name() - self.data_cols[task_name]["stdout"] = ( - task.result()[0].decode("utf-8").strip("\n") - ) - self.data_cols[task_name]["stderr"] = ( - task.result()[1].decode("utf-8").strip("\n") - ) - if ( - task.result()[0].decode("utf-8").strip("\n") - == self.data_cols[task_name]["test_command_result"] - ): + stdout, stderr = await self.run_subprocess(cmd) + stdout_str: str = stdout.decode("utf-8").strip("\n").strip('"') + stderr_str: str = stderr.decode("utf-8").strip("\n").strip('"') + + self.data_cols[task_name]["stdout"] = stdout_str + self.data_cols[task_name]["stderr"] = stderr_str + if stdout_str == self.data_cols[task_name]["test_command_result"]: self.data_cols[task_name]["status"] = "UP" else: self.data_cols[task_name]["status"] = "DOWN" + + return stdout, stderr except asyncio.TimeoutError: self.data_cols[task_name]["status"] = "TMOUT" - except asyncio.CancelledError: - self.data_cols[task_name]["status"] = "CANCELLED" - - async def tunnel_test_procs(self) -> typing.List[asyncio.Task]: - """run all the tunnel tests in the background as separate tasks""" - tasks: typing.List[asyncio.Task] = [] - for _, value in self.data_cols.items(): - if value["test_command"] != "": - tasks.append( - asyncio.create_task( - asyncio.wait_for( - self.run_subshell(value["test_command"]), - timeout=float(value["test_timeout"]), - ), - name=value["name"], - ) - ) - tasks[-1].add_done_callback(self.tunnel_test_callback) - await asyncio.sleep(0) - - return tasks + raise async def tunnel_procs( self, @@ -261,7 +317,7 @@ class TunnelTop: for _, value in self.data_cols.items(): tasks.append( asyncio.create_task( - self.run_subshell(value["command"]), name=value["name"] + self.run_subprocess(value["command"]), name=value["name"] ), ) await asyncio.sleep(0) @@ -270,14 +326,18 @@ class TunnelTop: async def sighup_handler_async_worker(self, data_cols_new) -> None: """Handles the actual updating of tasks when we get SIGTERM""" + delete_task: typing.Optional[asyncio.Task] = None for k, value in data_cols_new.items(): if k not in self.data_cols: self.tunnel_tasks.append( asyncio.create_task( - self.run_subshell(value["command"]), name=k + self.run_subprocess(value["command"]), name=k ) ) await asyncio.sleep(0) + self.data_cols[k] = copy.deepcopy(value) + if k in self.scheduler_table: + self.scheduler_table[k] = 0 else: if ( self.data_cols[k]["command"] != data_cols_new[k]["command"] @@ -287,21 +347,38 @@ class TunnelTop: ): for task in self.tunnel_tasks: if task.get_name() == k: - task.cancel() + delete_task = task + break + # task.cancel() + # await asyncio.sleep(0) + + if delete_task is not None: + await self.stop_task(delete_task, self.tunnel_tasks) + delete_task = None self.data_cols[k] = copy.deepcopy(data_cols_new[k]) self.tunnel_tasks.append( asyncio.create_task( - self.run_subshell(value["command"]), name=k + self.run_subprocess(value["command"]), name=k ) ) + if k in self.scheduler_table: + self.scheduler_table[k] = 0 await asyncio.sleep(0) for k, _ in self.data_cols.items(): if k not in data_cols_new: for task in self.tunnel_tasks: if task.get_name() == k: - task.cancel() + # task.cancel() + # await asyncio.sleep(0) + delete_task = task + break + if delete_task is not None: + await self.stop_task(delete_task, self.tunnel_tasks) + delete_task = None del self.data_cols[k] + if k in self.scheduler_table: + del self.scheduler_table[k] async def sighup_handler(self) -> None: """SIGHUP handler. we want to reload the config.""" @@ -314,7 +391,7 @@ class TunnelTop: """A simple logger""" with open( "/home/devi/devi/abbatoir/hole15/log", - "w", + "a", encoding="utf-8", ) as logfile: logfile.write(log) @@ -322,33 +399,36 @@ class TunnelTop: async def restart_task(self, line_content: str) -> None: """restart a task""" name: str = line_content[: line_content.find(" ")] - was_cancelled: bool = False + # was_cancelled: bool = False for task in self.tunnel_tasks: if task.get_name() == name: - was_cancelled = task.cancel() - self.write_log(f"was_cancelled: {was_cancelled}") - await task - for _, value in self.data_cols.items(): - if value["name"] == name: - self.tunnel_tasks.append( - asyncio.create_task( - self.run_subshell(value["command"]), - name=value["name"], - ) - ) - await asyncio.sleep(0) + # was_cancelled = task.cancel() + # self.write_log(f"was_cancelled: {was_cancelled}") + await self.stop_task(task, self.tunnel_tasks) + # await task + # await asyncio.sleep(0) + for _, value in self.data_cols.items(): + if value["name"] == name and task.cancelled(): + self.tunnel_tasks.append( + asyncio.create_task( + self.run_subprocess(value["command"]), + name=value["name"], + ) + ) + await asyncio.sleep(0) async def flip_task(self, line_content: str) -> None: """flip a task""" name: str = line_content[: line_content.find(" ")] - was_cancelled: bool = False + # was_cancelled: bool = False was_active: bool = False for task in self.tunnel_tasks: if task.get_name() == name: - was_cancelled = task.cancel() - await asyncio.sleep(0) - self.write_log(f"was_cancelled: {was_cancelled}") - await task + await self.stop_task(task, self.tunnel_tasks) + # was_cancelled = task.cancel() + # await asyncio.sleep(0) + # self.write_log(f"was_cancelled: {was_cancelled}") + # await task was_active = True break @@ -357,69 +437,105 @@ class TunnelTop: if value["name"] == name: self.tunnel_tasks.append( asyncio.create_task( - self.run_subshell(value["command"]), + self.run_subprocess(value["command"]), name=value["name"], ) ) await asyncio.sleep(0) - break async def quit(self) -> None: """Cleanly quit the applicaiton""" - for tunnel_test_task in self.tunnel_test_tasks: - tunnel_test_task.cancel() - for tunnel_task in self.tunnel_tasks: - tunnel_task.cancel() + # scheduler checks for this so stop making new tasks + # when we want to quit + self.are_we_dying = True - async def main(self) -> None: - """entrypoint""" - sel: int = 0 + for task in asyncio.all_tasks(): + task.cancel() + await asyncio.sleep(0) try: - stdscr = curses_init() + await asyncio.gather(*asyncio.all_tasks()) + finally: + sys.exit(0) + + async def scheduler(self) -> None: + """schedulaer manages running the tests and reviving dead tunnels""" + try: + while True: + if self.are_we_dying: + return + for key, value in self.scheduler_table.items(): + if value == 0 and key not in self.tunnel_test_tasks: + tunnel_entry = self.data_cols[key] + test_task = asyncio.create_task( + asyncio.wait_for( + self.run_test_coro( + tunnel_entry["test_command"], + tunnel_entry["name"], + ), + timeout=float(tunnel_entry["test_timeout"]), + ), + name=key, + ) + self.tunnel_test_tasks.append(test_task) + self.scheduler_table[key] = int( + tunnel_entry["test_interval"] + ) + await asyncio.sleep(0) + else: + self.scheduler_table[key] = ( + self.scheduler_table[key] - 1 + ) - self.data_cols = self.read_conf() + # we are using a 1 second ticker. basically the scheduler + # runs every second instead of as fast as it can + await asyncio.sleep(1) + except asyncio.CancelledError: + pass + async def tui_loop(self) -> None: + """the tui loop""" + sel: int = 0 + try: + stdscr = curses_init() + # we spawn the tunnels and the test scheduler put them + # in the background and then run the TUI loop + self.tunnel_tasks = await self.tunnel_procs() + self.scheduler_task = asyncio.create_task( + self.scheduler(), name="scheduler" + ) loop = asyncio.get_event_loop() loop.add_signal_handler( signal.SIGHUP, lambda: asyncio.create_task(self.sighup_handler()), ) - self.tunnel_tasks = await self.tunnel_procs() while True: - # self.tunnel_test_tasks = await self.tunnel_test_procs() - lines = ffs( - 2, - ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"] - if not self.argparser.args.noheader - else None, - False, - True, - [v["name"] for _, v in self.data_cols.items()], - [v["address"] for _, v in self.data_cols.items()], - [repr(v["port"]) for _, v in self.data_cols.items()], - [v["status"] for _, v in self.data_cols.items()], - [v["stdout"] for _, v in self.data_cols.items()], - [v["stderr"] for _, v in self.data_cols.items()], - ) stdscr.clear() - render(lines, stdscr, sel) + render(self.data_cols, self.tunnel_tasks, stdscr, sel) char = stdscr.getch() if char == ord("j") or char == curses.KEY_DOWN: sel = (sel + 1) % len(self.data_cols) elif char == ord("k") or char == curses.KEY_UP: sel = (sel - 1) % len(self.data_cols) + elif char == ord("g") or char == curses.KEY_UP: + sel = 0 + elif char == ord("G") or char == curses.KEY_UP: + sel = len(self.data_cols) - 1 elif char == ord("r"): line_content = stdscr.instr(sel + 2, 1) await self.restart_task(line_content.decode("utf-8")) elif char == ord("q"): await self.quit() - # elif char == curses.KEY_ENTER: elif char == ord("s"): line_content = stdscr.instr(sel + 2, 1) await self.flip_task(line_content.decode("utf-8")) + for task in self.tunnel_tasks: + self.write_log( + f"{task.get_name()} is {task.cancelled()} or {task.cancelling()}\n" + ) + stdscr.refresh() await asyncio.sleep(0) finally: @@ -427,11 +543,12 @@ class TunnelTop: stdscr.keypad(False) curses.echo() curses.endwin() - tasks = asyncio.all_tasks() - for task in tasks: - task.cancel() + # tasks = asyncio.all_tasks() + # for task in tasks: + # task.cancel() + await self.quit() if __name__ == "__main__": - tunnel_top = TunnelTop() - asyncio.run(tunnel_top.main()) + tunnel_manager = TunnelManager() + asyncio.run(tunnel_manager.tui_loop()) |