diff options
Diffstat (limited to 'bin/tunneltop')
-rwxr-xr-x | bin/tunneltop | 250 |
1 files changed, 92 insertions, 158 deletions
diff --git a/bin/tunneltop b/bin/tunneltop index 7c32e17..7bf1878 100755 --- a/bin/tunneltop +++ b/bin/tunneltop @@ -1,6 +1,6 @@ #!/usr/bin/env python """A top-like program for monitoring ssh tunnels""" -# TODO- task cancellation is very slow as should be with tasks + import argparse import asyncio import copy @@ -134,52 +134,18 @@ def ffs( return lines -def render( - data_cols: typing.Dict[str, typing.Dict[str, str]], - tasks: typing.List[asyncio.Task], - stdscr, - sel: int, -): +def render(lines: typing.List[str], stdscr, sel: int): """Render the text""" - lines = ffs( - 2, - ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"], - False, - True, - [v["name"] for _, v in data_cols.items()], - [v["address"] for _, v in data_cols.items()], - [repr(v["port"]) for _, v in data_cols.items()], - [v["status"] for _, v in data_cols.items()], - [v["stdout"] for _, v in data_cols.items()], - [v["stderr"] for _, v in data_cols.items()], - ) iterator = iter(lines) stdscr.addstr(1, 1, lines[0], curses.color_pair(3)) next(iterator) for i, line in enumerate(iterator): - try: - line_content = stdscr.instr(sel + 2, 1).decode("utf-8") - name: str = line_content[: line_content.find(" ")] - finally: - name = "" if i == sel: stdscr.addstr( - (2 + i) % (len(lines) + 1), - 1, - line, - curses.color_pair(2) - if name not in tasks - else curses.color_pair(5), + (2 + i) % (len(lines) + 1), 1, line, curses.color_pair(2) ) else: - stdscr.addstr( - 2 + i, - 1, - line, - curses.color_pair(1) - if name not in tasks - else curses.color_pair(4), - ) + stdscr.addstr(2 + i, 1, line, curses.color_pair(1)) stdscr.addstr("\n") stdscr.box() @@ -198,43 +164,18 @@ def curses_init(): curses.init_pair(2, curses.COLOR_BLACK, curses.COLOR_GREEN) curses.init_pair(3, curses.COLOR_BLUE, curses.COLOR_BLACK) curses.init_pair(4, curses.COLOR_CYAN, curses.COLOR_BLACK) - curses.init_pair(5, curses.COLOR_BLACK, curses.COLOR_CYAN) return stdscr -class TunnelManager: +class TunnelTop: """The tunnel top class""" def __init__(self): self.argparser = Argparser() - self.data_cols: typing.Dict[ - str, typing.Dict[str, str] - ] = self.read_conf() + self.data_cols: typing.Dict[str, typing.Dict[str, str]] = {} self.tunnel_tasks: typing.List[asyncio.Task] = [] self.tunnel_test_tasks: typing.List[asyncio.Task] = [] - self.scheduler_task: asyncio.Task - self.scheduler_table: typing.Dict[ - str, int - ] = self.init_scheduler_table() - # we use this when its time to quit. this will prevent any - # new tasks from being scheduled - self.are_we_dying: bool = False - - loop = asyncio.get_event_loop() - loop.add_signal_handler( - signal.SIGHUP, - lambda: asyncio.create_task(self.sighup_handler()), - ) - - def init_scheduler_table(self) -> typing.Dict[str, int]: - """initialize the scheduler table""" - result: typing.Dict[str, int] = {} - for key, value in self.data_cols.items(): - if "test_interval" in value and value["test_command"] != "": - result[key] = 0 - - return result def read_conf(self) -> typing.Dict[str, typing.Dict[str, str]]: """Read the config file""" @@ -259,36 +200,58 @@ class TunnelManager: } return data_cols - async def run_subprocess(self, cmd: str) -> typing.Tuple[bytes, bytes]: - """Run a command""" - proc = await asyncio.create_subprocess_exec( - *cmd.split(" "), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + async def run_subshell(self, cmd: str) -> typing.Tuple[bytes, bytes]: + """Run a command in a subshell""" + proc = await asyncio.create_subprocess_shell( + cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) - return await proc.communicate() - - async def run_test_coro( - self, cmd: str, task_name: str - ) -> typing.Tuple[bytes, bytes]: - """Run a test command""" try: - stdout, stderr = await self.run_subprocess(cmd) - stdout_str: str = stdout.decode("utf-8").strip("\n").strip('"') - stderr_str: str = stderr.decode("utf-8").strip("\n").strip('"') + return await proc.communicate() + except asyncio.CancelledError: + self.write_log("fucking fuck") + return (bytes(), bytes()) - self.data_cols[task_name]["stdout"] = stdout_str - self.data_cols[task_name]["stderr"] = stderr_str - if stdout_str == self.data_cols[task_name]["test_command_result"]: + def tunnel_test_callback(self, task: asyncio.Task) -> None: + """Tunnel test callback function.""" + try: + task_name = task.get_name() + self.data_cols[task_name]["stdout"] = ( + task.result()[0].decode("utf-8").strip("\n") + ) + self.data_cols[task_name]["stderr"] = ( + task.result()[1].decode("utf-8").strip("\n") + ) + if ( + task.result()[0].decode("utf-8").strip("\n") + == self.data_cols[task_name]["test_command_result"] + ): self.data_cols[task_name]["status"] = "UP" else: self.data_cols[task_name]["status"] = "DOWN" - - return stdout, stderr except asyncio.TimeoutError: self.data_cols[task_name]["status"] = "TMOUT" - raise + except asyncio.CancelledError: + self.data_cols[task_name]["status"] = "CANCELLED" + + async def tunnel_test_procs(self) -> typing.List[asyncio.Task]: + """run all the tunnel tests in the background as separate tasks""" + tasks: typing.List[asyncio.Task] = [] + for _, value in self.data_cols.items(): + if value["test_command"] != "": + tasks.append( + asyncio.create_task( + asyncio.wait_for( + self.run_subshell(value["test_command"]), + timeout=float(value["test_timeout"]), + ), + name=value["name"], + ) + ) + tasks[-1].add_done_callback(self.tunnel_test_callback) + await asyncio.sleep(0) + + return tasks async def tunnel_procs( self, @@ -298,7 +261,7 @@ class TunnelManager: for _, value in self.data_cols.items(): tasks.append( asyncio.create_task( - self.run_subprocess(value["command"]), name=value["name"] + self.run_subshell(value["command"]), name=value["name"] ), ) await asyncio.sleep(0) @@ -311,13 +274,10 @@ class TunnelManager: if k not in self.data_cols: self.tunnel_tasks.append( asyncio.create_task( - self.run_subprocess(value["command"]), name=k + self.run_subshell(value["command"]), name=k ) ) await asyncio.sleep(0) - self.data_cols[k] = copy.deepcopy(value) - if k in self.scheduler_table: - self.scheduler_table[k] = 0 else: if ( self.data_cols[k]["command"] != data_cols_new[k]["command"] @@ -331,11 +291,9 @@ class TunnelManager: self.data_cols[k] = copy.deepcopy(data_cols_new[k]) self.tunnel_tasks.append( asyncio.create_task( - self.run_subprocess(value["command"]), name=k + self.run_subshell(value["command"]), name=k ) ) - if k in self.scheduler_table: - self.scheduler_table[k] = 0 await asyncio.sleep(0) for k, _ in self.data_cols.items(): @@ -344,8 +302,6 @@ class TunnelManager: if task.get_name() == k: task.cancel() del self.data_cols[k] - if k in self.scheduler_table: - del self.scheduler_table[k] async def sighup_handler(self) -> None: """SIGHUP handler. we want to reload the config.""" @@ -358,7 +314,7 @@ class TunnelManager: """A simple logger""" with open( "/home/devi/devi/abbatoir/hole15/log", - "a", + "w", encoding="utf-8", ) as logfile: logfile.write(log) @@ -372,15 +328,15 @@ class TunnelManager: was_cancelled = task.cancel() self.write_log(f"was_cancelled: {was_cancelled}") await task - for _, value in self.data_cols.items(): - if value["name"] == name and task.cancelled(): - self.tunnel_tasks.append( - asyncio.create_task( - self.run_subprocess(value["command"]), - name=value["name"], - ) - ) - await asyncio.sleep(0) + for _, value in self.data_cols.items(): + if value["name"] == name: + self.tunnel_tasks.append( + asyncio.create_task( + self.run_subshell(value["command"]), + name=value["name"], + ) + ) + await asyncio.sleep(0) async def flip_task(self, line_content: str) -> None: """flip a task""" @@ -401,76 +357,53 @@ class TunnelManager: if value["name"] == name: self.tunnel_tasks.append( asyncio.create_task( - self.run_subprocess(value["command"]), + self.run_subshell(value["command"]), name=value["name"], ) ) await asyncio.sleep(0) + break async def quit(self) -> None: """Cleanly quit the applicaiton""" - # scheduler checks for this so stop making new tasks - # when we want to quit - self.are_we_dying = True - # alternatively we could ask asyncio to cancel all tasks for tunnel_test_task in self.tunnel_test_tasks: tunnel_test_task.cancel() for tunnel_task in self.tunnel_tasks: tunnel_task.cancel() - try: - await asyncio.gather(*self.tunnel_test_tasks) - await asyncio.gather(*self.tunnel_tasks) - except asyncio.TimeoutError: - pass - finally: - sys.exit(0) - - async def scheduler(self) -> None: - """schedulaer manages running the tests and reviving dead tunnels""" - while True: - if self.are_we_dying: - return - for key, value in self.scheduler_table.items(): - if value == 0 and key not in self.tunnel_test_tasks: - tunnel_entry = self.data_cols[key] - test_task = asyncio.create_task( - asyncio.wait_for( - self.run_test_coro( - tunnel_entry["test_command"], - tunnel_entry["name"], - ), - timeout=float(tunnel_entry["test_timeout"]), - ), - name=key, - ) - # test_task.add_done_callback(self.tunnel_test_callback) - self.tunnel_test_tasks.append(test_task) - self.scheduler_table[key] = int( - tunnel_entry["test_interval"] - ) - await asyncio.sleep(0) - else: - self.scheduler_table[key] = self.scheduler_table[key] - 1 - - # we are using a 1 second ticker. basically the scheduler - # runs every second instead of as fast as it can - await asyncio.sleep(1) - async def tui_loop(self) -> None: + async def main(self) -> None: """entrypoint""" sel: int = 0 try: stdscr = curses_init() - # we spawn the tunnels and the test scheduler put them - # in the background and then run the TUI loop - self.tunnel_tasks = await self.tunnel_procs() - self.scheduler_task = asyncio.create_task( - self.scheduler(), name="scheduler" + + self.data_cols = self.read_conf() + + loop = asyncio.get_event_loop() + loop.add_signal_handler( + signal.SIGHUP, + lambda: asyncio.create_task(self.sighup_handler()), ) + self.tunnel_tasks = await self.tunnel_procs() while True: + # self.tunnel_test_tasks = await self.tunnel_test_procs() + lines = ffs( + 2, + ["NAME", "ADDRESS", "PORT", "STATUS", "STDOUT", "STDERR"] + if not self.argparser.args.noheader + else None, + False, + True, + [v["name"] for _, v in self.data_cols.items()], + [v["address"] for _, v in self.data_cols.items()], + [repr(v["port"]) for _, v in self.data_cols.items()], + [v["status"] for _, v in self.data_cols.items()], + [v["stdout"] for _, v in self.data_cols.items()], + [v["stderr"] for _, v in self.data_cols.items()], + ) stdscr.clear() - render(self.data_cols, self.tunnel_tasks, stdscr, sel) + render(lines, stdscr, sel) char = stdscr.getch() if char == ord("j") or char == curses.KEY_DOWN: @@ -482,6 +415,7 @@ class TunnelManager: await self.restart_task(line_content.decode("utf-8")) elif char == ord("q"): await self.quit() + # elif char == curses.KEY_ENTER: elif char == ord("s"): line_content = stdscr.instr(sel + 2, 1) await self.flip_task(line_content.decode("utf-8")) @@ -499,5 +433,5 @@ class TunnelManager: if __name__ == "__main__": - tunnel_manager = TunnelManager() - asyncio.run(tunnel_manager.tui_loop()) + tunnel_top = TunnelTop() + asyncio.run(tunnel_top.main()) |