From 7c8f259f9b7523f5fc79d22c0eeab5226b3e1076 Mon Sep 17 00:00:00 2001
From: Jonathan Weth <git@jonathanweth.de>
Date: Mon, 27 Jul 2020 11:25:42 +0200
Subject: [PATCH] Create global trigger server and handle SIGINT correctly

---
 service.py | 44 ++++++++++++++++++++++++++++++++++++++++----
 session.py | 40 +++++-----------------------------------
 trigger.py | 36 ++++++++++++++++++++++++++++++++++++
 3 files changed, 81 insertions(+), 39 deletions(-)
 create mode 100644 trigger.py

diff --git a/service.py b/service.py
index 4c4e927..47ebfcb 100755
--- a/service.py
+++ b/service.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python3
 import argparse
 import json
+import signal
 import time
 from threading import Thread
 from typing import Union
@@ -9,15 +10,22 @@ import dbus
 import dbus.service
 
 from session import Session
+from trigger import TriggerServerThread
+
+ONE_TIME_SERVICE = False
 
 
 class RWAService(dbus.service.Object):
-    def __init__(self, mockup_mode: bool):
+    def __init__(self, loop, mockup_mode: bool = False):
+        self.loop = loop
         self.mockup_mode = mockup_mode
 
         self.bus = dbus.SessionBus()
         name = dbus.service.BusName("org.ArcticaProject.RWA", bus=self.bus)
 
+        self.trigger_service = TriggerServerThread(self._trigger)
+        self.trigger_service.start()
+
         self.update_service_running = False
         self.sessions = {}
         super().__init__(name, "/RWA")
@@ -26,7 +34,7 @@ class RWAService(dbus.service.Object):
     def start(self):
         """Start a new remote session."""
         # Start session
-        session = Session(mockup_mode)
+        session = Session(self.trigger_service.port, mockup_mode)
 
         # Add session to sessions list
         self.sessions[session.pid] = session
@@ -102,7 +110,28 @@ class RWAService(dbus.service.Object):
             time.sleep(2)
 
         self.update_service_running = False
-        # TODO Probably kill daemon here (quit main loop)
+        if ONE_TIME_SERVICE:
+            self._stop_all()
+
+    def _trigger(self, token: str) -> bool:
+        """Trigger a specific session via trigger token."""
+        print("Triggered with token", token)
+
+        for session in self.sessions.values():
+            if token == session.trigger_token:
+                print("Trigger session", session)
+                session.trigger()
+                return True
+
+        return False
+
+    def _stop_all(self):
+        """Stop all sessions and this daemon."""
+        for session in list(self.sessions.values()):
+            session.stop()
+            del self.sessions[session.pid]
+        self.trigger_service.shutdown()
+        self.loop.quit()
 
 
 def str2bool(v: Union[str, bool, int]) -> bool:
@@ -144,5 +173,12 @@ if __name__ == "__main__":
     dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
 
     loop = GLib.MainLoop()
-    object = RWAService(mockup_mode)
+    object = RWAService(loop, mockup_mode)
+
+    def signal_handler(sig, frame):
+        print("You pressed Ctrl+C!")
+        object._stop_all()
+
+    signal.signal(signal.SIGINT, signal_handler)
+
     loop.run()
diff --git a/session.py b/session.py
index d5c4281..b57bf57 100644
--- a/session.py
+++ b/session.py
@@ -36,20 +36,6 @@ def get_desktop_dir():
     )
 
 
-class ServerThread(threading.Thread):
-    def __init__(self, app, port: int):
-        super().__init__()
-        self.srv = make_server("127.0.0.1", port, app)
-        self.ctx = app.app_context()
-        self.ctx.push()
-
-    def run(self):
-        self.srv.serve_forever()
-
-    def shutdown(self):
-        self.srv.shutdown()
-
-
 class Session:
     #: Session is running
     STATUS_RUNNING = "running"
@@ -57,13 +43,15 @@ class Session:
     #: Remote has joined the session
     STATUS_JOINED = "active"
 
-    def __init__(self, mockup_session: bool):
+    def __init__(self, trigger_port: int, mockup_session: bool = False):
+        self.trigger_token = secrets.token_urlsafe(20)
+        self.trigger_port = trigger_port
         self.done_jobs = []
         self.mockup_session = mockup_session
         self.desktop_dir = get_desktop_dir()
+        self.desktop_dir = get_desktop_dir()
         self._generate_password()
         self._start_vnc()
-        self._start_trigger_service()
         self._register_session()
         self.status_text = self.STATUS_RUNNING
 
@@ -135,25 +123,7 @@ class Session:
             self.api_token = secrets.token_urlsafe(10)
             self.pin = int(random_digits(5))
 
-    def _start_trigger_service(self):
-        self.trigger_port = port_for.select_random()
-        self.trigger_token = secrets.token_urlsafe(20)
-
-        app = Flask(__name__)
-
-        @app.route("/", methods=["POST"])
-        def trigger():
-            json = request.json
-            if json.get("token", "") == self.trigger_token:
-                self._trigger()
-                return "Successful triggered"
-            else:
-                return abort(403)
-
-        self.trigger_thread = ServerThread(app=app, port=self.trigger_port)
-        self.trigger_thread.start()
-
-    def _trigger(self):
+    def trigger(self):
         """Event triggered by Django."""
         print("Triggered")
         self.pull()
diff --git a/trigger.py b/trigger.py
new file mode 100644
index 0000000..a7fb66c
--- /dev/null
+++ b/trigger.py
@@ -0,0 +1,36 @@
+import threading
+from typing import Any, Callable
+from wsgiref.simple_server import make_server
+
+import port_for
+from flask import Flask, abort, request
+
+
+class TriggerServerThread(threading.Thread):
+    """Simple Flask server (wrapped as thread) for triggering actions on sesssions."""
+
+    def __init__(self, trigger_method: Callable[[str], Any]):
+        super().__init__()
+        self.port = port_for.select_random()
+
+        app = Flask(__name__)
+
+        @app.route("/", methods=["POST"])
+        def trigger():
+            json = request.json
+            token = json.get("token", "")
+            r = trigger_method(token)
+            if r:
+                return "Successful triggered"
+            else:
+                return abort(403)
+
+        self.srv = make_server("127.0.0.1", self.port, app)
+        self.ctx = app.app_context()
+        self.ctx.push()
+
+    def run(self):
+        self.srv.serve_forever()
+
+    def shutdown(self):
+        self.srv.shutdown()
-- 
GitLab