From 2a2435dccdcebb209e9829e144e1ee6ebc9f5be9 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 02:43:28 +0300 Subject: [PATCH 01/18] feat: add first e2e test --- tests/test_e2e_run/http_apps/csrf_form/app.py | 41 ++++++ tests/test_e2e_run/selenium_runners.py | 29 ++++ tests/test_e2e_run/test_http_apps.py | 137 ++++++++++++++++++ 3 files changed, 207 insertions(+) create mode 100644 tests/test_e2e_run/http_apps/csrf_form/app.py create mode 100644 tests/test_e2e_run/selenium_runners.py create mode 100644 tests/test_e2e_run/test_http_apps.py diff --git a/tests/test_e2e_run/http_apps/csrf_form/app.py b/tests/test_e2e_run/http_apps/csrf_form/app.py new file mode 100644 index 0000000..1a63b26 --- /dev/null +++ b/tests/test_e2e_run/http_apps/csrf_form/app.py @@ -0,0 +1,41 @@ +from flask import Flask, request, render_template_string +from flask_wtf import FlaskForm +from wtforms import StringField, SubmitField +from wtforms.validators import DataRequired +import secrets + +app = Flask(__name__) +app.secret_key = secrets.token_hex(16) # Needed for CSRF + +SECRET1 = "alpha" +SECRET2 = "beta" + + +class MyForm(FlaskForm): + field1 = StringField("Field 1", validators=[DataRequired()]) + field2 = StringField("Field 2", validators=[DataRequired()]) + submit = SubmitField("Submit") + + +@app.route("/", methods=["GET", "POST"]) +def index(): + form = MyForm() + if form.validate_on_submit() and form.field1.data == SECRET1 and form.field2.data == SECRET2: + return "SUCCESS", 200 + return render_template_string( + """ +
+ +The secret is not alpha
+ """, + form=form, + ) + + +if __name__ == "__main__": + app.run(debug=False) diff --git a/tests/test_e2e_run/selenium_runners.py b/tests/test_e2e_run/selenium_runners.py new file mode 100644 index 0000000..5175fbe --- /dev/null +++ b/tests/test_e2e_run/selenium_runners.py @@ -0,0 +1,29 @@ +from selenium import webdriver +from selenium.webdriver.common.by import By + + +def run_csrf_form_submitsuccess(driver): + # Step # | name | target | value + # 1 | open | / | + driver.get("http://127.0.0.1:5000/") + # 2 | setWindowSize | 736x729 | + driver.set_window_size(736, 729) + # 3 | click | css=html | + driver.find_element(By.CSS_SELECTOR, "html").click() + # 4 | click | css=form | + driver.find_element(By.CSS_SELECTOR, "form").click() + # 5 | click | id=field2 | + driver.find_element(By.ID, "field2").click() + # 6 | click | id=field1 | + driver.find_element(By.ID, "field1").click() + # 7 | type | id=field1 | alpha + driver.find_element(By.ID, "field1").send_keys("alpha") + # 8 | type | id=field2 | beta + driver.find_element(By.ID, "field2").send_keys("beta") + # 9 | click | id=submit | + driver.find_element(By.ID, "submit").click() + + +selenium_runs = { + "http_apps/csrf_form": [run_csrf_form_submitsuccess], +} diff --git a/tests/test_e2e_run/test_http_apps.py b/tests/test_e2e_run/test_http_apps.py new file mode 100644 index 0000000..4d7d1f8 --- /dev/null +++ b/tests/test_e2e_run/test_http_apps.py @@ -0,0 +1,137 @@ +import asyncio +import logging +import os +import subprocess +import time +import pathlib +import urllib.request + +import pytest +from selenium import webdriver +from selenium_runners import run_csrf_form_submitsuccess + +from cdprecorder import erpeto, skopo, recorder + + +VENV_HTTP_APPS = ".venv_http_apps" + + +class VenvAppRunner: + def __init__(self, app_path: str): + abs_app_path = pathlib.Path(__file__).parent.resolve() / app_path / "app.py" + python_binary = os.path.join(VENV_HTTP_APPS, "bin/python") + self.proc = subprocess.Popen( + [python_binary, abs_app_path], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + def wait_until_up(self, timeout: int = 3): + start_time = time.time() + while time.time() - start_time < timeout: + try: + ret = self.proc.poll() + if ret is not None: + out, err = self.proc.communicate() + raise RuntimeError(f"Processed terminated. retcode: {ret}. stderr: {err}") + urllib.request.urlopen("http://localhost:5000") + break + except urllib.error.URLError: + time.sleep(0.1) + else: + self.proc.kill() + out, err = self.proc.communicate() + print(f"stdout: {out}") + print(f"stderr: {err}") + raise TimeoutError(f"Timeout exceeded: {timeout} seconds") + + def __del__(self): + self.proc.terminate() + + +async def on_fail(comparator, httpobj1, httpobj2): + logging.error(f"Comparator failed after {comparator.requests_passed} passes") + logging.error(f"First httpobj: {httpobj1}") + logging.error(f"Second httpobj: {httpobj2}") + raise Exception("Failure") + + +@pytest.mark.asyncio +async def test_csrf_form_run_csrf_from_submit_success(): + logging.info("Starting web app") + app = VenvAppRunner("http_apps/csrf_form") + app.wait_until_up() + logging.info("Web app started") + + sniffer_manager1 = skopo.MitmproxySnifferManager("1") + sniffer_manager1.start_sniffer_on_thread() + proxy1 = sniffer_manager1.start_proxy_instance(port=8080) + + sniffer_manager2 = skopo.MitmproxySnifferManager("2") + sniffer_manager2.start_sniffer_on_thread() + proxy2 = sniffer_manager2.start_proxy_instance(port=8081) + + await sniffer_manager1.wait_for_proxy_connection_with_sniffer() + await sniffer_manager2.wait_for_proxy_connection_with_sniffer() + + proxy_url1 = f"http://{proxy1.host}:{proxy1.port}" + proxy_url2 = f"http://{proxy2.host}:{proxy2.port}" + + # proxy_url = f"http://{proxy_info.host}:{proxy_info.port}" + options = webdriver.ChromeOptions() + cdp_port = 9222 + options.add_argument(f"--remote-debugging-port={cdp_port}") + options.add_argument(f"--proxy-server={proxy_url1}") + options.add_argument("--ignore-ceritifcate-erros") + driver = webdriver.Chrome(options) + logging.info("Instantiated web driver") + + recorder_options = recorder.RecorderOptions( + "http://localhost:5000", + cdp_host="localhost", + cdp_port=cdp_port, + collect_all=True, + ) + + try: + rec = await recorder.init_recorder(recorder_options) + + t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) + t2 = asyncio.create_task(recorder.collect_communications(rec, 20)) + done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED) + + if t1 in pending: + raise Exception("Recorder stopped before selenium test") + + if t2 in pending: + rec.listener.cancel() + communications = await t2 + finally: + await rec.close() + + logging.info("Recorded communications") + + actions = erpeto.parse_communications_into_actions(communications) + erpeto.make_action_ids_consecutive_from_list(actions) + # actions = await erpeto.run_recorder(recorder_options) + erpeto.run_analyse(actions) + + # TODO: probably, a sniffer that's not linked to a comparator should not block + # comparator = skopo.SnifferComparator(on_fail, sniffer_manager1.sniffer, sniffer_manager2.sniffer) + + t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) + t2 = asyncio.create_task(asyncio.to_thread(erpeto.run_replicate, actions)) + t3 = asyncio.create_task(comparator.run()) + + pair = asyncio.wait([t1, t2], return_when=asyncio.ALL_COMPLETED) + done, pending = await asyncio.wait([pair, t3], return_when=asyncio.FIRST_COMPLETED) + + # TODO: this probably can't cancel t1 and t2 + for t in pending: + t.cancel() + + if t3 in done: + res = t3.result() + + assert res is True From 82e5956ab4b91ed9bdf28ab2bd7059b92c9164f0 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 02:45:04 +0300 Subject: [PATCH 02/18] feat: move managing stuff into erpeto module --- cdprecorder/erpeto.py | 361 +++++++++++++++++++++++++++++++++++++ main.py | 403 +++++------------------------------------- 2 files changed, 408 insertions(+), 356 deletions(-) create mode 100644 cdprecorder/erpeto.py diff --git a/cdprecorder/erpeto.py b/cdprecorder/erpeto.py new file mode 100644 index 0000000..b19fce5 --- /dev/null +++ b/cdprecorder/erpeto.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +from typing import cast, Optional, Union, TYPE_CHECKING + +import bs4 +import bs4.builder._htmlparser +import pycdp +import requests +import sys +import twisted.internet.reactor + +from twisted.python.log import err +from twisted.internet import defer, threads +from twisted.internet.interfaces import IReactorCore +from pycdp import cdp + +import cdprecorder +from cdprecorder import generate_python +from cdprecorder.action import ( + BrowserAction, + InputAction, + HttpAction, + LowercaseStr, + RequestAction, + ResponseAction, + response_action_from_python_response, +) +from cdprecorder.recorder import ( + HttpCommunication, + RecorderOptions, + record, +) + +import cdprecorder.analyser + +if TYPE_CHECKING: + import bs4 + + from pycdp.cdp.util import T_JSON_DICT + from twisted.python.failure import Failure + + from cdprecorder.type_checking import CdpEvent, HttpTarget + + +# https://github.com/twisted/twisted/issues/9909 +reactor = cast(IReactorCore, twisted.internet.reactor) + + +def generate_action(action: HttpAction, prev_new_actions: list[Optional[HttpAction]]) -> RequestAction: + new_action = RequestAction() + new_action.shallow_copy_from_action(action) + for target in action.targets: + target.apply(new_action, prev_new_actions) + + return new_action + + +def run_actions(actions: list[HttpAction]) -> None: + new_actions: list[Optional[HttpAction]] = [] + + for action in actions: + if isinstance(action, RequestAction): + new_action = generate_action(action, new_actions) + new_actions.append(new_action) + + with requests.Session() as session: + req = requests.Request( + method=new_action.method, + url=new_action.url, + headers=new_action.headers, + data=new_action.body, + cookies=new_action.cookies_to_dict(), + ) + prepared_request = req.prepare() + resp = session.send(prepared_request, allow_redirects=False) + resp_action = response_action_from_python_response(resp) + new_actions.append(resp_action) + + print(f"{new_action.method} {new_action.url} - {resp.status_code}") + + elif not isinstance(action, ResponseAction): + new_actions.append(None) + + +def to_cdp_event(event: CdpEvent) -> dict[str, Union[str, T_JSON_DICT]]: + cdp_method = None + for key, val in cdp.util._event_parsers.items(): + if val == event.__class__: + cdp_method = key + break + else: + raise Exception + + return { + "method": cdp_method, + "params": event.to_json(), + "type": "recv", + "domain": "-", + } + + +def get_only_http_actions(actions: list[BrowserActions]) -> list[HttpActions]: + return [action for action in actions if isinstance(action, HttpAction)] + + +def _generate_events_with_redirects_extracted(events: list[CdpEvent]) -> list[CdpEvent]: + new_events = [] + future_events: list[CdpEvents] = [] + wait_response_extra = False + wait_request_extra = False + wait_extra = False + for evt in events: + if wait_extra: + if ( + isinstance(evt, cdp.network.RequestWillBeSentExtraInfo) + and not wait_request_extra + or isinstance(evt, cdp.network.ResponseReceivedExtraInfo) + and not wait_response_extra + ): + wait_extra = False + new_events += future_events + future_events = [] + + if isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): + wait_request_extra = False + elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): + wait_response_extra = False + wait_extra = wait_response_extra or wait_request_extra + + new_events.append(evt) + continue + else: + new_events += future_events + future_events = [] + + future_events.append(evt) + if not isinstance(evt, cdp.network.RequestWillBeSent): + continue + + if not evt.redirect_response: + new_events += future_events + future_events = [] + wait_extra = False + else: + if evt.redirect_has_extra_info: + wait_response_extra = True + wait_request_extra = True + wait_extra = True + + response_evt = cdp.network.ResponseReceived( + request_id=evt.request_id, + loader_id=evt.loader_id, + timestamp=evt.timestamp, + type_=evt.type_, + response=evt.redirect_response, + has_extra_info=evt.redirect_has_extra_info, + frame_id=evt.frame_id, + ) + new_events.append(response_evt) + + new_events += future_events + + return new_events + + +def parse_communications_into_actions( + communications: list[Union[HttpCommunication, InputActioni]] +) -> list[BrowserAction]: + from cdprecorder import logger + + actions: list[BrowserAction] = [] + + for comm in communications: + logger.debug("Comm: %s", repr(comm)) + + if not isinstance(comm, HttpCommunication): + actions.append(comm) + continue + + if comm.ignored: + continue + + response_bodies = list(comm.response_bodies) + + curr_request: Optional[RequestAction] = None + request_extra: Optional[RequestAction] = None + curr_response: Optional[ResponseAction] = None + response_extra: Optional[ResponseAction] = None + events = _generate_events_with_redirects_extracted(comm.events) + print("--------------------------------------------------------") + # Append to actions the requests/responses from each event + for evt in events: + if isinstance(evt, cdp.network.RequestWillBeSent): + if curr_request is not None: + if all((curr_request, request_extra, curr_response)): + curr_request.has_response = True + actions.append(curr_request) + actions.append(curr_response) + else: + actions.append(curr_request) + if curr_response is not None: + curr_request.has_response = True + actions.append(curr_response) + + curr_request = None + request_extra = None + curr_response = None + + """ + if curr_request is not None: + # Consume the previous request + if curr_response: + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + request_extra = None + + if curr_response: + # Consume the previous response + actions.append(curr_response) + if response_extra: + raise Exception + curr_response = None + """ + + curr_request = RequestAction() + curr_request.update_info(evt.request) + if evt.request.has_post_data and evt.request.post_data: + # TODO: Check if bytes in other entry + curr_request.set_body(evt.request.post_data.encode()) + + if request_extra is not None: + curr_request.merge(request_extra) + + elif isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): + if request_extra is not None: + if all((curr_request, request_extra, curr_response)): + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + request_extra = None + + actions.append(curr_response) + curr_response = None + + """ + if curr_request is not None and request_extra is not None: + # Consume the previous request + if curr_response: + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + request_extra = None + + if curr_response: + # Consume the previous response + actions.append(curr_response) + if response_extra: + raise Exception + curr_response = None + """ + + if request_extra is not None: + raise Exception + request_extra = RequestAction() + request_extra.update_info(evt) + + if curr_request is not None: + curr_request.merge(request_extra) + # request_extra = None + + elif isinstance(evt, cdp.network.ResponseReceived): + if curr_response is None: + curr_response = ResponseAction(evt.response) + else: + raise Exception + + if response_extra is not None: + # Always merge response_extra over curr_response, not the other way + curr_response.merge(response_extra) + response_extra = None + + elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): + if curr_response is not None: + # Always merge response_extra over curr_response, not the other way + curr_response.merge(ResponseAction(evt)) + elif response_extra is None: + response_extra = ResponseAction(evt) + else: + raise Exception + + elif isinstance(evt, cdp.network.LoadingFinished): + # Manually inserted + response_body = response_bodies.pop(0) + if response_body is not None: + if curr_response: + curr_response.set_body(response_body) + elif response_extra: + response_extra.set_body(response_body) + else: + raise Exception + + if curr_request is not None: + if curr_response or response_extra: + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + if curr_response is not None: + actions.append(curr_response) + curr_response = None + elif response_extra is not None: + actions.append(response_extra) + response_extra = None + + if curr_request is not None: + if curr_response is not None: + curr_request.has_response = True + curr_request.merge(request_extra) + actions.append(curr_request) + if curr_response is not None: + # Always merge response_extra over curr_response, not the other way + curr_response.merge(response_extra) + actions.append(curr_response) + + return actions + + +def make_action_ids_consecutive_from_list(actions: list[BrowserAction]): + for i, action in enumerate(actions): + action.ID = i + + +async def run_recorder(options: RecorderOptions): + communications = await record(options) + actions = parse_communications_into_actions(communications) + make_action_ids_consecutive_from_list(actions) + + return actions + + +def run_analyse(actions): + cdprecorder.analyser.analyse_actions(actions) + + +def run_replicate(actions): + run_actions(actions) + + +async def run(options: RecorderOptions) -> None: + actions = await run_recorder(options) + run_analyse(actions) + run_replicate(actions) + + # actions = get_only_http_actions(actions) + # run_actions(actions) + + # generate_python.write_python_code(actions, "generated.py") + + """ + await threads.deferToThread(chrome.kill) + """ diff --git a/main.py b/main.py index 791a07d..6c0fcf4 100644 --- a/main.py +++ b/main.py @@ -1,385 +1,76 @@ from __future__ import annotations -from typing import cast, Optional, Union, TYPE_CHECKING - -import bs4 -import bs4.builder._htmlparser -import pycdp -import requests +import asyncio import sys -import twisted.internet.reactor - -from twisted.python.log import err -from twisted.internet import defer, threads -from twisted.internet.interfaces import IReactorCore -from pycdp import cdp import cdprecorder -from cdprecorder import generate_python -from cdprecorder.action import ( - BrowserAction, - InputAction, - HttpAction, - LowercaseStr, - RequestAction, - ResponseAction, - response_action_from_python_response, -) -from cdprecorder.recorder import ( - HttpCommunication, - RecorderOptions, - record, -) - -import cdprecorder.analyser - -if TYPE_CHECKING: - import bs4 - - from pycdp.cdp.util import T_JSON_DICT - from twisted.python.failure import Failure - - from cdprecorder.type_checking import CdpEvent, HttpTarget - - -# https://github.com/twisted/twisted/issues/9909 -reactor = cast(IReactorCore, twisted.internet.reactor) - - -def generate_action( - action: HttpAction, prev_new_actions: list[Optional[HttpAction]] -) -> RequestAction: - new_action = RequestAction() - new_action.shallow_copy_from_action(action) - for target in action.targets: - target.apply(new_action, prev_new_actions) - - return new_action - - -def run_actions(actions: list[HttpAction]) -> None: - new_actions: list[Optional[HttpAction]] = [] - - for action in actions: - if isinstance(action, RequestAction): - new_action = generate_action(action, new_actions) - new_actions.append(new_action) - - with requests.Session() as session: - req = requests.Request( - method=new_action.method, - url=new_action.url, - headers=new_action.headers, - data=new_action.body, - cookies=new_action.cookies_to_dict(), - ) - prepared_request = req.prepare() - resp = session.send(prepared_request, allow_redirects=False) - resp_action = response_action_from_python_response(resp) - new_actions.append(resp_action) - - print(f"{new_action.method} {new_action.url} - {resp.status_code}") - - elif not isinstance(action, ResponseAction): - new_actions.append(None) - - -def to_cdp_event(event: CdpEvent) -> dict[str, Union[str, T_JSON_DICT]]: - cdp_method = None - for key, val in cdp.util._event_parsers.items(): - if val == event.__class__: - cdp_method = key - break - else: - raise Exception - - return { - "method": cdp_method, - "params": event.to_json(), - "type": "recv", - "domain": "-", - } - - -def get_only_http_actions(actions: list[BrowserActions]) -> list[HttpActions]: - return [action for action in actions if isinstance(action, HttpAction)] - - -def _generate_events_with_redirects_extracted(events: list[CdpEvent]) -> list[CdpEvent]: - new_events = [] - future_events: list[CdpEvents] = [] - wait_response_extra = False - wait_request_extra = False - wait_extra = False - for evt in events: - if wait_extra: - if ( - isinstance(evt, cdp.network.RequestWillBeSentExtraInfo) - and not wait_request_extra - or isinstance(evt, cdp.network.ResponseReceivedExtraInfo) - and not wait_response_extra - ): - wait_extra = False - new_events += future_events - future_events = [] - - if isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): - wait_request_extra = False - elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): - wait_response_extra = False - wait_extra = wait_response_extra or wait_request_extra - - new_events.append(evt) - continue - else: - new_events += future_events - future_events = [] - - future_events.append(evt) - if not isinstance(evt, cdp.network.RequestWillBeSent): - continue - - if not evt.redirect_response: - new_events += future_events - future_events = [] - wait_extra = False - else: - if evt.redirect_has_extra_info: - wait_response_extra = True - wait_request_extra = True - wait_extra = True - - response_evt = cdp.network.ResponseReceived( - request_id=evt.request_id, - loader_id=evt.loader_id, - timestamp=evt.timestamp, - type_=evt.type_, - response=evt.redirect_response, - has_extra_info=evt.redirect_has_extra_info, - frame_id=evt.frame_id, - ) - new_events.append(response_evt) - - new_events += future_events +from cdprecorder import erpeto, recorder, skopo - return new_events +async def on_fail(comparator, httpobj1, httpobj2): + print(f"Comparator failed after {comparator.requests_passed} passes") + print(f"First httpobj: {httpobj1}") + print(f"Second httpobj: {httpobj2}") + raise Exception("Failure") -def parse_communications_into_actions( - communications: list[Union[HttpCommunication, InputActioni]] -) -> list[BrowserAction]: - from cdprecorder import logger - actions: list[BrowserAction] = [] - - for comm in communications: - logger.debug("Comm: %s", repr(comm)) - - if not isinstance(comm, HttpCommunication): - actions.append(comm) - continue - - if comm.ignored: - continue - - response_bodies = list(comm.response_bodies) - - curr_request: Optional[RequestAction] = None - request_extra: Optional[RequestAction] = None - curr_response: Optional[ResponseAction] = None - response_extra: Optional[ResponseAction] = None - events = _generate_events_with_redirects_extracted(comm.events) - print("--------------------------------------------------------") - # Append to actions the requests/responses from each event - for evt in events: - if isinstance(evt, cdp.network.RequestWillBeSent): - if curr_request is not None: - if all((curr_request, request_extra, curr_response)): - curr_request.has_response = True - actions.append(curr_request) - actions.append(curr_response) - else: - actions.append(curr_request) - if curr_response is not None: - curr_request.has_response = True - actions.append(curr_response) - - curr_request = None - request_extra = None - curr_response = None - - """ - if curr_request is not None: - # Consume the previous request - if curr_response: - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - request_extra = None - - if curr_response: - # Consume the previous response - actions.append(curr_response) - if response_extra: - raise Exception - curr_response = None - """ - - curr_request = RequestAction() - curr_request.update_info(evt.request) - if evt.request.has_post_data and evt.request.post_data: - # TODO: Check if bytes in other entry - curr_request.set_body(evt.request.post_data.encode()) - - if request_extra is not None: - curr_request.merge(request_extra) - - elif isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): - if request_extra is not None: - if all((curr_request, request_extra, curr_response)): - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - request_extra = None - - actions.append(curr_response) - curr_response = None - - """ - if curr_request is not None and request_extra is not None: - # Consume the previous request - if curr_response: - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - request_extra = None - - if curr_response: - # Consume the previous response - actions.append(curr_response) - if response_extra: - raise Exception - curr_response = None - """ - - if request_extra is not None: - raise Exception - request_extra = RequestAction() - request_extra.update_info(evt) - - if curr_request is not None: - curr_request.merge(request_extra) - # request_extra = None - - elif isinstance(evt, cdp.network.ResponseReceived): - if curr_response is None: - curr_response = ResponseAction(evt.response) - else: - raise Exception - - if response_extra is not None: - # Always merge response_extra over curr_response, not the other way - curr_response.merge(response_extra) - response_extra = None - - elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): - if curr_response is not None: - # Always merge response_extra over curr_response, not the other way - curr_response.merge(ResponseAction(evt)) - elif response_extra is None: - response_extra = ResponseAction(evt) - else: - raise Exception +async def main() -> None: + sniffer_manager1 = skopo.MitmproxySnifferManager("1") + sniffer_manager1.start_sniffer_on_thread() + proxy1 = sniffer_manager1.start_proxy_instance(port=8082) - elif isinstance(evt, cdp.network.LoadingFinished): - # Manually inserted - response_body = response_bodies.pop(0) - if response_body is not None: - if curr_response: - curr_response.set_body(response_body) - elif response_extra: - response_extra.set_body(response_body) - else: - raise Exception + sniffer_manager2 = skopo.MitmproxySnifferManager("2") + sniffer_manager2.start_sniffer_on_thread() + proxy2 = sniffer_manager2.start_proxy_instance(port=8083) - if curr_request is not None: - if curr_response or response_extra: - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - if curr_response is not None: - actions.append(curr_response) - curr_response = None - elif response_extra is not None: - actions.append(response_extra) - response_extra = None + print(f"Event loop: {asyncio.get_event_loop()}") + comparator = skopo.SnifferComparator( + on_fail, sniffer_manager1.sniffer, sniffer_manager2.sniffer, asyncio.get_event_loop() + ) - if curr_request is not None: - if curr_response is not None: - curr_request.has_response = True - curr_request.merge(request_extra) - actions.append(curr_request) - if curr_response is not None: - # Always merge response_extra over curr_response, not the other way - curr_response.merge(response_extra) - actions.append(curr_response) + await sniffer_manager1.wait_for_proxy_connection_with_sniffer() + await sniffer_manager2.wait_for_proxy_connection_with_sniffer() - return actions + print("Waited") + proxy_url1 = f"http://{proxy1.host}:{proxy1.port}" + proxy_url2 = f"http://{proxy2.host}:{proxy2.port}" -def make_action_ids_consecutive_from_list(actions: list[BrowserAction]): - for i, action in enumerate(actions): - action.ID = i + import requests + def threaded_run1(): + proxies = {"http": proxy_url1, "https": proxy_url1} + r = requests.get("https://google.com", proxies=proxies, verify=False) + print("done") + r = requests.get("https://google.com/test", proxies=proxies, verify=False) + r = requests.get("https://google.com/test2", proxies=proxies, verify=False) -async def run(options: RecorderOptions) -> None: - communications = await record(options) - actions = parse_communications_into_actions(communications) - make_action_ids_consecutive_from_list(actions) - cdprecorder.analyser.analyse_actions(actions) - # actions = get_only_http_actions(actions) - # run_actions(actions) + def threaded_run2(): + proxies = {"http": proxy_url2, "https": proxy_url2} + r = requests.get("https://google.com", proxies=proxies, verify=False) + r = requests.get("https://google.com/test", proxies=proxies, verify=False) + r = requests.get("https://google.com/test", proxies=proxies, verify=False) - generate_python.write_python_code(actions, "generated.py") + import threading - """ - async with target_session.wait_for(cdp.runtime.ConsoleAPICalled) as evt: - print(evt) - if evt.args[0].value == "click": - for obj in evt.args: - if obj.class_name == "PointerEvent": + thread1 = threading.Thread(target=threaded_run1, daemon=True) + thread1.start() + thread2 = threading.Thread(target=threaded_run2, daemon=True) + thread2.start() - result = await target_session.execute(cdp.runtime.get_properties(obj.object_id)) - pointer_evt_props = result[0] - break + await comparator.run() - for prop in pointer_evt_props: - if prop.name == "srcElement": - print(prop) - result = await target_session.execute(cdp.dom.describe_node(object_id=prop.value.object_id)) - print(result) - """ + time.sleep(20) + exit() """ - await threads.deferToThread(chrome.kill) - """ - - -async def main() -> None: cdprecorder.enable_logger() cdprecorder.configure_root_logger(stream=sys.stdout) start_url = "https://github.com" - options = RecorderOptions(start_url) - await run(options) - - -def main_error(failure: Failure) -> None: - err(failure) # type: ignore[no-untyped-call] - reactor.stop() + options = recorder.RecorderOptions(start_url) + await erpeto.run(options) + """ if __name__ == "__main__": - d = defer.ensureDeferred(main()) - d.addErrback(main_error) - d.addCallback(lambda *args: reactor.stop()) - reactor.run() + asyncio.run(main()) From 4f056335e4d21b3954b61d1f8ed1ab9870a5e797 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:04:47 +0300 Subject: [PATCH 03/18] feat: add skopo --- cdprecorder/skopo/__init__.py | 419 +++++++++++++++++++++++++++ cdprecorder/skopo/common_data.py | 204 +++++++++++++ cdprecorder/skopo/intercept_addon.py | 224 ++++++++++++++ cdprecorder/skopo/mitm_runner.py | 39 +++ 4 files changed, 886 insertions(+) create mode 100644 cdprecorder/skopo/__init__.py create mode 100644 cdprecorder/skopo/common_data.py create mode 100644 cdprecorder/skopo/intercept_addon.py create mode 100644 cdprecorder/skopo/mitm_runner.py diff --git a/cdprecorder/skopo/__init__.py b/cdprecorder/skopo/__init__.py new file mode 100644 index 0000000..b944eb1 --- /dev/null +++ b/cdprecorder/skopo/__init__.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import asyncio +import os +import pickle +import queue +import socket +import sys +import threading +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .._storage import get_runtime_dir, DEFAULT_SOCKET_NAME +from .. import logger +from . import mitm_runner, common_data +from .common_data import RequestData, ResponseData + + +if TYPE_CHECKING: + import subprocess + from asyncio.events import AbstractEventLoop + from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar + + T = TypeVar("T") + + +# for pickle to find the module +# Reference: https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory +sys.modules["common_data"] = common_data + + +class SnifferException(Exception): + pass + + +def asyncio_run_coroutine_threadsafe(coro: Coroutine) -> None: + asyncio.run_coroutine_threadsafe(coro, asyncio.get_event_loop()) + + +class ThreadsafeAsyncWaker: + def __init__(self, loop: AbstractEventLoop, callback: Callable[[], Any]): + self.loop = loop + self.callback = callback + + def wake(self) -> None: + asyncio.run_coroutine_threadsafe(self.callback(), self.loop) + + +class ThreadSafeAsyncQueue(Generic[T]): + def __init__(self, size: int, loop: Optional[AbstractEventLoop] = None): + if loop is None: + loop = asyncio.get_event_loop() + self._tqueue: queue.Queue[T] = queue.Queue(size) + self._aqueue: asyncio.Queue[T] = asyncio.Queue(size) + self._async_waker = ThreadsafeAsyncWaker(loop, self.transfer_to_async) + + async def transfer_to_async(self) -> None: + try: + item = self._tqueue.get(block=False) + await self._aqueue.put(item) + except queue.Empty: + pass + + def put(self, item: T) -> None: + self._tqueue.put(item) + self._async_waker.wake() + + def get(self) -> T: + return self._tqueue.get() + + async def async_put(self, item: T) -> None: + await self._aqueue.put(item) + try: + item = await self._aqueue.get_nowait() + # TODO: This might block :( + self._tqueue.put(item) + except asyncio.QueueEmpty: + pass + + async def async_get(self) -> T: + return await self._aqueue.get() + + +class CrossThreadSnifferBase: + """Thread safe methods used by the Sniffer class.""" + + def __init__(self, loop: Optional[AbstractEventLoop] = None): + self.lock = threading.Lock() + self.signal_queue = None + if loop is not None: + self.signal_queue = ThreadSafeAsyncQueue(0, loop) + self.httpobj_queue = None + self.modification_queue = queue.Queue(maxsize=1) + + self.client_connected = False + + def set_httpobj_queue(self, httpobj_queue: queue.Queue): + with self.lock: + self.httpobj_queue = httpobj_queue + + def publish_httpobj(self, httpobj): + if self.httpobj_queue is None: + return + self.httpobj_queue.put((self, httpobj)) + + def get_modification(self): + return self.modification_queue.get() + + def publish_modification(self, obj: object): + self.modification_queue.put(obj) + + def send_signal(self, signal): + self.signal_queue.put(signal) + + +class Sniffer(CrossThreadSnifferBase): + SIG_CLIENT_CONNECTED = 1 + + def __init__(self, coroutine_runner=None, coroutine_runner_threadsafe=None, sock_suffix: str = "", *args, **kwargs): + super().__init__(*args, **kwargs) + if coroutine_runner is not None: + self.coroutine_runner = coroutine_runner + else: + self.coroutine_runner = asyncio.run + if coroutine_runner_threadsafe is not None: + self.coroutine_runner_threadsafe = coroutine_runner_threadsafe + else: + self.coroutine_runner_threadsafe = asyncio_run_coroutine_threadsafe + + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.running = False + + dirpath = get_runtime_dir() + self.sockpath = os.path.join(dirpath, DEFAULT_SOCKET_NAME + sock_suffix) + # Try to delete the socket if it already exists + try: + os.unlink(self.sockpath) + except FileNotFoundError: + pass + + self._id_counter = 0 + self.client_connected = False + + def run(self) -> None: + self.coroutine_runner(self._async_run()) + + async def on_client_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + if self.client_connected: + raise SnifferException("Sniffer only supports one client at a time") + self.client_connected = True + self.send_signal(self.SIG_CLIENT_CONNECTED) + try: + print("Client connected") + conn = common_data.AsyncioSniffConnection(reader, writer) + try: + try: + proxyname = await conn.read_str() + except common_data.SniffProtocolException as exc: + raise SnifferException from exc + + print(f"Proxy name: {proxyname}") + while True: + try: + httpobj = await conn.read_httpobj() + except common_data.SniffProtocolException as exc: + raise SnifferException from exc + + print(f"Got {httpobj}") + + # Yes, the calls to Queue can block the thread, that's + # why we allow only one client connection at a time + self.publish_httpobj(httpobj) + modification = self.get_modification() + + req, resp = modification + + if resp is not None: + await conn.send("REPLACE_RESPONSE") + await conn.send(resp) + elif req is not None and not isinstance(httpobj, ResponseData): + await conn.send("REPLACE_REQUEST") + await conn.send(req) + else: + await conn.send("OK") + except asyncio.IncompleteReadError: + pass + + print("Client done") + except Exception as exc: + print(f"Sniffer Exception: {exc}") + import traceback + + traceback.print_exception(exc) + raise + finally: + self.client_connected = False + + async def _async_run(self): + try: + os.unlink(self.sockpath) + except FileNotFoundError: + pass + self.server = await asyncio.start_unix_server(self.on_client_connected, self.sockpath) + await self.server.serve_forever() + + async def _async_stop(self): + self.server.close() + + def stop(self) -> None: + self.coroutine_runner_threadsafe(self._stop()) + + +class SnifferThreadCombiner: + """Combines sniffers from different threads and makes them accessible + from a single thread.""" + + async def __init__(self, sniffers: list[Sniffer]): + self.sniffers = sniffers + self.read_queue = ThreadSafeAsyncQueue(0, asyncio.get_event_loop()) + + for sniffer in self.sniffers: + sniffer.set_httpobj_queue(self.read_queue) + + async def get_message(self): + return await self.read_queue.async_get() + + async def send_message(self, sniffer, obj: object): + sniffer.modification_queue.put(obj) + + +# TODO: The sniffer comparator should better have references to the sniffers +class SnifferComparator: + def __init__(self, on_fail, sniffer1: Sniffer, sniffer2: Sniffer, event_loop=None): + self.on_fail = on_fail + + """ + self.req1 = None + self.req2 = None + self.res1 = None + self.request1_ready = asyncio.Semaphore(0) + self.request2_ready = asyncio.Semaphore(0) + self.requests_ready = asyncio.Semaphore(0) + + self.request1_lock = asyncio.Lock() + self.request2_lock = asyncio.Lock() + """ + + self.requests_passed = 0 + + self.sniffer1 = sniffer1 + self.sniffer2 = sniffer2 + + self.sniffer1_httpobj_queue = ThreadSafeAsyncQueue(0, event_loop) + self.sniffer2_httpobj_queue = ThreadSafeAsyncQueue(0, event_loop) + + sniffer1.set_httpobj_queue(self.sniffer1_httpobj_queue) + sniffer2.set_httpobj_queue(self.sniffer2_httpobj_queue) + # self.combiner = SnifferThreadCombiner([sniffer1, sniffer2]) + self.running = False + + async def run(self): + self.running = True + while self.running: + _, req1 = await self.sniffer1_httpobj_queue.async_get() + print("Got req1") + if not isinstance(req1, RequestData): + raise SnifferException("Expected RequestData from sniffer1") + + _, req2 = await self.sniffer2_httpobj_queue.async_get() + print("Got req2") + if not isinstance(req2, RequestData): + raise SnifferException("Expected RequestData from sniffer2") + + if req1 != req2: + print("Calling on_fail") + await self.on_fail(self, req1, req2) + + self.sniffer1.publish_modification((None, None)) + _, res1 = await self.sniffer1_httpobj_queue.async_get() + self.sniffer1.publish_modification((None, None)) + if not isinstance(res1, ResponseData): + raise SnifferException("Expected RespsoneData from sniffer1") + + self.sniffer2.publish_modification((None, res1)) + _, res2 = await self.sniffer2_httpobj_queue.async_get() + if not isinstance(res2, ResponseData): + raise SnifferException("Expected RespsoneData from sniffer2") + + if res1 != res2: + print("Expected the safe response from sniffer2") + self.on_fail(self, res1, res2) + + self.sniffer2.publish_modification((None, None)) + + self.requests_passed += 1 + + print("Done comparator iteration") + + async def on_request1(self, req: RequestData): + print("Got request1") + async with self.request1_lock: + self.req1 = req + self.request1_ready.release() + print("Released req1_ready") + await self.request2_ready.acquire() + + print("Barrier released") + + if self.req1 != self.req2: + print("Calling on_fail") + self.on_fail(self.req1, self.req2) + + return None, None + + self.req1 = None + + async def on_request2(self, req: RequestData): + print("Got request2") + async with self.request2_lock: + self.req2 = req + self.request2_ready.release() + + print("Released req2_ready") + + await self.requests_ready.acquire() + print("Got response for req2") + self.req2 = None + + return None, self.res1 + + async def on_response1(self, res: ResponseData): + print("Got response1") + # Assume on_request1 was called for the corresponding request + self.res1 = res + print("Well, nice") + self.requests_ready.release() + print("Released") + + async def on_response2(self, res: ResponseData): + print("On response2 was called but this should never happen") + + +@dataclass +class ProxyInfo: + proc: subprocess.Popen + host: str + port: int + name: str + + +class MitmproxySnifferManager: + def __init__(self, sock_suffix: str, event_loop=None): + if event_loop is None: + event_loop = asyncio.get_event_loop() + self.sniffer = Sniffer( + sock_suffix=sock_suffix, + loop=event_loop, + ) + self.thread = None + + self.proxies = {} + + def start_sniffer_on_thread(self): + if self.thread is not None: + raise SnifferException("Sniffer thread already started") + + self.sniffer.running = True + self.thread = threading.Thread(target=self.sniffer.run, daemon=True) + self.thread.start() + # Wait for the sniffer to open the listening socket + time.sleep(0.5) + + def start_proxy_instance(self, host=None, port=None): + if self.thread is None: + raise SnifferException("Sniffer must be started before starting a proxy") + + kwargs = {} + if host is not None: + kwargs["host"] = host + if port is not None: + kwargs["port"] = port + + proc, host, port, proxy_name = mitm_runner.run(self.sniffer.sockpath, **kwargs) + + self.proxies[proxy_name] = ProxyInfo( + proc=proc, + host=host, + port=port, + name=proxy_name, + ) + + return self.proxies[proxy_name] + + async def wait_for_proxy_connection_with_sniffer(self): + while True: + sig = await self.sniffer.signal_queue.async_get() + print(f"Got signal on queue: {sig}") + if sig is self.sniffer.SIG_CLIENT_CONNECTED: + break + + def stop_from_thread(self): + self.sniffer.stop() + self.thread.join() + + def stop_proxy_instance(self, proxy_name: str): + try: + proxy_info = self.proxies[proxy_name] + except KeyError: + raise SnifferException(f"No proxy with the name: {proxy_name!r}") + + proxy_info.proc.stop() + del self.proxies[proxy_name] + + def __del__(self): + for proxy in self.proxies.values(): + proxy.proc.terminate() + + self.proxies.clear() diff --git a/cdprecorder/skopo/common_data.py b/cdprecorder/skopo/common_data.py new file mode 100644 index 0000000..66bef73 --- /dev/null +++ b/cdprecorder/skopo/common_data.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import socket +import pickle +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import asyncio + from typing import Union + + +class RequestData: + __slots__ = ("http_version", "method", "url", "headers", "raw_content", "trailers", "object_id") + + def __init__( + self, + http_version: str, + method: str, + url: str, + headers: list[tuple[str, str]], + raw_content: bytes, + trailers: list[tuple[str, str]], + object_id: int, + ): + self.http_version = http_version + self.method = method + self.url = url + self.headers = headers + self.raw_content = raw_content + self.trailers = trailers + self.object_id = object_id + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RequestData): + raise ValueError("Value must RequestData instance") + # Compares everything but object_id + attr1 = (self.http_version, self.method, self.url, self.headers, self.raw_content, self.trailers) + attr2 = (other.http_version, other.method, other.url, other.headers, other.raw_content, other.trailers) + + return attr1 == attr2 + + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"http_version={self.http_version}, " + text += f"method={self.method}, " + text += f"url={self.url}, " + text += f"headers={self.headers}, " + text += f"raw_content={self.raw_content!r}, " + text += f"trailers={self.trailers}, " + text += f"object_id={self.object_id}" + text += ")" + + return text + + +class ResponseData: + __slots__ = ("http_version", "status_code", "reason", "headers", "raw_content", "trailers", "object_id") + + def __init__( + self, + http_version: str, + status_code: int, + reason: str, + headers: list[tuple[str, str]], + raw_content: bytes, + trailers: list[tuple[str, str]], + object_id: int, + ): + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.raw_content = raw_content + self.trailers = trailers + self.object_id = object_id + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResponseData): + raise ValueError("Value must ResponseData instance") + # Compares everything but object_id + attr1 = (self.http_version, self.status_code, self.reason, self.headers, self.raw_content, self.trailers) + attr2 = (other.http_version, other.status_code, other.reason, other.headers, other.raw_content, other.trailers) + + return attr1 == attr2 + + +class SniffProtocolException(Exception): + pass + + +class SniffProtocol: + def to_datagram(self, sendable_obj: Union[str, RequestData, ResponseData]) -> bytes: + assert isinstance(sendable_obj, (str, RequestData, ResponseData)) + body = pickle.dumps(sendable_obj) + size = len(body) + data = size.to_bytes(8, byteorder="big") + body + + return data + + +class SniffConnection: + def __init__(self, sockaddr: str): + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(sockaddr) + self.protocol = SniffProtocol() + + def send(self, sendable_obj: Union[RequestData, ResponseData]) -> None: + data = self.protocol.to_datagram(sendable_obj) + self.sock.send(data) + + def read(self) -> object: + size_data = self.sock.recv(8) + size = int.from_bytes(size_data, "big") + data = self.sock.recv(size) + + obj = pickle.loads(data) + if not isinstance(obj, (str, RequestData, ResponseData)): + raise SniffProtocolException(f"sendable_obj is of unsupported type {type(obj)}") + + return obj + + def read_str(self) -> str: + obj = self.read() + if not isinstance(obj, str): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {str}") + + return obj + + def read_httpobj(self) -> Union[RequestData, ResponseData]: + obj = self.read() + if not isinstance(obj, (RequestData, ResponseData)): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {(RequestData, ResponseData)}") + + return obj + + def read_request_data(self) -> RequestData: + obj = self.read() + if not isinstance(obj, RequestData): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {RequestData}") + + return obj + + def read_response_data(self) -> ResponseData: + obj = self.read() + if not isinstance(obj, ResponseData): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {ResponseData}") + + return obj + + def close(self) -> None: + self.sock.shutdown(socket.SHUT_RDWR) + self.sock.close() + + +class AsyncioSniffConnection: + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self.reader = reader + self.writer = writer + self.protocol = SniffProtocol() + + async def send(self, sendable_obj: Union[RequestData, ResponseData]) -> None: + data = self.protocol.to_datagram(sendable_obj) + self.writer.write(data) + await self.writer.drain() + + async def read(self) -> object: + size_data = await self.reader.readexactly(8) + size = int.from_bytes(size_data, "big") + data = await self.reader.readexactly(size) + + obj = pickle.loads(data) + if not isinstance(obj, (str, RequestData, ResponseData)): + raise SniffProtocolException(f"Object is of unsupported type {type(obj)}") + + return obj + + async def read_str(self) -> str: + obj = await self.read() + if not isinstance(obj, str): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {str}") + + return obj + + async def read_httpobj(self) -> Union[RequestData, ResponseData]: + obj = await self.read() + if not isinstance(obj, (RequestData, ResponseData)): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {(RequestData, ResponseData)}") + + return obj + + async def read_request_data(self) -> RequestData: + obj = await self.read() + if not isinstance(obj, RequestData): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {RequestData}") + + return obj + + async def read_response_data(self) -> ResponseData: + obj = await self.read() + if not isinstance(obj, ResponseData): + raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {ResponseData}") + + return obj diff --git a/cdprecorder/skopo/intercept_addon.py b/cdprecorder/skopo/intercept_addon.py new file mode 100644 index 0000000..27d39db --- /dev/null +++ b/cdprecorder/skopo/intercept_addon.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import logging +import pickle +import socket +import sys +from urllib.parse import urlunparse +from typing import Optional, TYPE_CHECKING + +from mitmproxy import ctx, http + +import common_data +from common_data import SniffConnection, RequestData, ResponseData + +# :'( +# Used by pickle to load RequestData and ResponseData +sys.modules["cdprecorder"] = common_data +sys.modules["cdprecorder.common_data"] = common_data +sys.modules["cdprecorder.skopo.common_data"] = common_data + +import requests + + +if TYPE_CHECKING: + from mitmproxy import http + + +class WrapperFormatter(logging.Formatter): + def __init__(self, formatter: loggingFormatter): + self.formatter = formatter + + def format(self, record: logging.LogRecord): + msg = self.formatter.format(record) + + msg_prefix = "" + if hasattr(record, "prefix"): + msg_prefix = f"[{record.prefix}]" + + return f"{msg_prefix}{msg}" + + def formatTime(self, *args, **kwargs): + return self.formatter.formatTime(*args, **kwargs) + + def formatException(self, *args, **kwargs): + return self.formatter.formatException(*args, **kwargs) + + def formatStack(self, *args, **kwargs): + return self.formatter.formatStack(*args, **kwargs) + + +class PrefixFilter(logging.Filter): + def __init__(self, prefix: Optional[str] = None): + self.prefix = prefix + + def filter(self, record: logging.LogRecord) -> bool: + if self.prefix is not None: + record.prefix = self.prefix + return True + + +class TheSpy: + def __init__(self): + self.data_sender = None + self.flows = set() + + logging.info("PPpp", extra={"client": "haubau"}) + + self.log_filter = PrefixFilter() + handler = logging.getLogger().handlers[0] + handler.addFilter(self.log_filter) + wrapper_formatter = WrapperFormatter(handler.formatter) + handler.setFormatter(wrapper_formatter) + + def load(self, loader): + url = f"{ctx.options.listen_host}:{ctx.options.listen_port}" + self.log_filter.prefix = url + + loader.add_option( + name="socketaddress", + # Wow, they really used typing at runtime + typespec=Optional[str], + default=None, + help="Socket address to send request/response pickled data", + ) + loader.add_option( + name="proxyname", + typespec=str, + default="thespy", + help="Name used to distinguish different mitmdump instances", + ) + + def start_connection(self, socketaddress: str): + if self.data_sender is not None: + self.data_sender.close() + logging.info("Connecting to %s", socketaddress) + self.data_sender = SniffConnection(socketaddress) + assert isinstance(self.proxyname, str) + self.data_sender.send(self.proxyname) + + def running(self): + if ctx.options.socketaddress is not None: + self.start_connection(ctx.options.socketaddress) + + def configure(self, updated: set[str]): + self.proxyname = ctx.options.proxyname + logging.info("Set proxyname=%s", self.proxyname) + if ctx.options.socketaddress is not None: + self.start_connection(ctx.options.socketaddress) + + def request(self, flow: http.HTTPFlow): + self.flows.add(id(flow)) + logging.info("Intercepted request") + mitmreq = flow.request + # TODO: handle mitmreq.authority + url = urlunparse( + ( + mitmreq.scheme, + f"{mitmreq.host}:{mitmreq.port}", + # Can this ever be bytes?? + mitmreq.path, + "", + "", + "", + ) + ) + + req = RequestData( + http_version=mitmreq.http_version, + method=mitmreq.method, + url=url, + headers=list(mitmreq.headers.items(multi=True)), + raw_content=mitmreq.raw_content, + trailers=list(mitmreq.trailers.items(multi=True) if mitmreq.trailers else []), + object_id=id(flow), + ) + + self.data_sender.send(req) + + logging.info("Sent request") + status = self.data_sender.read_str() + + if status == "REPLACE_RESPONSE": + r = self.data_sender.read_response_data() + + logging.info("Headers: %s", r.headers) + headers_bytes = [] + for key, value in r.headers: + headers_bytes.append((key.encode(), value.encode())) + resp = http.Response.make( + status_code=r.status_code, + content=r.raw_content, + headers=headers_bytes, + ) + resp.http_version = r.http_version + resp.reason = r.reason + resp.trailers = http.Headers(r.trailers) + + flow.response = resp + elif status == "REPLACE_REQUEST": + r = self.data_sender.read_request_data() + + req = http.Request.make(method=r.method, url=r.url, content=r.raw_content, headers=http.Headers(r.headers)) + req.http_version = (r.http_version,) + req.reason = (r.reason,) + req.trailers = (http.Headers(r.trailers),) + + flow.request = req + + elif status != "OK": + raise Exception(f"Unknown status: {status}") + + logging.info("OK Intercepted request") + + def response(self, flow: http.HTTPFlow): + logging.info("Intercepted response") + if id(flow) not in self.flows: + logging.error("Got response before existing request") + else: + self.flows.remove(id(flow)) + + try: + mitmres = flow.response + logging.info("Response headers: %s", mitmres.headers) + res = ResponseData( + http_version=mitmres.http_version, + status_code=mitmres.status_code, + reason=mitmres.reason, + headers=list(mitmres.headers.items(multi=True)), + raw_content=mitmres.raw_content, + trailers=list(mitmres.trailers.items(multi=True) if mitmres.trailers else []), + object_id=id(flow), + ) + + logging.info("Constructed ResponseData") + + self.data_sender.send(res) + logging.info("Sent response") + + status = self.data_sender.read_str() + logging.info("Got status in response") + + if status == "REPLACE_RESPONSE": + r = self.data_sender.read_response_data() + + resp = http.Response.make( + status_code=r.status_code, + content=r.raw_content, + headers=http.Headers(r.headers), + ) + resp.http_version = (r.http_version,) + resp.reason = (r.reason,) + resp.trailers = (http.Headers(r.trailers),) + + flow.response = resp + elif status != "OK": + raise Exception(f"Unknown status: {status}") + + logging.info("OK Intercepted response") + except Exception as exc: + logging.info("Addon Exception: %s", exc) + raise + + +addons = [TheSpy()] diff --git a/cdprecorder/skopo/mitm_runner.py b/cdprecorder/skopo/mitm_runner.py new file mode 100644 index 0000000..6b5d4a8 --- /dev/null +++ b/cdprecorder/skopo/mitm_runner.py @@ -0,0 +1,39 @@ +import os +import random +import string +import subprocess + +from . import logger + + +def run( + sniffer_socket_address: str, + host: str = "localhost", + port: int = 8080, + addon_script: str = "intercept_addon.py", + binary: str = "mitmdump", +): + proxy_name: str = "".join(random.choices(string.ascii_letters, k=32)) + module_dir = os.path.dirname(os.path.realpath(__file__)) + addon_path = os.path.join(module_dir, addon_script) + cli_args = [ + "--mode", + "regular", + "--listen-host", + host, + "--listen-port", + str(port), + "-s", + addon_path, + "--set", + f"socketaddress={sniffer_socket_address}", + "--set", + f"proxyname={proxy_name}", + ] + args = [binary] + cli_args + + logger.info("Running command: `%s`", " ".join(args)) + # TODO: stderr DEVNULL + p = subprocess.Popen(args) # , stdout=subprocess.DEVNULL) + + return p, host, port, proxy_name From 1e0180f5367767fbf92e136eb046dfbc77f20f2c Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:05:23 +0300 Subject: [PATCH 04/18] feat: update recorder --- cdprecorder/recorder.py | 181 ++++++++++++++++++++++++++++++++-------- 1 file changed, 148 insertions(+), 33 deletions(-) diff --git a/cdprecorder/recorder.py b/cdprecorder/recorder.py index 5c38fe0..caa7d57 100644 --- a/cdprecorder/recorder.py +++ b/cdprecorder/recorder.py @@ -28,12 +28,14 @@ from cheap_repr import cheap_repr from pycdp import cdp from pycdp.browser import ChromeLauncher -from pycdp.twisted import CDPConnection as _PyCDPConnection +from pycdp.twisted import CDPConnection as _TwistedPyCDPConnection from twisted.internet import defer, threads from twisted.internet.error import ConnectionRefusedError from twisted.internet.interfaces import IReactorCore, IReactorTime from twisted.web.client import Agent +from pycdp.asyncio import CDPConnection as _PyCDPConnection, ClientSession + from . import filters, logger, tkinter_ui from .action import InputAction @@ -98,13 +100,10 @@ async def insert_js_leech_script( await target_session.execute(cdp.page.runtime.enable()) - receiver = None - evt_to_listen = cdp.runtime.ExecutionContextCreated + events_queue = None try: - # Don't call target_session.listen because we need the receiver object to close it at the end - receiver = pycdp.twisted.CDPEventListener(defer.DeferredQueue(1024)) - target_session._listeners[evt_to_listen].add(receiver) - listener = aiter(receiver) + events_queue = target_session.create_listener(cdp.runtime.ExecutionContextCreated) + listener = aiter(events_queue) await target_session.execute( cdp.page.add_script_to_evaluate_on_new_document(expression, run_immediately=True, world_name=context_name) @@ -118,8 +117,8 @@ async def insert_js_leech_script( except defer.TimeoutError as exc: raise Exception from exc finally: - if receiver is not None: - receiver.close() + if events_queue is not None: + events_queue.close() if context_id is None: raise Exception @@ -193,7 +192,7 @@ def __eq__(self, obj: object) -> bool: T = TypeVar("T") -class AsyncIteratorWithTimeout(Generic[T]): +class TwistedAsyncIteratorWithTimeout(Generic[T]): def __init__( self, iterator: AsyncIterator[T], @@ -220,6 +219,50 @@ async def __anext__(self) -> T: d.addTimeout(remained, reactor) return await d + def __aiter__(self) -> TwistedAsyncIteratorWithTimeout: + return self + + +class TwistedAsyncIterableWithTimeout(Generic[T]): + def __init__( + self, + iterable: AsyncIterable[T], + timeout: float, + start_time: Optional[float] = None, + ): + self.iterable = iterable + self.timeout = timeout + if start_time is None: + self.start_time = time.time() + else: + self.start_time = start_time + + def __aiter__(self) -> AsyncIterator[T]: + return TwistedAsyncIteratorWithTimeout(self.iterable.__aiter__(), self.timeout, self.start_time) + + +class AsyncIteratorWithTimeout(Generic[T]): + def __init__( + self, + iterator: AsyncIterator[T], + timeout: float, + start_time: Optional[float] = None, + ): + self.iterator = iterator + self.timeout = timeout + if start_time is None: + self.start_time = time.time() + else: + self.start_time = start_time + + async def __anext__(self) -> T: + curr_time = time.time() + remained = self.timeout - (curr_time - self.start_time) + if remained <= 0: + raise StopAsyncIteration + + return await asyncio.wait_for(self.iterator.__anext__(), remained) + def __aiter__(self) -> AsyncIteratorWithTimeout: return self @@ -242,7 +285,7 @@ def __aiter__(self) -> AsyncIterator[T]: return AsyncIteratorWithTimeout(self.iterable.__aiter__(), self.timeout, self.start_time) -class CancelableAsyncIterator(Generic[T]): +class TwistedCancelableAsyncIterator(Generic[T]): def __init__(self, iterator: AsyncIterator[T], stop_event: defer.Deferred): self.iterator = iterator self._stop_event = stop_event @@ -277,11 +320,11 @@ async def __anext__(self) -> T: return res - def __aiter__(self) -> CancelableAsyncIterator: + def __aiter__(self) -> TwistedCancelableAsyncIterator: return self -class CancelableAsyncIterable(Generic[T]): +class TwistedCancelableAsyncIterable(Generic[T]): def __init__(self, iterable: AsyncIterable[T]): self.iterable = iterable self._stop_event = defer.Deferred() @@ -289,6 +332,47 @@ def __init__(self, iterable: AsyncIterable[T]): def cancel(self): return self._stop_event.callback(None) + def __aiter__(self) -> AsyncIterator[T]: + return TwistedCancelableAsyncIterator(self.iterable.__aiter__(), self._stop_event) + + +class CancelableAsyncIterator(Generic[T]): + def __init__(self, iterator: AsyncIterator[T], stop_event: asyncio.Event): + self.iterator = iterator + self._stop_event = stop_event + self._wait_cancel = asyncio.create_task(self._stop_event.wait()) + + def cancel(self): + return self._stop_event.set() + + async def __anext__(self) -> T: + if self._stop_event.is_set(): + raise StopAsyncIteration + + it_next = asyncio.create_task(self.iterator.__anext__()) + + done, pending = await asyncio.wait([it_next, self._wait_cancel], return_when=asyncio.FIRST_COMPLETED) + + if it_next in pending: + it_next.cancel() + + if self._wait_cancel in done: + raise StopAsyncIteration + + return it_next.result() + + def __aiter__(self) -> CancelableAsyncIterator: + return self + + +class CancelableAsyncIterable(Generic[T]): + def __init__(self, iterable: AsyncIterable[T]): + self.iterable = iterable + self._stop_event = asyncio.Event() + + def cancel(self): + return self._stop_event.set() + def __aiter__(self) -> AsyncIterator[T]: return CancelableAsyncIterator(self.iterable.__aiter__(), self._stop_event) @@ -300,11 +384,15 @@ def __init__( urlfilter: filters.URLFilter, collect_all: bool, start_origin: Optional[str], + connection, + listener: AsyncIterable, ): self.target_session = target_session self.urlfilter = urlfilter self.start_origin = start_origin self.collect_all = collect_all + self.connection = connection + self.listener = CancelableAsyncIterable(listener) self.communications: list[Union[HttpCommunication, InputAction]] = [] self.request_map: dict[pycdp.cdp.network.RequestId, HttpCommunication] = {} self.runtime_ctx = get_runtime_context() @@ -384,14 +472,14 @@ async def on_stop(self): def add_on_stop_callback(self, callback): self.on_stop_cbs.append(callback) + async def close(self): + self.target_session.close_listeners() + await self.connection.close() + async def collect_communications( - target_session: pycdp.twisted.CDPSession, - listener: AsyncIterator[object], - urlfilter: filters.URLFilter, + recorder: Recorder, timeout: int = 120, - collect_all: bool = False, - start_origin: Optional[str] = None, ) -> list[Union[HttpCommunication, InputAction]]: """Takes a cdp session, listens for events, and generates a list of communications. @@ -407,7 +495,7 @@ async def collect_communications( Returns: A list of communications. """ - recorder = Recorder(target_session, urlfilter, collect_all, start_origin) + listener = recorder.listener runtime_context = get_runtime_context() ui = tkinter_ui.TkRecordControl(reactor, recorder.on_start, recorder.on_stop) @@ -448,6 +536,11 @@ async def collect_communications( return recorder.communications +class TwistedCDPConnection(_TwistedPyCDPConnection): + # Remove `retry_on` wrapper from function + connect = _PyCDPConnection.connect.__wrapped__ # type: ignore[attr-defined] # pylint: disable=no-member + + class CDPConnection(_PyCDPConnection): # Remove `retry_on` wrapper from function connect = _PyCDPConnection.connect.__wrapped__ # type: ignore[attr-defined] # pylint: disable=no-member @@ -472,16 +565,25 @@ def find_chrome_binary_path() -> str: class RecorderOptions: start_url: str keep_only_same_origin_urls: bool = True - collect_all: bool = False - binary: str = CHROME_BINARY + collect_all: bool = True + binary: Optional[str] = CHROME_BINARY cdp_host: str = "localhost" cdp_port: int = 9222 - fail_if_no_connection: bool = False + proxy_host: Optional[str] = None + proxy_port: int = 8080 + proxy_scheme: str = "http" @property def cdp_url(self) -> str: return f"http://{self.cdp_host}:{self.cdp_port}" + @property + def proxy_url(self) -> Optional[str]: + if self.proxy_host is None: + return None + + return f"{self.proxy_scheme}://{self.proxy_host}:{self.proxy_port}" + async def obtain_active_tab( targets: list[cdp.target.TargetInfo], conn: CDPConnection @@ -688,21 +790,26 @@ async def insert_widget_extension( return await insert_js_leech_script(target_session, expression) -async def record( - options: RecorderOptions, -) -> list[Union[HttpCommunication, InputAction]]: +async def init_recorder(options: RecorderOptions): urlfilter = filters.URLFilter() try: - conn = CDPConnection(options.cdp_url, Agent(reactor), reactor) # type: ignore[no-untyped-call] + http = ClientSession() + conn = CDPConnection(options.cdp_url, http) await conn.connect() + conn.start() except ConnectionRefusedError: - if options.fail_if_no_connection: + if options.binary is None: raise ConnectionRefusedError port = options.cdp_port + + chrome_args = [f"--remote-debugging-port={port}", "--incognito"] + if options.proxy_url is not None: + chrome_args.append(f"--proxy-server={options.proxy_url}") + chrome_args.append("--ignore-certificate-errors") chrome = ChromeLauncher( binary=options.binary, - args=[f"--remote-debugging-port={port}", "--incognito"], + args=chrome_args, ) await threads.deferToThread(chrome.launch) # type: ignore[no-untyped-call] await conn.connect() @@ -718,6 +825,9 @@ async def record( await target_session.execute(cdp.network.enable()) + # await target_session.execute(cdp.security.enable()) + # await target_session.execute(cdp.security.set_ignore_certificate_errors(True)) + # Start the listener before navigating to the page listener = target_session.listen( cdp.runtime.BindingCalled, @@ -745,12 +855,17 @@ async def record( if options.keep_only_same_origin_urls: start_origin = extract_origin(start_url) + return Recorder(target_session, urlfilter, options.collect_all, start_origin, conn, listener) + + +async def record( + options: RecorderOptions, +) -> list[Union[HttpCommunication, InputAction]]: + recorder = await init_recorder(options) + try: - communications = await collect_communications( - target_session, listener, urlfilter, 20, options.collect_all, start_origin - ) + communications = await collect_communications(recorder, 20) finally: - target_session.close_listeners() - await conn.close() + await recorder.close() return communications From 250723181eb3d9b4ecd3cf9d0e391cd8f688adeb Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:05:48 +0300 Subject: [PATCH 05/18] dev: add mitmproxy requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 6225b91..b13805e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-dateutil==2.8.2 python-cdp @ git+https://github.com/RazorBest/python-cdp.git@erpeto adblockparser==0.7 cheap_repr @ git+https://github.com/RazorBest/cheap_repr.git@add-annotations +mitmproxy==11.0.2 From cfb7dbd9626022188079947428d39023bb0d216c Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:06:15 +0300 Subject: [PATCH 06/18] feat: add dev requirements for testing with selenium --- dev/requirements-dev.txt | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dev/requirements-dev.txt b/dev/requirements-dev.txt index edb9605..df37086 100644 --- a/dev/requirements-dev.txt +++ b/dev/requirements-dev.txt @@ -1,8 +1,8 @@ +# Repo coverage-badge pybadges -pytest -pytest-cov==4.* -pytest-asyncio==0.23.5 + +# Styling mypy==1.17.* pylint==3.* isort==5.* @@ -10,6 +10,13 @@ black==24.* mypy-zope==1.0.* pre-commit==4.* +# Testing +pytest +virtualenv +pytest-cov==4.* +pytest-asyncio==0.23.5 +selenium==4.35.* + # Mypy stubs types-requests types-beautifulsoup4==4.* From af5ddeed69257fd340d8d8ce7be51a7a0df4bd2f Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:06:55 +0300 Subject: [PATCH 07/18] dev: add setup for http apps --- dev/requirements-http-apps.txt | 4 ++++ dev/setup.sh | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 dev/requirements-http-apps.txt create mode 100755 dev/setup.sh diff --git a/dev/requirements-http-apps.txt b/dev/requirements-http-apps.txt new file mode 100644 index 0000000..46c69f4 --- /dev/null +++ b/dev/requirements-http-apps.txt @@ -0,0 +1,4 @@ +# Requirements for the Python apps in tests/test_e2e_run/http_apps +Flask==3.1.1 +Flask-WTF==1.2.2 +WTForms==3.2.1 diff --git a/dev/setup.sh b/dev/setup.sh new file mode 100755 index 0000000..e310624 --- /dev/null +++ b/dev/setup.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -e + +PARENT_PATH=$(cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P) +VENV_HTTP_APPS=".venv_http_apps" + + +# TODO: make sure the script is run from a specific place + +# Create venv if it doesn't exist +if [ ! -d "$VENV_HTTP_APPS" ]; then + echo Creating virtual environment in $VENV_HTTP_APPS... + python3 -m virtualenv "$VENV_HTTP_APPS" +fi + +# Activate venv +echo Activating virtual environment $VENV_HTTP_APPS... +# shellcheck disable=SC1091 +source "$VENV_HTTP_APPS/bin/activate" + +echo Installing dependencies... +pip install --upgrade pip +pip install -r $PARENT_PATH/requirements-http-apps.txt + +echo Done From 612c4dac574693d02f9344e4c6c9447ad43564b8 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:07:36 +0300 Subject: [PATCH 08/18] feat: add _storage.py --- cdprecorder/_storage.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 cdprecorder/_storage.py diff --git a/cdprecorder/_storage.py b/cdprecorder/_storage.py new file mode 100644 index 0000000..57f7747 --- /dev/null +++ b/cdprecorder/_storage.py @@ -0,0 +1,8 @@ +import os + + +DEFAULT_SOCKET_NAME = "erpeto.sock" + + +def get_runtime_dir(): + return os.getenv("XDG_RUNTIME_DIR") From 7860e5c5bee27174d4d8df77dd19b4acf179b032 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 18 Aug 2025 03:08:07 +0300 Subject: [PATCH 09/18] fix: import IReactorCore --- cdprecorder/user_interface.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cdprecorder/user_interface.py b/cdprecorder/user_interface.py index ca7ebc3..d99b916 100644 --- a/cdprecorder/user_interface.py +++ b/cdprecorder/user_interface.py @@ -1,6 +1,8 @@ import abc from typing import TYPE_CHECKING, Callable +from twisted.internet.interfaces import IReactorCore + if TYPE_CHECKING: from typing import TYPE_CHECKING, Callable from twisted.internet.interfaces import IReactorCore From 5ae5b9346cbbc5c388ee05debf4afa1c7e2c406e Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 1 Sep 2025 01:36:12 +0300 Subject: [PATCH 10/18] feat: remove common_data.py and mitm_runner.py --- cdprecorder/skopo/common_data.py | 204 ------------------------------- cdprecorder/skopo/mitm_runner.py | 39 ------ 2 files changed, 243 deletions(-) delete mode 100644 cdprecorder/skopo/common_data.py delete mode 100644 cdprecorder/skopo/mitm_runner.py diff --git a/cdprecorder/skopo/common_data.py b/cdprecorder/skopo/common_data.py deleted file mode 100644 index 66bef73..0000000 --- a/cdprecorder/skopo/common_data.py +++ /dev/null @@ -1,204 +0,0 @@ -from __future__ import annotations - -import socket -import pickle -from typing import TYPE_CHECKING - - -if TYPE_CHECKING: - import asyncio - from typing import Union - - -class RequestData: - __slots__ = ("http_version", "method", "url", "headers", "raw_content", "trailers", "object_id") - - def __init__( - self, - http_version: str, - method: str, - url: str, - headers: list[tuple[str, str]], - raw_content: bytes, - trailers: list[tuple[str, str]], - object_id: int, - ): - self.http_version = http_version - self.method = method - self.url = url - self.headers = headers - self.raw_content = raw_content - self.trailers = trailers - self.object_id = object_id - - def __eq__(self, other: object) -> bool: - if not isinstance(other, RequestData): - raise ValueError("Value must RequestData instance") - # Compares everything but object_id - attr1 = (self.http_version, self.method, self.url, self.headers, self.raw_content, self.trailers) - attr2 = (other.http_version, other.method, other.url, other.headers, other.raw_content, other.trailers) - - return attr1 == attr2 - - def __str__(self) -> str: - text = f"{self.__class__.__name__}(" - text += f"http_version={self.http_version}, " - text += f"method={self.method}, " - text += f"url={self.url}, " - text += f"headers={self.headers}, " - text += f"raw_content={self.raw_content!r}, " - text += f"trailers={self.trailers}, " - text += f"object_id={self.object_id}" - text += ")" - - return text - - -class ResponseData: - __slots__ = ("http_version", "status_code", "reason", "headers", "raw_content", "trailers", "object_id") - - def __init__( - self, - http_version: str, - status_code: int, - reason: str, - headers: list[tuple[str, str]], - raw_content: bytes, - trailers: list[tuple[str, str]], - object_id: int, - ): - self.http_version = http_version - self.status_code = status_code - self.reason = reason - self.headers = headers - self.raw_content = raw_content - self.trailers = trailers - self.object_id = object_id - - def __eq__(self, other: object) -> bool: - if not isinstance(other, ResponseData): - raise ValueError("Value must ResponseData instance") - # Compares everything but object_id - attr1 = (self.http_version, self.status_code, self.reason, self.headers, self.raw_content, self.trailers) - attr2 = (other.http_version, other.status_code, other.reason, other.headers, other.raw_content, other.trailers) - - return attr1 == attr2 - - -class SniffProtocolException(Exception): - pass - - -class SniffProtocol: - def to_datagram(self, sendable_obj: Union[str, RequestData, ResponseData]) -> bytes: - assert isinstance(sendable_obj, (str, RequestData, ResponseData)) - body = pickle.dumps(sendable_obj) - size = len(body) - data = size.to_bytes(8, byteorder="big") + body - - return data - - -class SniffConnection: - def __init__(self, sockaddr: str): - self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.sock.connect(sockaddr) - self.protocol = SniffProtocol() - - def send(self, sendable_obj: Union[RequestData, ResponseData]) -> None: - data = self.protocol.to_datagram(sendable_obj) - self.sock.send(data) - - def read(self) -> object: - size_data = self.sock.recv(8) - size = int.from_bytes(size_data, "big") - data = self.sock.recv(size) - - obj = pickle.loads(data) - if not isinstance(obj, (str, RequestData, ResponseData)): - raise SniffProtocolException(f"sendable_obj is of unsupported type {type(obj)}") - - return obj - - def read_str(self) -> str: - obj = self.read() - if not isinstance(obj, str): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {str}") - - return obj - - def read_httpobj(self) -> Union[RequestData, ResponseData]: - obj = self.read() - if not isinstance(obj, (RequestData, ResponseData)): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {(RequestData, ResponseData)}") - - return obj - - def read_request_data(self) -> RequestData: - obj = self.read() - if not isinstance(obj, RequestData): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {RequestData}") - - return obj - - def read_response_data(self) -> ResponseData: - obj = self.read() - if not isinstance(obj, ResponseData): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {ResponseData}") - - return obj - - def close(self) -> None: - self.sock.shutdown(socket.SHUT_RDWR) - self.sock.close() - - -class AsyncioSniffConnection: - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self.reader = reader - self.writer = writer - self.protocol = SniffProtocol() - - async def send(self, sendable_obj: Union[RequestData, ResponseData]) -> None: - data = self.protocol.to_datagram(sendable_obj) - self.writer.write(data) - await self.writer.drain() - - async def read(self) -> object: - size_data = await self.reader.readexactly(8) - size = int.from_bytes(size_data, "big") - data = await self.reader.readexactly(size) - - obj = pickle.loads(data) - if not isinstance(obj, (str, RequestData, ResponseData)): - raise SniffProtocolException(f"Object is of unsupported type {type(obj)}") - - return obj - - async def read_str(self) -> str: - obj = await self.read() - if not isinstance(obj, str): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {str}") - - return obj - - async def read_httpobj(self) -> Union[RequestData, ResponseData]: - obj = await self.read() - if not isinstance(obj, (RequestData, ResponseData)): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {(RequestData, ResponseData)}") - - return obj - - async def read_request_data(self) -> RequestData: - obj = await self.read() - if not isinstance(obj, RequestData): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {RequestData}") - - return obj - - async def read_response_data(self) -> ResponseData: - obj = await self.read() - if not isinstance(obj, ResponseData): - raise SniffProtocolException(f"Object is of type {type(obj)}. Expected {ResponseData}") - - return obj diff --git a/cdprecorder/skopo/mitm_runner.py b/cdprecorder/skopo/mitm_runner.py deleted file mode 100644 index 6b5d4a8..0000000 --- a/cdprecorder/skopo/mitm_runner.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import random -import string -import subprocess - -from . import logger - - -def run( - sniffer_socket_address: str, - host: str = "localhost", - port: int = 8080, - addon_script: str = "intercept_addon.py", - binary: str = "mitmdump", -): - proxy_name: str = "".join(random.choices(string.ascii_letters, k=32)) - module_dir = os.path.dirname(os.path.realpath(__file__)) - addon_path = os.path.join(module_dir, addon_script) - cli_args = [ - "--mode", - "regular", - "--listen-host", - host, - "--listen-port", - str(port), - "-s", - addon_path, - "--set", - f"socketaddress={sniffer_socket_address}", - "--set", - f"proxyname={proxy_name}", - ] - args = [binary] + cli_args - - logger.info("Running command: `%s`", " ".join(args)) - # TODO: stderr DEVNULL - p = subprocess.Popen(args) # , stdout=subprocess.DEVNULL) - - return p, host, port, proxy_name From f99172d6cecff727b0bb7d866d50d9014f7cc964 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 1 Sep 2025 01:40:26 +0300 Subject: [PATCH 11/18] feat: separate sniff protocol from server code --- cdprecorder/skopo/__init__.py | 746 ++++++++++++++-------------- cdprecorder/skopo/sniff_protocol.py | 539 ++++++++++++++++++++ 2 files changed, 918 insertions(+), 367 deletions(-) create mode 100644 cdprecorder/skopo/sniff_protocol.py diff --git a/cdprecorder/skopo/__init__.py b/cdprecorder/skopo/__init__.py index b944eb1..8afe1f0 100644 --- a/cdprecorder/skopo/__init__.py +++ b/cdprecorder/skopo/__init__.py @@ -1,419 +1,431 @@ from __future__ import annotations import asyncio -import os -import pickle -import queue +import os.path +import random import socket -import sys -import threading -import time -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from .._storage import get_runtime_dir, DEFAULT_SOCKET_NAME -from .. import logger -from . import mitm_runner, common_data -from .common_data import RequestData, ResponseData - +import string +import subprocess +from typing import TypeVar, TYPE_CHECKING + +from .._storage import DEFAULT_SOCKET_NAME, get_runtime_dir +from .sniff_protocol import ( + async_read_sock_datagram, + ProxyEvent, + ProxyMessage, + RequestData, + ResponseData, + SniffCommand, + sniffer_data_from_bytes, + to_sock_datagram, + SnifferError, +) if TYPE_CHECKING: - import subprocess - from asyncio.events import AbstractEventLoop - from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar - - T = TypeVar("T") - + from asyncio import StreamReader, StreamWriter + from typing import TypeVar -# for pickle to find the module -# Reference: https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory -sys.modules["common_data"] = common_data - -class SnifferException(Exception): +class SkopoException(Exception): pass -def asyncio_run_coroutine_threadsafe(coro: Coroutine) -> None: - asyncio.run_coroutine_threadsafe(coro, asyncio.get_event_loop()) - - -class ThreadsafeAsyncWaker: - def __init__(self, loop: AbstractEventLoop, callback: Callable[[], Any]): - self.loop = loop - self.callback = callback - - def wake(self) -> None: - asyncio.run_coroutine_threadsafe(self.callback(), self.loop) - - -class ThreadSafeAsyncQueue(Generic[T]): - def __init__(self, size: int, loop: Optional[AbstractEventLoop] = None): - if loop is None: - loop = asyncio.get_event_loop() - self._tqueue: queue.Queue[T] = queue.Queue(size) - self._aqueue: asyncio.Queue[T] = asyncio.Queue(size) - self._async_waker = ThreadsafeAsyncWaker(loop, self.transfer_to_async) - - async def transfer_to_async(self) -> None: - try: - item = self._tqueue.get(block=False) - await self._aqueue.put(item) - except queue.Empty: - pass - - def put(self, item: T) -> None: - self._tqueue.put(item) - self._async_waker.wake() - - def get(self) -> T: - return self._tqueue.get() - - async def async_put(self, item: T) -> None: - await self._aqueue.put(item) - try: - item = await self._aqueue.get_nowait() - # TODO: This might block :( - self._tqueue.put(item) - except asyncio.QueueEmpty: - pass - - async def async_get(self) -> T: - return await self._aqueue.get() - - -class CrossThreadSnifferBase: - """Thread safe methods used by the Sniffer class.""" - - def __init__(self, loop: Optional[AbstractEventLoop] = None): - self.lock = threading.Lock() - self.signal_queue = None - if loop is not None: - self.signal_queue = ThreadSafeAsyncQueue(0, loop) - self.httpobj_queue = None - self.modification_queue = queue.Queue(maxsize=1) +class SnifferException(SkopoException): + def __init__(self, obj: Optional[SnifferError], *args, **kwargs): + super().__init__(*args, **kwargs) + self.obj = obj - self.client_connected = False - def set_httpobj_queue(self, httpobj_queue: queue.Queue): - with self.lock: - self.httpobj_queue = httpobj_queue +class SnifferProcessTerminated(SkopoException): + def __init__(self, proc: asyncio.subprocess.Process, *args, **kwargs): + super().__init__(*args, **kwargs) + self.proc = proc + + +class Sniffer: + def __init__(self): + self._ignore_reconnects = False + self.session_to_msg_queues = defaultdict(asyncio.Queue) + + def ignore_reconnects(self, ignore=True): + self._ignore_reconnects = ignore + + async def _get_data(self) -> bytes: + raise NotImplementedError + + async def pushback_message(self, obj): + await self.session_to_msg_queues[obj.session].put(obj) + + def _handle_message(self, obj, session: Optional[int]) -> Optional[ProxyMessage]: + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) + + if self._ignore_reconnects and isinstance(obj, ProxyEvent): + if obj.event == ProxyEvent.CONNECT or obj.event == ProxyEvent.CLOSE: + return None + + if isinstance(obj, SnifferMessage): + if obj.session is None and session is not None: + raise ProxyException(obj, "Received sesionless message in a context with session") + if obj.session != session: + self.pushback_message(obj) + else: + return obj + + return None + + async def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> ProxyMessage: + if session is not None: + while True: + t1 = asyncio.create_task(self.session_to_msg_queues[session].get()) + t2 = asyncio.create_task(self._get_data()) + done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED) + for t in pending: + t.cancel() + + results = [] + if t1 in done: + results.append(t1.result()) + if t2 in done: + data = t1.result() + obj, _ = sniffer_data_from_bytes(data) + results.append(obj) + + while results: + obj = results.pop(0) + obj = self._handle_message(obj) + if obj is not None: + for result in results: + self.pushback_message(result) + return obj - def publish_httpobj(self, httpobj): - if self.httpobj_queue is None: - return - self.httpobj_queue.put((self, httpobj)) + while True: + data = await self._get_data() + obj, _ = sniffer_data_from_bytes(data) + obj = self._handle_message(obj) + if obj is not None: + return obj - def get_modification(self): - return self.modification_queue.get() + async def get_request_data(self, ignore_error: bool = False, session: Optional[int] = None) -> RequestData: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, RequestData): + raise SnifferException(msg, "Expected message of type RequestData") + return msg - def publish_modification(self, obj: object): - self.modification_queue.put(obj) + async def get_response_data(self, ignore_error: bool = False, session: Optional[int] = None) -> ResponseData: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, ResponseData): + raise SnifferException(msg, "Expected message of type ResponseData") + return msg - def send_signal(self, signal): - self.signal_queue.put(signal) + async def get_proxy_event(self, ignore_error: bool = False, session: Optional[int] = None) -> ProxyEvent: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, ProxyEvent): + raise SnifferException(msg, "Expected message of type ProxyEvent") + return msg + async def wait_event(self, event_id: int, ignore_error: bool = False, session: Optional[int] = None) -> ProxyEvent: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, ProxyEvent): + raise SnifferException(msg, "Expected message of type ProxyEvent") + if msg.event != event_id: + raise SnifferException(msg, f"Expected event id {event_id}") -class Sniffer(CrossThreadSnifferBase): - SIG_CLIENT_CONNECTED = 1 + return msg - def __init__(self, coroutine_runner=None, coroutine_runner_threadsafe=None, sock_suffix: str = "", *args, **kwargs): - super().__init__(*args, **kwargs) - if coroutine_runner is not None: - self.coroutine_runner = coroutine_runner - else: - self.coroutine_runner = asyncio.run - if coroutine_runner_threadsafe is not None: - self.coroutine_runner_threadsafe = coroutine_runner_threadsafe - else: - self.coroutine_runner_threadsafe = asyncio_run_coroutine_threadsafe - - self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.running = False - - dirpath = get_runtime_dir() - self.sockpath = os.path.join(dirpath, DEFAULT_SOCKET_NAME + sock_suffix) - # Try to delete the socket if it already exists - try: - os.unlink(self.sockpath) - except FileNotFoundError: - pass + async def _send_data(self, data: bytes): + raise NotImplementedError - self._id_counter = 0 - self.client_connected = False + async def send_command( + self, + command: int, + request: Optional[RequestData] = None, + response: Optional[ResponseData] = None, + session: Optional[int] = None, + ): + sniff_command = SniffCommand(command, request, response) + sniff_command.session = session + data = sniff_command.to_bytes() + await self._send_data(data) - def run(self) -> None: - self.coroutine_runner(self._async_run()) + async def send_error(self, msg: SnifferError, session: Optional[int] = None): + msg.session = session + data = msg.to_bytes() + await self._send_data(data) - async def on_client_connected(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - if self.client_connected: - raise SnifferException("Sniffer only supports one client at a time") - self.client_connected = True - self.send_signal(self.SIG_CLIENT_CONNECTED) - try: - print("Client connected") - conn = common_data.AsyncioSniffConnection(reader, writer) - try: - try: - proxyname = await conn.read_str() - except common_data.SniffProtocolException as exc: - raise SnifferException from exc - - print(f"Proxy name: {proxyname}") - while True: - try: - httpobj = await conn.read_httpobj() - except common_data.SniffProtocolException as exc: - raise SnifferException from exc - - print(f"Got {httpobj}") - - # Yes, the calls to Queue can block the thread, that's - # why we allow only one client connection at a time - self.publish_httpobj(httpobj) - modification = self.get_modification() - - req, resp = modification - - if resp is not None: - await conn.send("REPLACE_RESPONSE") - await conn.send(resp) - elif req is not None and not isinstance(httpobj, ResponseData): - await conn.send("REPLACE_REQUEST") - await conn.send(req) - else: - await conn.send("OK") - except asyncio.IncompleteReadError: - pass - - print("Client done") - except Exception as exc: - print(f"Sniffer Exception: {exc}") - import traceback - - traceback.print_exception(exc) - raise - finally: - self.client_connected = False + async def async_stop(self): + raise NotImplemented - async def _async_run(self): - try: - os.unlink(self.sockpath) - except FileNotFoundError: - pass - self.server = await asyncio.start_unix_server(self.on_client_connected, self.sockpath) - await self.server.serve_forever() + def to_session(self, session: int): + return SnifferSession(self, session) - async def _async_stop(self): - self.server.close() - def stop(self) -> None: - self.coroutine_runner_threadsafe(self._stop()) +class SnifferSession: + def __init__(self, sniffer: Sniffer, id_: int): + self.sniffer = sniffer + self.id = id_ + def __getattr__(self, key: str): + value = getattr(self.client, key) + if isinstance(value, Callable): + value = functools.partial(value, session=self.id) -class SnifferThreadCombiner: - """Combines sniffers from different threads and makes them accessible - from a single thread.""" + return value - async def __init__(self, sniffers: list[Sniffer]): - self.sniffers = sniffers - self.read_queue = ThreadSafeAsyncQueue(0, asyncio.get_event_loop()) - for sniffer in self.sniffers: - sniffer.set_httpobj_queue(self.read_queue) +class ComparatorSniffer: + def __init__(self, sniffer1: Sniffer, sniffer2: Sniffer, on_diff: Callable): + self.sniffer1 = sniffer1 + self.sniffer2 = sniffer2 - async def get_message(self): - return await self.read_queue.async_get() + self.on_diff = on_diff - async def send_message(self, sniffer, obj: object): - sniffer.modification_queue.put(obj) + self.http_queues = defaultdict(lambda _: asyncio.Queue(), asyncio.Queue()) + self.waiting_tasks = [] + async def handle_request_response(self, session2, queue1: asyncio.Queue, queue2: asyncio.Queue): + session1 = self.sniffer1 + session2 = self.sniffer2 -# TODO: The sniffer comparator should better have references to the sniffers -class SnifferComparator: - def __init__(self, on_fail, sniffer1: Sniffer, sniffer2: Sniffer, event_loop=None): - self.on_fail = on_fail + req1 = await queue1.get() + assert isinstance(req1, RequestData) + if req1.session is not None: + session1 = session1.to_session(req1.session) + await session1.send_command(SniffCommand.NOP) - """ - self.req1 = None - self.req2 = None - self.res1 = None - self.request1_ready = asyncio.Semaphore(0) - self.request2_ready = asyncio.Semaphore(0) - self.requests_ready = asyncio.Semaphore(0) + req2 = await queue2.get() + assert isinstance(req2, RequestData) - self.request1_lock = asyncio.Lock() - self.request2_lock = asyncio.Lock() - """ + if req1 != req2: + self.on_diff(req1, req2) - self.requests_passed = 0 + if req2.session is not None: + session2 = self.sniffer2.to_session(req2.session) + await session2.send_command(SniffCommand.REPLACE, response=res1) - self.sniffer1 = sniffer1 - self.sniffer2 = sniffer2 + res1 = await queue1.get() + assert isinstance(res1, ResponseData) + await session1.send_command(SniffCommand.NOP) - self.sniffer1_httpobj_queue = ThreadSafeAsyncQueue(0, event_loop) - self.sniffer2_httpobj_queue = ThreadSafeAsyncQueue(0, event_loop) + res2 = await queue2.get() + assert isinstance(res2, ResponseData) - sniffer1.set_httpobj_queue(self.sniffer1_httpobj_queue) - sniffer2.set_httpobj_queue(self.sniffer2_httpobj_queue) - # self.combiner = SnifferThreadCombiner([sniffer1, sniffer2]) - self.running = False + # If the proxy works as expected, this should never happen + if res1 != res2: + self.on_diff(res1, res2) + await session2.send_command(SniffCommand.NOP) async def run(self): - self.running = True - while self.running: - _, req1 = await self.sniffer1_httpobj_queue.async_get() - print("Got req1") - if not isinstance(req1, RequestData): - raise SnifferException("Expected RequestData from sniffer1") - - _, req2 = await self.sniffer2_httpobj_queue.async_get() - print("Got req2") - if not isinstance(req2, RequestData): - raise SnifferException("Expected RequestData from sniffer2") - - if req1 != req2: - print("Calling on_fail") - await self.on_fail(self, req1, req2) - - self.sniffer1.publish_modification((None, None)) - _, res1 = await self.sniffer1_httpobj_queue.async_get() - self.sniffer1.publish_modification((None, None)) - if not isinstance(res1, ResponseData): - raise SnifferException("Expected RespsoneData from sniffer1") - - self.sniffer2.publish_modification((None, res1)) - _, res2 = await self.sniffer2_httpobj_queue.async_get() - if not isinstance(res2, ResponseData): - raise SnifferException("Expected RespsoneData from sniffer2") - - if res1 != res2: - print("Expected the safe response from sniffer2") - self.on_fail(self, res1, res2) - - self.sniffer2.publish_modification((None, None)) - - self.requests_passed += 1 - - print("Done comparator iteration") - - async def on_request1(self, req: RequestData): - print("Got request1") - async with self.request1_lock: - self.req1 = req - self.request1_ready.release() - print("Released req1_ready") - await self.request2_ready.acquire() - - print("Barrier released") - - if self.req1 != self.req2: - print("Calling on_fail") - self.on_fail(self.req1, self.req2) - - return None, None - - self.req1 = None - - async def on_request2(self, req: RequestData): - print("Got request2") - async with self.request2_lock: - self.req2 = req - self.request2_ready.release() - - print("Released req2_ready") - - await self.requests_ready.acquire() - print("Got response for req2") - self.req2 = None - - return None, self.res1 - - async def on_response1(self, res: ResponseData): - print("Got response1") - # Assume on_request1 was called for the corresponding request - self.res1 = res - print("Well, nice") - self.requests_ready.release() - print("Released") - - async def on_response2(self, res: ResponseData): - print("On response2 was called but this should never happen") - - -@dataclass -class ProxyInfo: - proc: subprocess.Popen - host: str - port: int - name: str - - -class MitmproxySnifferManager: - def __init__(self, sock_suffix: str, event_loop=None): - if event_loop is None: - event_loop = asyncio.get_event_loop() - self.sniffer = Sniffer( - sock_suffix=sock_suffix, - loop=event_loop, - ) - self.thread = None - - self.proxies = {} - - def start_sniffer_on_thread(self): - if self.thread is not None: - raise SnifferException("Sniffer thread already started") - - self.sniffer.running = True - self.thread = threading.Thread(target=self.sniffer.run, daemon=True) - self.thread.start() - # Wait for the sniffer to open the listening socket - time.sleep(0.5) - - def start_proxy_instance(self, host=None, port=None): - if self.thread is None: - raise SnifferException("Sniffer must be started before starting a proxy") - - kwargs = {} - if host is not None: - kwargs["host"] = host - if port is not None: - kwargs["port"] = port + on_message1 = asyncio.create_task(self.sniffer1.get_message()) + on_message2 = asyncio.create_task(self.sniffer2.get_message()) + while True: + done, pending = await asyncio.wait([on_message1, on_message2], return_when=asyncio.FIRST_COMPLETED) + + for done_task in done: + obj = done_task.result() + if not isinstance(obj, (RequestData, ResponseData)): + logging.error("Received unwanted object: %s", obj) + continue + + key = obj.meta.object_id + if key not in self.http_queues: + q1, q2 = asyncio.Queue(), asyncio.Queue() + self.http_queues[obj.meta.object_id] = (q1, q2) + self.waiting_tasks.append(asyncio.create_task(handle_request_response(q1, q2))) + + if done_task is on_message1: + q1 = self.http_queues[key][0] + await q1.put(obj) + on_message1 = asyncio.create_task(self.sniffer1.get_message()) + + if done_task is on_message2: + q2 = self.http_queues[key][1] + await q2.put(obj) + on_message2 = asyncio.create_task(self.sniffer2.get_message()) + + # TODO: maybe use a while loop? + + +async def mitmproxy_run( + sniffer_socket_address: str, + host: str = "localhost", + port: int = 8080, + addon_script: str = "intercept_addon.py", + binary: str = "mitmdump", +): + proxy_name: str = "".join(random.choices(string.ascii_letters, k=32)) + module_dir = os.path.dirname(os.path.realpath(__file__)) + addon_path = os.path.join(module_dir, addon_script) + cli_args = [ + "--mode", + "regular", + "--listen-host", + host, + "--listen-port", + str(port), + "-s", + addon_path, + "--set", + f"socketaddress={sniffer_socket_address}", + "--set", + f"proxyname={proxy_name}", + ] + args = [binary] + cli_args + + # TODO: stderr DEVNULL + p = await asyncio.create_subprocess_exec(*args) # , stdout=subprocess.PIPE) + + return p, host, port, proxy_name + + +class MitmproxySniffer(Sniffer): + def __init__( + self, + sockaddr: str, + proxy_host: str, + proxy_port: int, + proxy_name: str, + mitmproxy_proc: asyncio.subprocess.Process, + ): + super().__init__() + self.sockaddr = sockaddr + self._server = None + + self._read_queue: asyncio.Queue[bytes] = asyncio.Queue() + self._write_queue: asyncio.Queue[bytes] = asyncio.Queue() + + self._proc = mitmproxy_proc + self.proxy_host = proxy_host + self.proxy_port = proxy_port + self.proxy_name = proxy_name + + async def init(self): + self._server = await asyncio.start_unix_server(self.on_client_connected, path=self.sockaddr) + + @property + def proxy_url(self): + return f"http://{self.proxy_host}:{self.proxy_port}" + + async def _get_data(self): + wait_task = asyncio.create_task(self._proc.wait()) + read_task = asyncio.create_task(self._read_queue.get()) + try: + done, _pending = await asyncio.wait([wait_task, read_task], return_when=asyncio.FIRST_COMPLETED) + finally: + # Here, I found out that task cancelation can create bugs + # Hence, in a couroutine, you should always expect that when you + # await a coroutine/task, that can generate a CancelledError. + wait_task.cancel() + read_task.cancel() - proc, host, port, proxy_name = mitm_runner.run(self.sniffer.sockpath, **kwargs) + if wait_task in done: + raise SnifferProcessTerminated(None, self._proc) - self.proxies[proxy_name] = ProxyInfo( - proc=proc, - host=host, - port=port, - name=proxy_name, - ) + return read_task.result() - return self.proxies[proxy_name] + async def _send_data(self, data: bytes): + datagram = to_sock_datagram(data) + await self._write_queue.put(datagram) - async def wait_for_proxy_connection_with_sniffer(self): - while True: - sig = await self.sniffer.signal_queue.async_get() - print(f"Got signal on queue: {sig}") - if sig is self.sniffer.SIG_CLIENT_CONNECTED: - break + async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): + """Called when a proxy server has connected to this sniffer.""" + await self._read_queue.put(ProxyEvent(event=ProxyEvent.CONNECT).to_bytes()) - def stop_from_thread(self): - self.sniffer.stop() - self.thread.join() + # Handle both reading and writing concurrently + # Listen on the reader stream and the internal write queue + task1 = asyncio.create_task(async_read_sock_datagram(reader)) + task2 = asyncio.create_task(self._write_queue.get()) + try: + while True: + # ignore pending tasks, because they will be waited again in the next loop iteration + done, _pending = await asyncio.wait([task1, task2], return_when=asyncio.FIRST_COMPLETED) + + if task1 in done: + data = task1.result() + await self._read_queue.put(data) + # recreate the task + task1 = asyncio.create_task(async_read_sock_datagram(reader)) + + if task2 in done: + data = task2.result() + print(f"Writing data: {data}") + writer.write(data) + await writer.drain() + # recreate the task + task2 = asyncio.create_task(self._write_queue.get()) + except asyncio.IncompleteReadError: + pass + finally: + task1.cancel() + task2.cancel() + await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) + while not self._write_queue.empty(): + await self._write_queue.get() + + def stop(self): + print("Stopping") + print(f"proc: {self._proc}") + if self._proc is None: + return - def stop_proxy_instance(self, proxy_name: str): + self._proc.terminate() + print("Terminate") try: - proxy_info = self.proxies[proxy_name] - except KeyError: - raise SnifferException(f"No proxy with the name: {proxy_name!r}") + # Imagine having an API that that has both async and blocking functions + # Sadly, this is not the case for asyncio.subprocess.Process + print("Waiting") + p = self._proc._transport._proc.wait(timeout=0.5) + print("Waited") + except subprocess.TimeoutExpired: + print("Killing") + self._proc.kill() - proxy_info.proc.stop() - del self.proxies[proxy_name] + print("Done stopping") - def __del__(self): - for proxy in self.proxies.values(): - proxy.proc.terminate() + self._proc = None - self.proxies.clear() + def __del__(self): + print(f"Called del for {self}") + # The object might be destroyed before __init__ terminates + if hasattr(self, "_proc"): + self.stop() + + +async def create_mitmproxy_sniffer_comparator( + on_diff: Callable, socksuffix1: str = "1", proxy_port1: int = 8080, socksuffix2: str = "2", proxy_port2: int = 8081 +) -> ComparatorSniffer: + if socksuffix1 == socksuffix2: + raise SnifferException("Socket suffixes must be different") + if proxy_port1 == proxy_port2: + raise SnifferException("Proxy ports must be different") + + sockaddr1 = os.path.join(get_runtime_dir(), DEFAULT_SOCKET_NAME + socksuffix1) + proc1, host1, port1, proxy_name1 = await mitmproxy_run(sockaddr1, port=proxy_port1) + sniffer1 = MitmproxySniffer(sockaddr1, host1, port1, proxy_name1, proc1) + + sockaddr2 = os.path.join(get_runtime_dir(), DEFAULT_SOCKET_NAME + socksuffix2) + proc2, host2, port2, proxy_name2 = await mitmproxy_run(sockaddr2, port=proxy_port2) + sniffer2 = MitmproxySniffer(sockaddr2, host2, port2, proxy_name2, proc2) + + await sniffer1.init() + await sniffer2.init() + + try: + await sniffer1.wait_event(ProxyEvent.CONNECT) + await sniffer2.wait_event(ProxyEvent.CONNECT) + except: + proc1.kill() + out, err = await proc1.communicate() + print("Proc output: {out}") + print("Proc err: {err}") + raise + + sniffer1.ignore_reconnects(True) + sniffer2.ignore_reconnects(True) + comparator = ComparatorSniffer(sniffer1, sniffer2, on_diff) + + return comparator diff --git a/cdprecorder/skopo/sniff_protocol.py b/cdprecorder/skopo/sniff_protocol.py new file mode 100644 index 0000000..21e027e --- /dev/null +++ b/cdprecorder/skopo/sniff_protocol.py @@ -0,0 +1,539 @@ +from __future__ import annotations + +import asyncio +import os.path +import subprocess +from enum import IntEnum +from typing import Callable, TypeVar, TYPE_CHECKING, Union + + +if TYPE_CHECKING: + import socket + from asyncio import StreamReader, StreamWriter + from typing import TypeVar + + +def bytes_to_varlen_bytes(data: bytes): + size = len(data) + return size.to_bytes(8, "big") + data + + +class SnifferMessageType(IntEnum): + REQUEST_DATA = 1 + RESPONSE_DATA = 2 + PROXY_EVENT = 3 + SNIFF_COMMAND = 4 + INT64 = 5 + NONE = 6 + + +class SnifferNone: + @staticmethod + def to_bytes(): + data = b"" + data += SnifferMessageType.NONE.to_bytes(8, "big") + return data + + +class SnifferInt64: + @staticmethod + def to_bytes(value: int): + data = b"" + data += SnifferMessageType.INT64.to_bytes(8, "big") + data += value.to_bytes(8, "big") + return data + + +class SnifferMetadata: + def __init__(self, object_id: int, timestamp: int, proxyname: str): + self.object_id = object_id + self.timestamp = timestamp + self.proxyname = proxyname + + def to_bytes(self) -> bytes: + data = b"" + data += self.object_id.to_bytes(8, "big") + data += self.timestamp.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.proxyname.encode()) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[SnifferMetadata, int]: + i = 0 + object_id = int.from_bytes(data[i : i + 8], "big") + i += 8 + + timestamp = int.from_bytes(data[i : i + 8], "big") + i += 8 + + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + + proxyname = data[i : i + size].decode("utf8") + i += size + + return cls(object_id, timestamp, proxyname, session), i + + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"object_id={self.object_id}, " + text += f"timestamp={self.timestamp}, " + text += f"proxyname={self.proxyname}" + text += ")" + + return text + + +class RequestData: + __slots__ = ["http_version", "method", "url", "headers", "content", "trailers", "meta"] + + def __init__( + self, + http_version: bytes, + method: bytes, + url: bytes, + headers: bytes, + content: bytes, + trailers: bytes, + meta: SnifferMetadata, + sesison: Optional[int], + ): + self.http_version = http_version + self.method = method + self.url = url + self.headers = headers + self.content = content + self.trailers = trailers + self.meta = meta + self.session = session + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RequestData): + raise ValueError("eq only supported for RequestData types") + + # Compare everything but the metadata + if ( + self.http_version != other.http_version + or self.method != other.method + or self.url != other.url + or self.headers != other.headers + or self.content != other.content + or self.trailers != other.trailers + ): + return False + + return True + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.REQUEST_DATA.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.http_version) + data += bytes_to_varlen_bytes(self.method) + data += bytes_to_varlen_bytes(self.url) + data += bytes_to_varlen_bytes(self.headers) + data += bytes_to_varlen_bytes(self.content) + data += bytes_to_varlen_bytes(self.trailers) + data += self.meta.to_bytes() + data += SnifferInt64.to_bytes(self.session) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[RequestData, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.REQUEST_DATA + + components = [] + for _ in range(6): + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + obj = data[i : i + size] + i += size + components.append(obj) + + http_version = components[0] + method = components[1] + url = components[2] + headers = components[3] + content = components[4] + trailers = components[5] + + meta, used = SnifferMetadata.from_bytes(data[i:]) + i += used + + session, used = sniffer_data_from_bytes(data[i:]) + assert isinstance(session, int) or session is None + i += used + + return cls(http_version, method, url, headers, content, trailers, meta), i + + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"http_version={self.http_version}, " + text += f"method={self.method}, " + text += f"url={self.url}, " + text += f"headers={self.headers}, " + text += f"content={self.content!r}, " + text += f"trailers={self.trailers}, " + text += f"meta={self.meta}" + text += ")" + + return text + + +class ResponseData: + __slots__ = ["http_version", "status_code", "reason", "headers", "content", "trailers", "meta"] + + def __init__( + self, + http_version: bytes, + status_code: int, + reason: bytes, + headers: bytes, + content: bytes, + trailers: bytes, + meta: SnifferMetadata, + session: Optional[int], + ): + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.content = content + self.trailers = trailers + self.meta = meta + self.session = session + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResponseData): + raise ValueError("eq only supported for ResponseData types") + + # Compare everything but the metadata + if ( + self.http_version != other.http_version + or self.status_code != other.status_code + or self.reason != other.reason + or self.headers != other.headers + or self.content != other.content + or self.trailers != other.trailers + ): + return False + + return True + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.RESPONSE_DATA.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.http_version) + data += self.status_code.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.reason) + data += bytes_to_varlen_bytes(self.headers) + data += bytes_to_varlen_bytes(self.content) + data += bytes_to_varlen_bytes(self.trailers) + data += self.meta.to_bytes() + data += SnifferInt64.to_bytes(self.session) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[ResponseData, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.RESPONSE_DATA + + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + http_version = data[i : i + size] + i += size + + status_code = int.from_bytes(data[i : i + 8], "big") + i += 8 + + components = [] + for _ in range(4): + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + obj = data[i : i + size] + i += size + components.append(obj) + + reason = components[0] + headers = components[1] + content = components[2] + trailers = components[3] + + meta, used = SnifferMetadata.from_bytes(data[i:]) + i += used + + session, used = sniffer_data_from_bytes(data[i:]) + assert isinstance(session, int) or session is None + i += used + + return cls(http_version, status_code, reason, headers, content, trailers, meta, session), i + + +class ProxyEvent: + CONNECT = 1 + CLOSE = 2 + + def __init__(self, event: int): + self.event = event + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[ProxyEvent, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.PROXY_EVENT + + event = int.from_bytes(data[i : i + 8], "big") + i += 8 + + return cls(event), i + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.PROXY_EVENT.to_bytes(8, "big") + data += self.event.to_bytes(8, "big") + + return data + + +class SnifferError: + def __init__(self, error_msg: str, error_type: str): + self.error_msg = error_msg + self.error_type = error_type + + +class SniffCommand: + NOP = 1 + REPLACE = 2 + CANCEL = 3 + CLOSE_CLIENT = 4 + + def __init__( + self, + command: int, + request: Optional[RequestData] = None, + response: Optional[ResponseData] = None, + meta: Optional[SnifferMetadata] = None, + session: Optional[int] = None, + ): + self.command = command + self.request = request + self.response = response + self.meta = meta + self.session = session + + def to_bytes(self): + data = b"" + data += SnifferMessageType.SNIFF_COMMAND.to_bytes(8, "big") + data += self.command.to_bytes(8, "big") + data += self.request.to_bytes() if self.request is not None else SnifferNone.to_bytes() + data += self.repsonse.to_bytes() if self.response is not None else SnifferNone.to_bytes() + data += self.meta.to_bytes() if self.meta is not None else SnifferNone.to_bytes() + data += SnifferInt64.to_bytes(self.session) if self.session is not None else SnifferNone.to_bytes() + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[SniffCommand, int]: + i = 0 + + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.SNIFF_COMMAND + + command = int.from_bytes(data[i : i + 8], "big") + i += 8 + + request, used = sniffer_data_from_bytes(data[i:]) + i += used + assert isinstance(request, RequestData) or request is None + + response, used = sniffer_data_from_bytes(data[i:]) + i += used + assert isinstance(response, ResponseData) or response is None + + meta, used = sniffer_data_from_bytes(data[i:]) + i += used + assert isinstance(meta, SnifferMetadata) or meta is None + + session, used = sniffer_data_from_bytes(data[i:]) + assert isinstance(session, int) or session is None + i += used + + return cls(command, request, response, meta, session), i + + +class ProxyException(Exception): + def __init__(self, obj: Optional[SnifferError], *args, **kwargs): + super().__init__(*args, **kwargs) + self.obj = obj + + +SnifferMessage: TypeVar = Union[SniffCommand] +ProxyMessage: TypeVar = Union[RequestData, ResponseData, ProxyEvent] + + +def sniffer_data_from_bytes(data: bytes): + message_type = int.from_bytes(data[:8], "big") + if message_type == SnifferMessageType.REQUEST_DATA: + return RequestData.from_bytes(data) + elif message_type == SnifferMessageType.RESPONSE_DATA: + return ResponseData.from_bytes(data) + elif message_type == SnifferMessageType.PROXY_EVENT: + return ProxyEvent.from_bytes(data) + elif message_type == SnifferMessageType.SNIFF_COMMAND: + return SniffCommand.from_bytes(data) + elif message_type == SnifferMessageType.INT64: + assert len(data) >= 16, "Not enough bytes to decode" + return int.from_bytes(data[8:16], "big"), 16 + elif message_type == SnifferMessageType.NONE: + return None, 8 + else: + raise ProxyException(None, f"Unknown message type: {message_type}") + + +class SnifferProxyClient: + def __init__(self): + self.session_to_messages = defaultdict(list) + + def _get_data(self) -> bytes: + raise NotImplementedError + + def pushback_message(self, obj): + self.session_to_messages[obj.session].append(obj) + + def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> SnifferMessage: + if session is not None: + while len(self.session_to_messages[session]): + obj = self.session_to_messages[session].pop(0) + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) + + if isinstance(obj, SnifferMessage): + if obj.session is None and session is not None: + raise ProxyException(obj, "Received sesionless message in a context with session") + return obj + + del self.session_to_messages[session] + + while True: + data = self._get_data() + obj, _ = sniffer_data_from_bytes(data) + + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) + + if isinstance(obj, SnifferMessage): + if obj.session is None and session is not None: + raise ProxyException(obj, "Received sesionless message in a context with session") + if obj.session != session: + self.pushback_message(obj) + else: + return obj + + def get_command(self, ignore_error: bool = False, session: Optional[int] = None) -> SniffCommand: + while True: + msg = self.get_message(ignore_error, session) + if not isinstance(msg, SniffCommand): + raise ProxyException(msg, "Expected message of type SniffCommand") + return msg + + def _send_data(self, data: bytes): + raise NotImplementedError + + def send_proxy_message(self, obj: ProxyMessage, session: Optional[int] = None): + if hasattr(obj, "session"): + obj.session = session + data = obj.to_bytes() + self._send_data(data) + + def send_request_data(self, obj: RequestData, session: Optional[int] = None): + obj.session = session + self.send_proxy_message(obj) + + def send_response_data(self, obj: ResponseData, session: Optional[int] = None): + obj.session = session + self.send_proxy_message(obj) + + def send_proxy_event(self, obj: ProxyEvent, session: Optional[int] = None): + self.send_proxy_message(obj) + + def send_error(self, msg: SnifferError, session: Optional[int] = None): + msg.session = session + data = msg.to_bytes() + self._send_data(data) + + def new_session(self): + return SnifferClientSession(self) + + +class SnifferClientSession: + _LAST_ID = 1 + + def __init__(self, client: SnifferProxyClient): + self.client = client + self.id = SnifferClientSession._LAST_ID + SnifferClientSession._LAST_ID += 1 + + def __getattr__(self, key: str): + value = getattr(self.client, key) + if isinstance(value, Callable): + value = functools.partial(value, session=self.id) + + return value + + +class BufferedSnifferProxyClient(SnifferProxyClient): + def __init__(self): + self.command_queue = asyncio.Queue() + self.request_queue = asyncio.Queue() + self.response_queue = asyncio.Queue() + pass + + def get_message(self, ignore_error: bool = False) -> SnifferMessage: + while True: + data = self._get_data() + obj, _ = sniffer_data_from_bytes(data) + + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) + + if isinstance(obj, SnifferMessage): + return obj + + def get_command(self, ignore_error: bool = False) -> SniffCommand: + while True: + msg = self.get_message(ignore_error) + if not isinstance(msg, SniffCommand): + raise ProxyException(msg, "Expected message of type SniffCommand") + return msg + + +def to_sock_datagram(data: bytes): + size = len(data) + return size.to_bytes(8, "big") + data + + +async def async_read_sock_datagram(reader: StreamReader): + data_size = await reader.readexactly(8) + size = int.from_bytes(data_size, "big") + data = await reader.readexactly(size) + return data + + +def read_sock_datagram(sock: socket.socket): + data_size = sock.recv(8) + if len(data_size) != 8: + raise ProxyException("Received less than expected data") + size = int.from_bytes(data_size, "big") + data = sock.recv(size) + if len(data) != size: + raise ProxyException("Received less than expected data") + return data From 41d2fee05b42493254858b7fea63b4bf66066953 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 1 Sep 2025 01:41:56 +0300 Subject: [PATCH 12/18] feat: change the way the proxy uses the sniff protocol --- cdprecorder/skopo/intercept_addon.py | 189 ++++++++++++++++----------- 1 file changed, 113 insertions(+), 76 deletions(-) diff --git a/cdprecorder/skopo/intercept_addon.py b/cdprecorder/skopo/intercept_addon.py index 27d39db..a1386a1 100644 --- a/cdprecorder/skopo/intercept_addon.py +++ b/cdprecorder/skopo/intercept_addon.py @@ -1,24 +1,22 @@ from __future__ import annotations import logging -import pickle import socket import sys from urllib.parse import urlunparse from typing import Optional, TYPE_CHECKING -from mitmproxy import ctx, http +from mitmproxy import ctx, http, tcp -import common_data -from common_data import SniffConnection, RequestData, ResponseData - -# :'( -# Used by pickle to load RequestData and ResponseData -sys.modules["cdprecorder"] = common_data -sys.modules["cdprecorder.common_data"] = common_data -sys.modules["cdprecorder.skopo.common_data"] = common_data - -import requests +from sniff_protocol import ( + RequestData, + ResponseData, + SniffCommand, + SnifferMetadata, + SnifferProxyClient, + read_sock_datagram, + to_sock_datagram, +) if TYPE_CHECKING: @@ -58,13 +56,34 @@ def filter(self, record: logging.LogRecord) -> bool: return True +class MitmproxySnifferProxyClient(SnifferProxyClient): + def __init__(self, sockaddr: str): + self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.client.connect(sockaddr) + + def _get_data(self) -> bytes: + data = read_sock_datagram(self.client) + return data + + def _send_data(self, data: bytes): + datagram = to_sock_datagram(data) + self.client.sendall(datagram) + + def close(self): + if self.client is not None: + self.client.shutdown(socket.SHUT_RDWR) + self.client.close() + self.client = None + + def __del__(self): + self.close() + + class TheSpy: def __init__(self): - self.data_sender = None + self.client = None self.flows = set() - logging.info("PPpp", extra={"client": "haubau"}) - self.log_filter = PrefixFilter() handler = logging.getLogger().handlers[0] handler.addFilter(self.log_filter) @@ -90,12 +109,11 @@ def load(self, loader): ) def start_connection(self, socketaddress: str): - if self.data_sender is not None: - self.data_sender.close() + if self.client is not None: + self.client.close() logging.info("Connecting to %s", socketaddress) - self.data_sender = SniffConnection(socketaddress) + self.client = MitmproxySnifferProxyClient(socketaddress) assert isinstance(self.proxyname, str) - self.data_sender.send(self.proxyname) def running(self): if ctx.options.socketaddress is not None: @@ -108,6 +126,7 @@ def configure(self, updated: set[str]): self.start_connection(ctx.options.socketaddress) def request(self, flow: http.HTTPFlow): + pass self.flows.add(id(flow)) logging.info("Intercepted request") mitmreq = flow.request @@ -125,49 +144,50 @@ def request(self, flow: http.HTTPFlow): ) req = RequestData( - http_version=mitmreq.http_version, - method=mitmreq.method, - url=url, - headers=list(mitmreq.headers.items(multi=True)), - raw_content=mitmreq.raw_content, - trailers=list(mitmreq.trailers.items(multi=True) if mitmreq.trailers else []), - object_id=id(flow), + http_version=mitmreq.http_version.encode(), + method=mitmreq.method.encode(), + url=url.encode(), + headers=bytes(mitmreq.headers), + content=mitmreq.raw_content, + trailers=bytes(mitmreq.trailers) if mitmreq.trailers else b"", + meta=SnifferMetadata(object_id=id(flow), timestamp=0, proxyname=self.proxyname), ) - self.data_sender.send(req) + self.client.send_request_data(req) logging.info("Sent request") - status = self.data_sender.read_str() - - if status == "REPLACE_RESPONSE": - r = self.data_sender.read_response_data() - - logging.info("Headers: %s", r.headers) - headers_bytes = [] - for key, value in r.headers: - headers_bytes.append((key.encode(), value.encode())) - resp = http.Response.make( - status_code=r.status_code, - content=r.raw_content, - headers=headers_bytes, - ) - resp.http_version = r.http_version - resp.reason = r.reason - resp.trailers = http.Headers(r.trailers) + command = self.client.get_command() - flow.response = resp - elif status == "REPLACE_REQUEST": - r = self.data_sender.read_request_data() + if command.command == SniffCommand.REPLACE: + if command.response is not None: + r = command.response - req = http.Request.make(method=r.method, url=r.url, content=r.raw_content, headers=http.Headers(r.headers)) - req.http_version = (r.http_version,) - req.reason = (r.reason,) - req.trailers = (http.Headers(r.trailers),) + headers_bytes = [] + for key, value in r.headers: + headers_bytes.append((key.encode(), value.encode())) + resp = http.Response.make( + status_code=r.status_code, + content=r.raw_content, + headers=headers_bytes, + ) + resp.http_version = r.http_version + resp.reason = r.reason + resp.trailers = http.Headers(r.trailers) - flow.request = req + flow.response = resp + if command.request is not None: + r = command.request + + req = http.Request.make( + method=r.method, url=r.url, content=r.raw_content, headers=http.Headers(r.headers) + ) + req.http_version = (r.http_version,) + req.reason = (r.reason,) + req.trailers = (http.Headers(r.trailers),) - elif status != "OK": - raise Exception(f"Unknown status: {status}") + flow.request = req + elif command.command != SniffCommand.NOP: + raise Exception(f"Unknown command: {command.command}") logging.info("OK Intercepted request") @@ -182,43 +202,60 @@ def response(self, flow: http.HTTPFlow): mitmres = flow.response logging.info("Response headers: %s", mitmres.headers) res = ResponseData( - http_version=mitmres.http_version, + http_version=mitmres.http_version.encode(), status_code=mitmres.status_code, - reason=mitmres.reason, - headers=list(mitmres.headers.items(multi=True)), - raw_content=mitmres.raw_content, - trailers=list(mitmres.trailers.items(multi=True) if mitmres.trailers else []), - object_id=id(flow), + reason=mitmres.reason.encode(), + headers=bytes(mitmres.headers), + content=mitmres.raw_content, + trailers=bytes(mitmres.trailers) if mitmres.trailers else b"", + meta=SnifferMetadata(object_id=id(flow), timestamp=0, proxyname=self.proxyname), ) logging.info("Constructed ResponseData") - self.data_sender.send(res) + self.client.send_response_data(res) logging.info("Sent response") - status = self.data_sender.read_str() - logging.info("Got status in response") + command = self.client.get_command() + logging.info("Got command in response") - if status == "REPLACE_RESPONSE": - r = self.data_sender.read_response_data() + if command.command == SniffCommand.REPLACE: + if command.response is not None: + r = command.response - resp = http.Response.make( - status_code=r.status_code, - content=r.raw_content, - headers=http.Headers(r.headers), - ) - resp.http_version = (r.http_version,) - resp.reason = (r.reason,) - resp.trailers = (http.Headers(r.trailers),) + resp = http.Response.make( + status_code=r.status_code, + content=r.raw_content, + headers=http.Headers(r.headers), + ) + resp.http_version = (r.http_version,) + resp.reason = (r.reason,) + resp.trailers = (http.Headers(r.trailers),) - flow.response = resp - elif status != "OK": - raise Exception(f"Unknown status: {status}") + flow.response = resp + elif command.command != SniffCommand.NOP: + raise Exception(f"Unknown command: {sommand.command}") logging.info("OK Intercepted response") except Exception as exc: logging.info("Addon Exception: %s", exc) raise + def server_connect(self, data): + logging.info("About to connect to: %s. %s", data.server.address, str(data.server)) + if data.server.address[0] == "local": + data.server.address = ("127.0.0.1", data.server.address[1]) + + +def tcp_message(flow: tcp.TCPFlow): + from mitmproxy.utils import strutils + + message = flow.messages[-1] + # message.content = message.content.replace(b"foo", b"bar") + + logging.info( + f"tcp_message[from_client={message.from_client}), content={strutils.bytes_to_escaped_str(message.content)}]" + ) + addons = [TheSpy()] From 39dc35c86c5a25847685cef281fb89064bb6128f Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 1 Sep 2025 01:42:58 +0300 Subject: [PATCH 13/18] feat: don't navigate to url in run_csrf_form_submitsuccess --- tests/test_e2e_run/selenium_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_e2e_run/selenium_runners.py b/tests/test_e2e_run/selenium_runners.py index 5175fbe..76d7060 100644 --- a/tests/test_e2e_run/selenium_runners.py +++ b/tests/test_e2e_run/selenium_runners.py @@ -5,7 +5,7 @@ def run_csrf_form_submitsuccess(driver): # Step # | name | target | value # 1 | open | / | - driver.get("http://127.0.0.1:5000/") + # driver.get("http://localhost:5000/") # 2 | setWindowSize | 736x729 | driver.set_window_size(736, 729) # 3 | click | css=html | From f3d0a8d18aa2cc87c213f0f49f938382395108d4 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Mon, 1 Sep 2025 01:43:55 +0300 Subject: [PATCH 14/18] feat: improve http comparator test --- tests/test_e2e_run/test_http_apps.py | 69 +++++++++++++++++++++------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/tests/test_e2e_run/test_http_apps.py b/tests/test_e2e_run/test_http_apps.py index 4d7d1f8..e605742 100644 --- a/tests/test_e2e_run/test_http_apps.py +++ b/tests/test_e2e_run/test_http_apps.py @@ -59,51 +59,81 @@ async def on_fail(comparator, httpobj1, httpobj2): @pytest.mark.asyncio async def test_csrf_form_run_csrf_from_submit_success(): + logging.getLogger("selenium").setLevel(logging.WARNING) + logging.getLogger("pycdp").setLevel(logging.WARNING) logging.info("Starting web app") app = VenvAppRunner("http_apps/csrf_form") app.wait_until_up() logging.info("Web app started") - sniffer_manager1 = skopo.MitmproxySnifferManager("1") - sniffer_manager1.start_sniffer_on_thread() - proxy1 = sniffer_manager1.start_proxy_instance(port=8080) + comparator = await skopo.create_mitmproxy_sniffer_comparator(on_fail) - sniffer_manager2 = skopo.MitmproxySnifferManager("2") - sniffer_manager2.start_sniffer_on_thread() - proxy2 = sniffer_manager2.start_proxy_instance(port=8081) - - await sniffer_manager1.wait_for_proxy_connection_with_sniffer() - await sniffer_manager2.wait_for_proxy_connection_with_sniffer() - - proxy_url1 = f"http://{proxy1.host}:{proxy1.port}" - proxy_url2 = f"http://{proxy2.host}:{proxy2.port}" + proxy_url1 = comparator.sniffer1.proxy_url + proxy_url2 = comparator.sniffer2.proxy_url # proxy_url = f"http://{proxy_info.host}:{proxy_info.port}" options = webdriver.ChromeOptions() cdp_port = 9222 options.add_argument(f"--remote-debugging-port={cdp_port}") options.add_argument(f"--proxy-server={proxy_url1}") - options.add_argument("--ignore-ceritifcate-erros") + options.add_argument("--headless=new") + capabilities = options.to_capabilities() + capabilities["acceptInsecureCerts"] = True + print(options.to_capabilities()) + options.binary_location = "/usr/bin/google-chrome-stable" driver = webdriver.Chrome(options) logging.info("Instantiated web driver") recorder_options = recorder.RecorderOptions( - "http://localhost:5000", + "http://local:5000", cdp_host="localhost", cdp_port=cdp_port, collect_all=True, ) + sniffer1 = comparator.sniffer1 + + async def pass_sniffer(): + req = None + try: + while True: + req = await sniffer1.get_message() + await sniffer1.send_command(skopo.SniffCommand.NOP) + except asyncio.CancelledError: + if req is not None: + await sniffer1.send_command(skopo.SniffCommand.NOP) + pass + try: - rec = await recorder.init_recorder(recorder_options) + t1 = asyncio.create_task(recorder.init_recorder(recorder_options)) + t2 = asyncio.create_task(pass_sniffer()) + done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED) + + if t1 not in done: + assert False, "Recorder should end" + await asyncio.sleep(2) + t2.cancel() + + rec = t1.result() t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) - t2 = asyncio.create_task(recorder.collect_communications(rec, 20)) - done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED) + t2 = asyncio.create_task(recorder.collect_communications(rec, 30)) + t3 = asyncio.create_task(log_sniffer()) + done, pending = await asyncio.wait([t1, t2, t3], return_when=asyncio.FIRST_COMPLETED) + + if t2 in done: + logging.info("Task2 is done") + logging.info("Result: %s", t2.result()) + + if t3 in done: + logging.info("Task3 is done") if t1 in pending: raise Exception("Recorder stopped before selenium test") + if t3 in pending: + t3.cancel() + if t2 in pending: rec.listener.cancel() communications = await t2 @@ -112,6 +142,11 @@ async def test_csrf_form_run_csrf_from_submit_success(): logging.info("Recorded communications") + sniffer1.stop() + comparator.sniffer2.stop() + + return + actions = erpeto.parse_communications_into_actions(communications) erpeto.make_action_ids_consecutive_from_list(actions) # actions = await erpeto.run_recorder(recorder_options) From 8958a49329d6687347feddcb29dab7f08786acff Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Tue, 2 Sep 2025 02:09:22 +0300 Subject: [PATCH 15/18] feat: improve session management in sniffer --- cdprecorder/skopo/__init__.py | 130 ++++++++++++++++++++------- cdprecorder/skopo/intercept_addon.py | 21 +++-- cdprecorder/skopo/sniff_protocol.py | 79 ++++++---------- 3 files changed, 144 insertions(+), 86 deletions(-) diff --git a/cdprecorder/skopo/__init__.py b/cdprecorder/skopo/__init__.py index 8afe1f0..13f9d48 100644 --- a/cdprecorder/skopo/__init__.py +++ b/cdprecorder/skopo/__init__.py @@ -1,12 +1,14 @@ from __future__ import annotations import asyncio +import functools import os.path import random import socket import string import subprocess -from typing import TypeVar, TYPE_CHECKING +from collections import defaultdict +from typing import Callable, TYPE_CHECKING from .._storage import DEFAULT_SOCKET_NAME, get_runtime_dir from .sniff_protocol import ( @@ -16,14 +18,16 @@ RequestData, ResponseData, SniffCommand, + SnifferError, + SnifferMessage, sniffer_data_from_bytes, to_sock_datagram, - SnifferError, ) if TYPE_CHECKING: from asyncio import StreamReader, StreamWriter - from typing import TypeVar + +CNT = 0 class SkopoException(Exception): @@ -56,7 +60,7 @@ async def _get_data(self) -> bytes: async def pushback_message(self, obj): await self.session_to_msg_queues[obj.session].put(obj) - def _handle_message(self, obj, session: Optional[int]) -> Optional[ProxyMessage]: + def _handle_message(self, obj, ignore_error: bool = False, session: Optional[int] = None) -> Optional[ProxyMessage]: if isinstance(obj, SnifferError) and not ignore_error: raise ProxyException(obj) @@ -67,12 +71,12 @@ def _handle_message(self, obj, session: Optional[int]) -> Optional[ProxyMessage] if isinstance(obj, SnifferMessage): if obj.session is None and session is not None: raise ProxyException(obj, "Received sesionless message in a context with session") - if obj.session != session: + if session is not None and obj.session != session: self.pushback_message(obj) - else: - return obj + return None - return None + # TODO: check if it's a bug it this returns when session is None, but obj.session is not None + return obj async def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> ProxyMessage: if session is not None: @@ -93,7 +97,7 @@ async def get_message(self, ignore_error: bool = False, session: Optional[int] = while results: obj = results.pop(0) - obj = self._handle_message(obj) + obj = self._handle_message(obj, ignore_error, session) if obj is not None: for result in results: self.pushback_message(result) @@ -102,7 +106,7 @@ async def get_message(self, ignore_error: bool = False, session: Optional[int] = while True: data = await self._get_data() obj, _ = sniffer_data_from_bytes(data) - obj = self._handle_message(obj) + obj = self._handle_message(obj, ignore_error, session) if obj is not None: return obj @@ -170,7 +174,7 @@ def __init__(self, sniffer: Sniffer, id_: int): self.id = id_ def __getattr__(self, key: str): - value = getattr(self.client, key) + value = getattr(self.sniffer, key) if isinstance(value, Callable): value = functools.partial(value, session=self.id) @@ -184,10 +188,14 @@ def __init__(self, sniffer1: Sniffer, sniffer2: Sniffer, on_diff: Callable): self.on_diff = on_diff - self.http_queues = defaultdict(lambda _: asyncio.Queue(), asyncio.Queue()) + # self.http_queues = defaultdict(lambda _: asyncio.Queue(), asyncio.Queue()) + self.sessions1 = defaultdict(lambda _: asyncio.Queue()) + self.sessions2 = defaultdict(lambda _: asyncio.Queue()) + self.unpaired_sessions1 = {} + self.unpaired_sessions2 = {} self.waiting_tasks = [] - async def handle_request_response(self, session2, queue1: asyncio.Queue, queue2: asyncio.Queue): + async def handle_request_response(self, task, queue1: asyncio.Queue, queue2: asyncio.Queue): session1 = self.sniffer1 session2 = self.sniffer2 @@ -219,35 +227,70 @@ async def handle_request_response(self, session2, queue1: asyncio.Queue, queue2: self.on_diff(res1, res2) await session2.send_command(SniffCommand.NOP) + @staticmethod + def _request_data_key(obj: RequestData) -> str: + return obj.method + url + + @staticmethod + async def _try_extracting_session(obj: RequestData, unpaired_sessions: dict) -> Optional[asyncio.Queue]: + obj_key = self._request_data_key(obj) + if obj_key in unpaired_sessions: + value = unpaired_sessions[obj_key] + del unpaired_sessions[obj_key] + return value + + return value + + async def _create_comparator_task(self, q1: asyncio.Queue, q2: asyncio.Queue): + task = asyncio.create_task(handle_request_response(q1, q2)) + self.waiting_tasks.append((task, q1, q2)) + return task + async def run(self): + print("Comparator start run") on_message1 = asyncio.create_task(self.sniffer1.get_message()) on_message2 = asyncio.create_task(self.sniffer2.get_message()) while True: + print("Comparator await") done, pending = await asyncio.wait([on_message1, on_message2], return_when=asyncio.FIRST_COMPLETED) + print("Comparator callback") for done_task in done: obj = done_task.result() + print(f"ComparatorSniffer object: {obj}") if not isinstance(obj, (RequestData, ResponseData)): logging.error("Received unwanted object: %s", obj) continue - key = obj.meta.object_id - if key not in self.http_queues: - q1, q2 = asyncio.Queue(), asyncio.Queue() - self.http_queues[obj.meta.object_id] = (q1, q2) - self.waiting_tasks.append(asyncio.create_task(handle_request_response(q1, q2))) + object_id = obj.meta.object_id if done_task is on_message1: - q1 = self.http_queues[key][0] + if obj.session not in self.sessions1: + q1 = asyncio.Queue() + self.sessions1[obj.session] = q1 + q2 = self._try_extracting_session(obj, self.unpaired_sessions2) + if q2 is not None: + self._create_comparator_task(q1, q2) + + q1 = self.sessions1[obj.session] await q1.put(obj) on_message1 = asyncio.create_task(self.sniffer1.get_message()) if done_task is on_message2: - q2 = self.http_queues[key][1] + if obj.session not in self.sessions2: + q2 = asyncio.Queue() + self.sessions2[obj.session] = q2 + q1 = self._try_extracting_session(obj, self.unpaired_sessions1) + if q1 is not None: + self._create_comparator_task(q1, q2) + + q2 = self.sessions2[obj.session] await q2.put(obj) on_message2 = asyncio.create_task(self.sniffer2.get_message()) - # TODO: maybe use a while loop? + def stop(self): + self.sniffer1.stop() + self.sniffer2.stop() async def mitmproxy_run( @@ -294,6 +337,7 @@ def __init__( super().__init__() self.sockaddr = sockaddr self._server = None + self.stop_event = asyncio.Event() self._read_queue: asyncio.Queue[bytes] = asyncio.Queue() self._write_queue: asyncio.Queue[bytes] = asyncio.Queue() @@ -323,26 +367,36 @@ async def _get_data(self): read_task.cancel() if wait_task in done: + print("Sniffer proxy error") raise SnifferProcessTerminated(None, self._proc) return read_task.result() async def _send_data(self, data: bytes): datagram = to_sock_datagram(data) + print(f"Sniffer._send_data: {data}") await self._write_queue.put(datagram) + print(f"done put") async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): + global CNT """Called when a proxy server has connected to this sniffer.""" + print(f"On client connected") await self._read_queue.put(ProxyEvent(event=ProxyEvent.CONNECT).to_bytes()) # Handle both reading and writing concurrently # Listen on the reader stream and the internal write queue task1 = asyncio.create_task(async_read_sock_datagram(reader)) task2 = asyncio.create_task(self._write_queue.get()) + stop_task = asyncio.create_task(self.stop_event.wait()) + print(f"Created task2 for awaiting {self._write_queue}") try: while True: # ignore pending tasks, because they will be waited again in the next loop iteration - done, _pending = await asyncio.wait([task1, task2], return_when=asyncio.FIRST_COMPLETED) + done, _pending = await asyncio.wait([task1, task2, stop_task], return_when=asyncio.FIRST_COMPLETED) + + if stop_task in done: + return if task1 in done: data = task1.result() @@ -352,23 +406,37 @@ async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): if task2 in done: data = task2.result() - print(f"Writing data: {data}") + print(f"Sniffer writes data: {data}") writer.write(data) await writer.drain() # recreate the task task2 = asyncio.create_task(self._write_queue.get()) - except asyncio.IncompleteReadError: + except asyncio.exceptions.IncompleteReadError: pass - finally: - task1.cancel() - task2.cancel() + except asyncio.CancelledError: + raise + except Exception as e: + print(f"Exception: {e}") + # Hopefully, the event loop is still running await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) while not self._write_queue.empty(): await self._write_queue.get() + raise + finally: + task1.cancel() + task2.cancel() + stop_task.cancel() + + await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) + + print(f"on_client_connected: return") def stop(self): + if self._server is not None: + self._server.close() + self.stop_event.set() + self._server = None print("Stopping") - print(f"proc: {self._proc}") if self._proc is None: return @@ -420,10 +488,12 @@ async def create_mitmproxy_sniffer_comparator( except: proc1.kill() out, err = await proc1.communicate() - print("Proc output: {out}") - print("Proc err: {err}") + print(f"Proc output: {out}") + print(f"Proc err: {err}") raise + print("The sniffers are ready to start") + sniffer1.ignore_reconnects(True) sniffer2.ignore_reconnects(True) comparator = ComparatorSniffer(sniffer1, sniffer2, on_diff) diff --git a/cdprecorder/skopo/intercept_addon.py b/cdprecorder/skopo/intercept_addon.py index a1386a1..828feb8 100644 --- a/cdprecorder/skopo/intercept_addon.py +++ b/cdprecorder/skopo/intercept_addon.py @@ -58,6 +58,7 @@ def filter(self, record: logging.LogRecord) -> bool: class MitmproxySnifferProxyClient(SnifferProxyClient): def __init__(self, sockaddr: str): + super().__init__() self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.client.connect(sockaddr) @@ -90,6 +91,8 @@ def __init__(self): wrapper_formatter = WrapperFormatter(handler.formatter) handler.setFormatter(wrapper_formatter) + self.objid_to_session = {} + def load(self, loader): url = f"{ctx.options.listen_host}:{ctx.options.listen_port}" self.log_filter.prefix = url @@ -125,8 +128,14 @@ def configure(self, updated: set[str]): if ctx.options.socketaddress is not None: self.start_connection(ctx.options.socketaddress) + def get_session_for_object_id(self, object_id: int): + if object_id not in self.objid_to_session: + session = self.client.new_session() + self.objid_to_session[object_id] = session + + return self.objid_to_session[object_id] + def request(self, flow: http.HTTPFlow): - pass self.flows.add(id(flow)) logging.info("Intercepted request") mitmreq = flow.request @@ -153,10 +162,11 @@ def request(self, flow: http.HTTPFlow): meta=SnifferMetadata(object_id=id(flow), timestamp=0, proxyname=self.proxyname), ) - self.client.send_request_data(req) + session = self.get_session_for_object_id(id(flow)) + session.send_request_data(req) logging.info("Sent request") - command = self.client.get_command() + command = session.get_command() if command.command == SniffCommand.REPLACE: if command.response is not None: @@ -213,10 +223,11 @@ def response(self, flow: http.HTTPFlow): logging.info("Constructed ResponseData") - self.client.send_response_data(res) + session = self.get_session_for_object_id(id(flow)) + session.send_response_data(res) logging.info("Sent response") - command = self.client.get_command() + command = session.get_command() logging.info("Got command in response") if command.command == SniffCommand.REPLACE: diff --git a/cdprecorder/skopo/sniff_protocol.py b/cdprecorder/skopo/sniff_protocol.py index 21e027e..d9989a7 100644 --- a/cdprecorder/skopo/sniff_protocol.py +++ b/cdprecorder/skopo/sniff_protocol.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import functools import os.path import subprocess +from collections import defaultdict from enum import IntEnum from typing import Callable, TypeVar, TYPE_CHECKING, Union @@ -37,7 +39,10 @@ def to_bytes(): class SnifferInt64: @staticmethod - def to_bytes(value: int): + def to_bytes(value: Optional[int]): + if value is None: + return SnifferNone.to_bytes() + data = b"" data += SnifferMessageType.INT64.to_bytes(8, "big") data += value.to_bytes(8, "big") @@ -73,7 +78,7 @@ def from_bytes(cls, data: bytes) -> tuple[SnifferMetadata, int]: proxyname = data[i : i + size].decode("utf8") i += size - return cls(object_id, timestamp, proxyname, session), i + return cls(object_id, timestamp, proxyname), i def __str__(self) -> str: text = f"{self.__class__.__name__}(" @@ -86,7 +91,7 @@ def __str__(self) -> str: class RequestData: - __slots__ = ["http_version", "method", "url", "headers", "content", "trailers", "meta"] + __slots__ = ["http_version", "method", "url", "headers", "content", "trailers", "meta", "session"] def __init__( self, @@ -97,7 +102,7 @@ def __init__( content: bytes, trailers: bytes, meta: SnifferMetadata, - sesison: Optional[int], + session: Optional[int] = None, ): self.http_version = http_version self.method = method @@ -168,7 +173,7 @@ def from_bytes(cls, data: bytes) -> tuple[RequestData, int]: assert isinstance(session, int) or session is None i += used - return cls(http_version, method, url, headers, content, trailers, meta), i + return cls(http_version, method, url, headers, content, trailers, meta, session), i def __str__(self) -> str: text = f"{self.__class__.__name__}(" @@ -185,7 +190,7 @@ def __str__(self) -> str: class ResponseData: - __slots__ = ["http_version", "status_code", "reason", "headers", "content", "trailers", "meta"] + __slots__ = ["http_version", "status_code", "reason", "headers", "content", "trailers", "meta", "session"] def __init__( self, @@ -196,7 +201,7 @@ def __init__( content: bytes, trailers: bytes, meta: SnifferMetadata, - session: Optional[int], + session: Optional[int] = None, ): self.http_version = http_version self.status_code = status_code @@ -410,29 +415,29 @@ def pushback_message(self, obj): self.session_to_messages[obj.session].append(obj) def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> SnifferMessage: - if session is not None: - while len(self.session_to_messages[session]): - obj = self.session_to_messages[session].pop(0) - if isinstance(obj, SnifferError) and not ignore_error: - raise ProxyException(obj) - - if isinstance(obj, SnifferMessage): - if obj.session is None and session is not None: - raise ProxyException(obj, "Received sesionless message in a context with session") - return obj + print(f"Waiting message with session: {session}") + while len(self.session_to_messages[session]): + obj = self.session_to_messages[session].pop(0) + print(f"Got obj: {obj}") + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) - del self.session_to_messages[session] + if isinstance(obj, SnifferMessage): + if obj.session is None and session is not None: + raise ProxyException(obj, "Received sessionless message in a context with session") + return obj + + del self.session_to_messages[session] while True: data = self._get_data() obj, _ = sniffer_data_from_bytes(data) + print(f"Got obj: {obj}") if isinstance(obj, SnifferError) and not ignore_error: raise ProxyException(obj) if isinstance(obj, SnifferMessage): - if obj.session is None and session is not None: - raise ProxyException(obj, "Received sesionless message in a context with session") if obj.session != session: self.pushback_message(obj) else: @@ -455,15 +460,13 @@ def send_proxy_message(self, obj: ProxyMessage, session: Optional[int] = None): self._send_data(data) def send_request_data(self, obj: RequestData, session: Optional[int] = None): - obj.session = session - self.send_proxy_message(obj) + self.send_proxy_message(obj, session) def send_response_data(self, obj: ResponseData, session: Optional[int] = None): - obj.session = session - self.send_proxy_message(obj) + self.send_proxy_message(obj, session) def send_proxy_event(self, obj: ProxyEvent, session: Optional[int] = None): - self.send_proxy_message(obj) + self.send_proxy_message(obj, session) def send_error(self, msg: SnifferError, session: Optional[int] = None): msg.session = session @@ -490,32 +493,6 @@ def __getattr__(self, key: str): return value -class BufferedSnifferProxyClient(SnifferProxyClient): - def __init__(self): - self.command_queue = asyncio.Queue() - self.request_queue = asyncio.Queue() - self.response_queue = asyncio.Queue() - pass - - def get_message(self, ignore_error: bool = False) -> SnifferMessage: - while True: - data = self._get_data() - obj, _ = sniffer_data_from_bytes(data) - - if isinstance(obj, SnifferError) and not ignore_error: - raise ProxyException(obj) - - if isinstance(obj, SnifferMessage): - return obj - - def get_command(self, ignore_error: bool = False) -> SniffCommand: - while True: - msg = self.get_message(ignore_error) - if not isinstance(msg, SniffCommand): - raise ProxyException(msg, "Expected message of type SniffCommand") - return msg - - def to_sock_datagram(data: bytes): size = len(data) return size.to_bytes(8, "big") + data From 58aa27d283f29f884efe94c2693998d0ca981994 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Tue, 2 Sep 2025 02:10:44 +0300 Subject: [PATCH 16/18] test: now, we can compare requests and chrome --- tests/test_e2e_run/test_http_apps.py | 65 +++++++++++++++++----------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/tests/test_e2e_run/test_http_apps.py b/tests/test_e2e_run/test_http_apps.py index e605742..67dee8f 100644 --- a/tests/test_e2e_run/test_http_apps.py +++ b/tests/test_e2e_run/test_http_apps.py @@ -76,7 +76,7 @@ async def test_csrf_form_run_csrf_from_submit_success(): cdp_port = 9222 options.add_argument(f"--remote-debugging-port={cdp_port}") options.add_argument(f"--proxy-server={proxy_url1}") - options.add_argument("--headless=new") + # options.add_argument("--headless=new") capabilities = options.to_capabilities() capabilities["acceptInsecureCerts"] = True print(options.to_capabilities()) @@ -94,31 +94,37 @@ async def test_csrf_form_run_csrf_from_submit_success(): sniffer1 = comparator.sniffer1 async def pass_sniffer(): - req = None + print("pass sniffer") + msg = None try: while True: - req = await sniffer1.get_message() - await sniffer1.send_command(skopo.SniffCommand.NOP) + msg = await sniffer1.get_message() + + logging.debug("Msg: %s", msg) + logging.debug("Session: %s", msg.session) + await sniffer1.to_session(msg.session).send_command(skopo.SniffCommand.NOP) except asyncio.CancelledError: - if req is not None: - await sniffer1.send_command(skopo.SniffCommand.NOP) + if msg is not None: + await sniffer1.to_session(msg.session).send_command(skopo.SniffCommand.NOP) pass + except: + logging.exception("Pass sniffer ended with exception") + pass_task = asyncio.create_task(pass_sniffer()) + rec = await recorder.init_recorder(recorder_options) + await asyncio.sleep(2) + pass_task.cancel() try: - t1 = asyncio.create_task(recorder.init_recorder(recorder_options)) - t2 = asyncio.create_task(pass_sniffer()) - done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED) - - if t1 not in done: - assert False, "Recorder should end" - await asyncio.sleep(2) - t2.cancel() - - rec = t1.result() + pass_task.result() + except (asyncio.CancelledError, asyncio.InvalidStateError): + pass + except: + logging.exception("pass_task ended with exception") + try: t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) t2 = asyncio.create_task(recorder.collect_communications(rec, 30)) - t3 = asyncio.create_task(log_sniffer()) + t3 = asyncio.create_task(pass_sniffer()) done, pending = await asyncio.wait([t1, t2, t3], return_when=asyncio.FIRST_COMPLETED) if t2 in done: @@ -140,12 +146,9 @@ async def pass_sniffer(): finally: await rec.close() - logging.info("Recorded communications") - - sniffer1.stop() - comparator.sniffer2.stop() + driver.quit() - return + logging.info("Recorded communications") actions = erpeto.parse_communications_into_actions(communications) erpeto.make_action_ids_consecutive_from_list(actions) @@ -155,16 +158,26 @@ async def pass_sniffer(): # TODO: probably, a sniffer that's not linked to a comparator should not block # comparator = skopo.SnifferComparator(on_fail, sniffer_manager1.sniffer, sniffer_manager2.sniffer) + driver = webdriver.Chrome(options) + t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) - t2 = asyncio.create_task(asyncio.to_thread(erpeto.run_replicate, actions)) + t2 = asyncio.create_task(asyncio.to_thread(erpeto.run_replicate, actions, [proxy_url2])) t3 = asyncio.create_task(comparator.run()) - pair = asyncio.wait([t1, t2], return_when=asyncio.ALL_COMPLETED) + pair = asyncio.create_task(asyncio.wait([t1, t2], return_when=asyncio.ALL_COMPLETED)) done, pending = await asyncio.wait([pair, t3], return_when=asyncio.FIRST_COMPLETED) + logging.debug("Pending state: %s, %s", pair.done(), t3.done()) + + t3.cancel() + if t3 in done: + assert t3.result() is True + # comparator.stop() + return + # TODO: this probably can't cancel t1 and t2 - for t in pending: - t.cancel() + for pair in pending: + pair.cancel() if t3 in done: res = t3.result() From 88d155b41a66f6100b773c201bda714b05584f25 Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Fri, 5 Sep 2025 03:09:28 +0300 Subject: [PATCH 17/18] feat: improve sniff protocol --- cdprecorder/_storage.py | 2 +- cdprecorder/skopo/__init__.py | 279 ++++++++++++++++----------- cdprecorder/skopo/intercept_addon.py | 35 ++-- cdprecorder/skopo/sniff_protocol.py | 104 +++++----- 4 files changed, 240 insertions(+), 180 deletions(-) diff --git a/cdprecorder/_storage.py b/cdprecorder/_storage.py index 57f7747..f1e74aa 100644 --- a/cdprecorder/_storage.py +++ b/cdprecorder/_storage.py @@ -4,5 +4,5 @@ DEFAULT_SOCKET_NAME = "erpeto.sock" -def get_runtime_dir(): +def get_runtime_dir() -> str: return os.getenv("XDG_RUNTIME_DIR") diff --git a/cdprecorder/skopo/__init__.py b/cdprecorder/skopo/__init__.py index 13f9d48..863ec24 100644 --- a/cdprecorder/skopo/__init__.py +++ b/cdprecorder/skopo/__init__.py @@ -2,32 +2,37 @@ import asyncio import functools +import logging import os.path import random import socket import string import subprocess from collections import defaultdict -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING from .._storage import DEFAULT_SOCKET_NAME, get_runtime_dir from .sniff_protocol import ( - async_read_sock_datagram, ProxyEvent, ProxyMessage, + ProxyException, RequestData, ResponseData, SniffCommand, SnifferError, SnifferMessage, + SkopoMessage, + async_read_sock_datagram, sniffer_data_from_bytes, to_sock_datagram, ) if TYPE_CHECKING: from asyncio import StreamReader, StreamWriter + from typing import Callable, Optional, Union -CNT = 0 + +logger = logging.getLogger(__name__) class SkopoException(Exception): @@ -35,32 +40,34 @@ class SkopoException(Exception): class SnifferException(SkopoException): - def __init__(self, obj: Optional[SnifferError], *args, **kwargs): + def __init__(self, obj: Optional[SkopoMessage], *args: object, **kwargs: object) -> None: super().__init__(*args, **kwargs) self.obj = obj class SnifferProcessTerminated(SkopoException): - def __init__(self, proc: asyncio.subprocess.Process, *args, **kwargs): + def __init__(self, proc: asyncio.subprocess.Process, *args: object, **kwargs: object) -> None: super().__init__(*args, **kwargs) self.proc = proc class Sniffer: - def __init__(self): + def __init__(self) -> None: self._ignore_reconnects = False - self.session_to_msg_queues = defaultdict(asyncio.Queue) + self.session_to_msg_queues: dict[Optional[int], asyncio.Queue[ProxyMessage]] = defaultdict(asyncio.Queue) - def ignore_reconnects(self, ignore=True): + def ignore_reconnects(self, ignore: bool = True) -> None: self._ignore_reconnects = ignore async def _get_data(self) -> bytes: raise NotImplementedError - async def pushback_message(self, obj): - await self.session_to_msg_queues[obj.session].put(obj) + def pushback_message(self, obj: ProxyMessage) -> None: + self.session_to_msg_queues[obj.session].put_nowait(obj) - def _handle_message(self, obj, ignore_error: bool = False, session: Optional[int] = None) -> Optional[ProxyMessage]: + def _handle_message( + self, obj: ProxyMessage, ignore_error: bool = False, session: Optional[int] = None + ) -> Optional[ProxyMessage]: if isinstance(obj, SnifferError) and not ignore_error: raise ProxyException(obj) @@ -70,7 +77,7 @@ def _handle_message(self, obj, ignore_error: bool = False, session: Optional[int if isinstance(obj, SnifferMessage): if obj.session is None and session is not None: - raise ProxyException(obj, "Received sesionless message in a context with session") + raise SnifferException(obj, "Received sesionless message in a context with session") if session is not None and obj.session != session: self.pushback_message(obj) return None @@ -91,7 +98,7 @@ async def get_message(self, ignore_error: bool = False, session: Optional[int] = if t1 in done: results.append(t1.result()) if t2 in done: - data = t1.result() + data = t2.result() obj, _ = sniffer_data_from_bytes(data) results.append(obj) @@ -141,7 +148,7 @@ async def wait_event(self, event_id: int, ignore_error: bool = False, session: O return msg - async def _send_data(self, data: bytes): + async def _send_data(self, data: bytes) -> None: raise NotImplementedError async def send_command( @@ -150,21 +157,24 @@ async def send_command( request: Optional[RequestData] = None, response: Optional[ResponseData] = None, session: Optional[int] = None, - ): + ) -> None: sniff_command = SniffCommand(command, request, response) sniff_command.session = session data = sniff_command.to_bytes() await self._send_data(data) - async def send_error(self, msg: SnifferError, session: Optional[int] = None): + async def send_error(self, msg: SnifferError, session: Optional[int] = None) -> None: msg.session = session data = msg.to_bytes() await self._send_data(data) - async def async_stop(self): - raise NotImplemented + async def async_stop(self) -> None: + raise NotImplementedError + + def stop(self) -> None: + raise NotImplementedError - def to_session(self, session: int): + def to_session(self, session: int) -> SnifferSession: return SnifferSession(self, session) @@ -173,9 +183,9 @@ def __init__(self, sniffer: Sniffer, id_: int): self.sniffer = sniffer self.id = id_ - def __getattr__(self, key: str): + def __getattr__(self, key: str): # type: ignore value = getattr(self.sniffer, key) - if isinstance(value, Callable): + if callable(value): value = functools.partial(value, session=self.id) return value @@ -189,15 +199,21 @@ def __init__(self, sniffer1: Sniffer, sniffer2: Sniffer, on_diff: Callable): self.on_diff = on_diff # self.http_queues = defaultdict(lambda _: asyncio.Queue(), asyncio.Queue()) - self.sessions1 = defaultdict(lambda _: asyncio.Queue()) - self.sessions2 = defaultdict(lambda _: asyncio.Queue()) - self.unpaired_sessions1 = {} - self.unpaired_sessions2 = {} - self.waiting_tasks = [] + self.sessions1: dict[Optional[int], asyncio.Queue[ProxyMessage]] = defaultdict(lambda: asyncio.Queue()) + self.sessions2: dict[Optional[int], asyncio.Queue[ProxyMessage]] = defaultdict(lambda: asyncio.Queue()) + self.unpaired_sessions1: dict[bytes, asyncio.Queue[ProxyMessage]] = {} + self.unpaired_sessions2: dict[bytes, asyncio.Queue[ProxyMessage]] = {} + self.waiting_tasks: list[asyncio.Task] = [] - async def handle_request_response(self, task, queue1: asyncio.Queue, queue2: asyncio.Queue): - session1 = self.sniffer1 - session2 = self.sniffer2 + self.filtered_sessions: set[int] = set() + self.filtered_url_keywords: list[bytes] = [] + + def add_filtered_url_keywords(self, keywords: list[str]) -> None: + self.filtered_url_keywords.extend(kw.encode() for kw in keywords) + + async def handle_request_response(self, queue1: asyncio.Queue, queue2: asyncio.Queue) -> None: + session1: Union[Sniffer, SnifferSession] = self.sniffer1 + session2: Union[Sniffer, SnifferSession] = self.sniffer2 req1 = await queue1.get() assert isinstance(req1, RequestData) @@ -209,16 +225,16 @@ async def handle_request_response(self, task, queue1: asyncio.Queue, queue2: asy assert isinstance(req2, RequestData) if req1 != req2: - self.on_diff(req1, req2) - - if req2.session is not None: - session2 = self.sniffer2.to_session(req2.session) - await session2.send_command(SniffCommand.REPLACE, response=res1) + await self.on_diff(self, req1, req2) res1 = await queue1.get() assert isinstance(res1, ResponseData) await session1.send_command(SniffCommand.NOP) + if req2.session is not None: + session2 = self.sniffer2.to_session(req2.session) + await session2.send_command(SniffCommand.REPLACE, response=res1) + res2 = await queue2.get() assert isinstance(res2, ResponseData) @@ -227,68 +243,108 @@ async def handle_request_response(self, task, queue1: asyncio.Queue, queue2: asy self.on_diff(res1, res2) await session2.send_command(SniffCommand.NOP) - @staticmethod - def _request_data_key(obj: RequestData) -> str: - return obj.method + url + async def _pass_message(self, sniffer: Sniffer, msg: Union[RequestData, ResponseData]) -> None: + session: Union[Sniffer, SnifferSession] = sniffer + if msg.session is not None: + session = sniffer.to_session(msg.session) + + await session.send_command(SniffCommand.NOP) @staticmethod - async def _try_extracting_session(obj: RequestData, unpaired_sessions: dict) -> Optional[asyncio.Queue]: - obj_key = self._request_data_key(obj) + def _request_data_key(obj: RequestData) -> bytes: + return obj.method + obj.url + + @classmethod + def _try_extracting_session( + cls, obj: RequestData, unpaired_sessions: dict[bytes, asyncio.Queue[ProxyMessage]] + ) -> Optional[asyncio.Queue[ProxyMessage]]: + obj_key = cls._request_data_key(obj) if obj_key in unpaired_sessions: value = unpaired_sessions[obj_key] del unpaired_sessions[obj_key] return value - return value + return None - async def _create_comparator_task(self, q1: asyncio.Queue, q2: asyncio.Queue): - task = asyncio.create_task(handle_request_response(q1, q2)) - self.waiting_tasks.append((task, q1, q2)) + async def _create_comparator_task(self, q1: asyncio.Queue, q2: asyncio.Queue) -> asyncio.Task: + task = asyncio.create_task(self.handle_request_response(q1, q2)) + self.waiting_tasks.append(task) return task - async def run(self): + def is_object_filtered(self, obj: Union[RequestData, ResponseData]) -> bool: + if isinstance(obj, RequestData): + return any(kw in obj.url for kw in self.filtered_url_keywords) + + if isinstance(obj, ResponseData): + if obj.session is not None and obj.session in self.filtered_sessions: + return True + + return False + + def remember_filtered_object(self, obj: Union[RequestData, ResponseData]) -> None: + if obj.session is None: + raise SkopoException("Can't remember a filtered object with no session.") + self.filtered_sessions.add(obj.session) + + async def run(self) -> None: print("Comparator start run") on_message1 = asyncio.create_task(self.sniffer1.get_message()) on_message2 = asyncio.create_task(self.sniffer2.get_message()) while True: - print("Comparator await") + logger.debug("Comparator await") done, pending = await asyncio.wait([on_message1, on_message2], return_when=asyncio.FIRST_COMPLETED) - print("Comparator callback") + logger.debug("Comparator callback") - for done_task in done: - obj = done_task.result() - print(f"ComparatorSniffer object: {obj}") - if not isinstance(obj, (RequestData, ResponseData)): - logging.error("Received unwanted object: %s", obj) + for task in [on_message1, on_message2]: + if task not in done: continue - object_id = obj.meta.object_id - - if done_task is on_message1: - if obj.session not in self.sessions1: - q1 = asyncio.Queue() - self.sessions1[obj.session] = q1 - q2 = self._try_extracting_session(obj, self.unpaired_sessions2) - if q2 is not None: - self._create_comparator_task(q1, q2) - - q1 = self.sessions1[obj.session] - await q1.put(obj) - on_message1 = asyncio.create_task(self.sniffer1.get_message()) - - if done_task is on_message2: - if obj.session not in self.sessions2: - q2 = asyncio.Queue() - self.sessions2[obj.session] = q2 - q1 = self._try_extracting_session(obj, self.unpaired_sessions1) - if q1 is not None: - self._create_comparator_task(q1, q2) - - q2 = self.sessions2[obj.session] - await q2.put(obj) - on_message2 = asyncio.create_task(self.sniffer2.get_message()) - - def stop(self): + try: + obj = task.result() + if not isinstance(obj, (RequestData, ResponseData)): + logger.error("Received unwanted object: %s", obj) + continue + + if self.is_object_filtered(obj): + self.remember_filtered_object(obj) + sniffer = self.sniffer1 if task is on_message1 else self.sniffer2 + await self._pass_message(sniffer, obj) + continue + + if task is on_message1: + if obj.session not in self.sessions1: + q1: asyncio.Queue[ProxyMessage] = asyncio.Queue() + self.sessions1[obj.session] = q1 + q2 = self._try_extracting_session(obj, self.unpaired_sessions2) + if q2 is not None: + await self._create_comparator_task(q1, q2) + else: + obj_key = self._request_data_key(obj) + self.unpaired_sessions1[obj_key] = q1 + + q1 = self.sessions1[obj.session] + await q1.put(obj) + + if task is on_message2: + if obj.session not in self.sessions2: + q2 = asyncio.Queue() + self.sessions2[obj.session] = q2 + q1 = self._try_extracting_session(obj, self.unpaired_sessions1) + if q1 is not None: + await self._create_comparator_task(q1, q2) + else: + obj_key = self._request_data_key(obj) + self.unpaired_sessions2[obj_key] = q2 + + q2 = self.sessions2[obj.session] + await q2.put(obj) + finally: + if task is on_message1: + on_message1 = asyncio.create_task(self.sniffer1.get_message()) + elif task is on_message2: + on_message2 = asyncio.create_task(self.sniffer2.get_message()) + + def stop(self) -> None: self.sniffer1.stop() self.sniffer2.stop() @@ -299,7 +355,7 @@ async def mitmproxy_run( port: int = 8080, addon_script: str = "intercept_addon.py", binary: str = "mitmdump", -): +) -> tuple[asyncio.subprocess.Process, str, int, str]: proxy_name: str = "".join(random.choices(string.ascii_letters, k=32)) module_dir = os.path.dirname(os.path.realpath(__file__)) addon_path = os.path.join(module_dir, addon_script) @@ -336,52 +392,53 @@ def __init__( ): super().__init__() self.sockaddr = sockaddr - self._server = None + self._server: Optional[asyncio.Server] = None self.stop_event = asyncio.Event() self._read_queue: asyncio.Queue[bytes] = asyncio.Queue() self._write_queue: asyncio.Queue[bytes] = asyncio.Queue() - self._proc = mitmproxy_proc + self._proc: Optional[asyncio.subprocess.Process] = mitmproxy_proc self.proxy_host = proxy_host self.proxy_port = proxy_port self.proxy_name = proxy_name - async def init(self): + async def init(self) -> None: self._server = await asyncio.start_unix_server(self.on_client_connected, path=self.sockaddr) @property - def proxy_url(self): + def proxy_url(self) -> str: return f"http://{self.proxy_host}:{self.proxy_port}" - async def _get_data(self): - wait_task = asyncio.create_task(self._proc.wait()) + async def _get_data(self) -> bytes: + tasks: list[asyncio.Task] = [] + wait_task = None + if self._proc is not None: + wait_task = asyncio.create_task(self._proc.wait()) + tasks.append(wait_task) read_task = asyncio.create_task(self._read_queue.get()) + tasks.append(read_task) try: - done, _pending = await asyncio.wait([wait_task, read_task], return_when=asyncio.FIRST_COMPLETED) + done, _pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) finally: # Here, I found out that task cancelation can create bugs # Hence, in a couroutine, you should always expect that when you # await a coroutine/task, that can generate a CancelledError. - wait_task.cancel() + if wait_task is not None: + wait_task.cancel() read_task.cancel() - if wait_task in done: - print("Sniffer proxy error") - raise SnifferProcessTerminated(None, self._proc) + if wait_task in done and self._proc is not None: + raise SnifferProcessTerminated(self._proc) return read_task.result() - async def _send_data(self, data: bytes): + async def _send_data(self, data: bytes) -> None: datagram = to_sock_datagram(data) - print(f"Sniffer._send_data: {data}") await self._write_queue.put(datagram) - print(f"done put") - async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): - global CNT + async def on_client_connected(self, reader: StreamReader, writer: StreamWriter) -> None: """Called when a proxy server has connected to this sniffer.""" - print(f"On client connected") await self._read_queue.put(ProxyEvent(event=ProxyEvent.CONNECT).to_bytes()) # Handle both reading and writing concurrently @@ -389,7 +446,6 @@ async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): task1 = asyncio.create_task(async_read_sock_datagram(reader)) task2 = asyncio.create_task(self._write_queue.get()) stop_task = asyncio.create_task(self.stop_event.wait()) - print(f"Created task2 for awaiting {self._write_queue}") try: while True: # ignore pending tasks, because they will be waited again in the next loop iteration @@ -406,7 +462,6 @@ async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): if task2 in done: data = task2.result() - print(f"Sniffer writes data: {data}") writer.write(data) await writer.drain() # recreate the task @@ -416,7 +471,7 @@ async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): except asyncio.CancelledError: raise except Exception as e: - print(f"Exception: {e}") + logger.exception(f"Exception") # Hopefully, the event loop is still running await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) while not self._write_queue.empty(): @@ -429,35 +484,25 @@ async def on_client_connected(self, reader: StreamReader, writer: StreamWriter): await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) - print(f"on_client_connected: return") - - def stop(self): + def stop(self) -> None: if self._server is not None: self._server.close() self.stop_event.set() self._server = None - print("Stopping") if self._proc is None: return self._proc.terminate() - print("Terminate") try: # Imagine having an API that that has both async and blocking functions # Sadly, this is not the case for asyncio.subprocess.Process - print("Waiting") - p = self._proc._transport._proc.wait(timeout=0.5) - print("Waited") + p = self._proc._transport._proc.wait(timeout=0.5) # type: ignore[attr-defined] except subprocess.TimeoutExpired: - print("Killing") self._proc.kill() - print("Done stopping") - self._proc = None - def __del__(self): - print(f"Called del for {self}") + def __del__(self) -> None: # The object might be destroyed before __init__ terminates if hasattr(self, "_proc"): self.stop() @@ -467,9 +512,9 @@ async def create_mitmproxy_sniffer_comparator( on_diff: Callable, socksuffix1: str = "1", proxy_port1: int = 8080, socksuffix2: str = "2", proxy_port2: int = 8081 ) -> ComparatorSniffer: if socksuffix1 == socksuffix2: - raise SnifferException("Socket suffixes must be different") + raise SkopoException("Socket suffixes must be different") if proxy_port1 == proxy_port2: - raise SnifferException("Proxy ports must be different") + raise SkopoException("Proxy ports must be different") sockaddr1 = os.path.join(get_runtime_dir(), DEFAULT_SOCKET_NAME + socksuffix1) proc1, host1, port1, proxy_name1 = await mitmproxy_run(sockaddr1, port=proxy_port1) @@ -488,12 +533,10 @@ async def create_mitmproxy_sniffer_comparator( except: proc1.kill() out, err = await proc1.communicate() - print(f"Proc output: {out}") - print(f"Proc err: {err}") + logger.error("Proc output: %s", out) + logger.error("Proc err: %s", err) raise - print("The sniffers are ready to start") - sniffer1.ignore_reconnects(True) sniffer2.ignore_reconnects(True) comparator = ComparatorSniffer(sniffer1, sniffer2, on_diff) diff --git a/cdprecorder/skopo/intercept_addon.py b/cdprecorder/skopo/intercept_addon.py index 828feb8..256fc71 100644 --- a/cdprecorder/skopo/intercept_addon.py +++ b/cdprecorder/skopo/intercept_addon.py @@ -27,7 +27,7 @@ class WrapperFormatter(logging.Formatter): def __init__(self, formatter: loggingFormatter): self.formatter = formatter - def format(self, record: logging.LogRecord): + def format(self, record: logging.LogRecord) -> str: msg = self.formatter.format(record) msg_prefix = "" @@ -36,13 +36,13 @@ def format(self, record: logging.LogRecord): return f"{msg_prefix}{msg}" - def formatTime(self, *args, **kwargs): + def formatTime(self, *args: object, **kwargs: object) -> str: return self.formatter.formatTime(*args, **kwargs) - def formatException(self, *args, **kwargs): + def formatException(self, *args: object, **kwargs: object) -> str: return self.formatter.formatException(*args, **kwargs) - def formatStack(self, *args, **kwargs): + def formatStack(self, *args: object, **kwargs: object) -> str: return self.formatter.formatStack(*args, **kwargs) @@ -57,7 +57,7 @@ def filter(self, record: logging.LogRecord) -> bool: class MitmproxySnifferProxyClient(SnifferProxyClient): - def __init__(self, sockaddr: str): + def __init__(self, sockaddr: str) -> None: super().__init__() self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.client.connect(sockaddr) @@ -66,11 +66,11 @@ def _get_data(self) -> bytes: data = read_sock_datagram(self.client) return data - def _send_data(self, data: bytes): + def _send_data(self, data: bytes) -> None: datagram = to_sock_datagram(data) self.client.sendall(datagram) - def close(self): + def close(self) -> None: if self.client is not None: self.client.shutdown(socket.SHUT_RDWR) self.client.close() @@ -81,7 +81,7 @@ def __del__(self): class TheSpy: - def __init__(self): + def __init__(self) -> None: self.client = None self.flows = set() @@ -118,24 +118,24 @@ def start_connection(self, socketaddress: str): self.client = MitmproxySnifferProxyClient(socketaddress) assert isinstance(self.proxyname, str) - def running(self): + def running(self) -> None: if ctx.options.socketaddress is not None: self.start_connection(ctx.options.socketaddress) - def configure(self, updated: set[str]): + def configure(self, updated: set[str]) -> None: self.proxyname = ctx.options.proxyname logging.info("Set proxyname=%s", self.proxyname) if ctx.options.socketaddress is not None: self.start_connection(ctx.options.socketaddress) - def get_session_for_object_id(self, object_id: int): + def get_session_for_object_id(self, object_id: int) -> None: if object_id not in self.objid_to_session: session = self.client.new_session() self.objid_to_session[object_id] = session return self.objid_to_session[object_id] - def request(self, flow: http.HTTPFlow): + def request(self, flow: http.HTTPFlow) -> None: self.flows.add(id(flow)) logging.info("Intercepted request") mitmreq = flow.request @@ -165,9 +165,11 @@ def request(self, flow: http.HTTPFlow): session = self.get_session_for_object_id(id(flow)) session.send_request_data(req) - logging.info("Sent request") + logging.info("Sent request session=%s", session.id) command = session.get_command() + logging.info("Got command in request") + if command.command == SniffCommand.REPLACE: if command.response is not None: r = command.response @@ -192,7 +194,6 @@ def request(self, flow: http.HTTPFlow): method=r.method, url=r.url, content=r.raw_content, headers=http.Headers(r.headers) ) req.http_version = (r.http_version,) - req.reason = (r.reason,) req.trailers = (http.Headers(r.trailers),) flow.request = req @@ -201,7 +202,7 @@ def request(self, flow: http.HTTPFlow): logging.info("OK Intercepted request") - def response(self, flow: http.HTTPFlow): + def response(self, flow: http.HTTPFlow) -> None: logging.info("Intercepted response") if id(flow) not in self.flows: logging.error("Got response before existing request") @@ -225,7 +226,7 @@ def response(self, flow: http.HTTPFlow): session = self.get_session_for_object_id(id(flow)) session.send_response_data(res) - logging.info("Sent response") + logging.info("Sent response session=%s", session.id) command = session.get_command() logging.info("Got command in response") @@ -245,7 +246,7 @@ def response(self, flow: http.HTTPFlow): flow.response = resp elif command.command != SniffCommand.NOP: - raise Exception(f"Unknown command: {sommand.command}") + raise Exception(f"Unknown command: {command.command}") logging.info("OK Intercepted response") except Exception as exc: diff --git a/cdprecorder/skopo/sniff_protocol.py b/cdprecorder/skopo/sniff_protocol.py index d9989a7..8d4d7b2 100644 --- a/cdprecorder/skopo/sniff_protocol.py +++ b/cdprecorder/skopo/sniff_protocol.py @@ -6,13 +6,13 @@ import subprocess from collections import defaultdict from enum import IntEnum -from typing import Callable, TypeVar, TYPE_CHECKING, Union +from typing import Callable, TYPE_CHECKING if TYPE_CHECKING: import socket from asyncio import StreamReader, StreamWriter - from typing import TypeVar + from typing import Optional, TypeAlias, Protocol, Union def bytes_to_varlen_bytes(data: bytes): @@ -113,23 +113,6 @@ def __init__( self.meta = meta self.session = session - def __eq__(self, other: object) -> bool: - if not isinstance(other, RequestData): - raise ValueError("eq only supported for RequestData types") - - # Compare everything but the metadata - if ( - self.http_version != other.http_version - or self.method != other.method - or self.url != other.url - or self.headers != other.headers - or self.content != other.content - or self.trailers != other.trailers - ): - return False - - return True - def to_bytes(self) -> bytes: data = b"" data += SnifferMessageType.REQUEST_DATA.to_bytes(8, "big") @@ -183,11 +166,25 @@ def __str__(self) -> str: text += f"headers={self.headers}, " text += f"content={self.content!r}, " text += f"trailers={self.trailers}, " - text += f"meta={self.meta}" + text += f"meta={self.meta}," + text += f"session={self.session}" text += ")" return text + def __eq__(self, other: object) -> bool: + if not isinstance(other, RequestData): + raise NotImplementedError + + return ( + self.http_version == other.http_version + and self.method == other.method + and self.url == other.url + and self.headers == other.headers + and self.content == other.content + and self.trailers == other.trailers + ) + class ResponseData: __slots__ = ["http_version", "status_code", "reason", "headers", "content", "trailers", "meta", "session"] @@ -212,23 +209,6 @@ def __init__( self.meta = meta self.session = session - def __eq__(self, other: object) -> bool: - if not isinstance(other, ResponseData): - raise ValueError("eq only supported for ResponseData types") - - # Compare everything but the metadata - if ( - self.http_version != other.http_version - or self.status_code != other.status_code - or self.reason != other.reason - or self.headers != other.headers - or self.content != other.content - or self.trailers != other.trailers - ): - return False - - return True - def to_bytes(self) -> bytes: data = b"" data += SnifferMessageType.RESPONSE_DATA.to_bytes(8, "big") @@ -280,6 +260,33 @@ def from_bytes(cls, data: bytes) -> tuple[ResponseData, int]: return cls(http_version, status_code, reason, headers, content, trailers, meta, session), i + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"http_version={self.http_version}, " + text += f"status_code={self.status_code}, " + text += f"reasom={self.reason}, " + text += f"headers={self.headers}, " + text += f"content={self.content!r}, " + text += f"trailers={self.trailers}, " + text += f"meta={self.meta}," + text += f"session={self.session}" + text += ")" + + return text + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RequestData): + raise NotImplementedError + + return ( + self.http_version == other.http_version + and self.status_code == other.status_code + and self.reason == other.reason + and self.headers == other.headers + and self.content == other.content + and self.trailers == other.trailers + ) + class ProxyEvent: CONNECT = 1 @@ -287,6 +294,7 @@ class ProxyEvent: def __init__(self, event: int): self.event = event + self.session: Optional[int] = None @classmethod def from_bytes(cls, data: bytes) -> tuple[ProxyEvent, int]: @@ -309,9 +317,18 @@ def to_bytes(self) -> bytes: class SnifferError: - def __init__(self, error_msg: str, error_type: str): + def __init__(self, error_msg: str, error_type: str, session: Optional[int] = None): self.error_msg = error_msg self.error_type = error_type + self.session = session + + def to_bytes(self) -> bytes: + data = b"" + data += bytes_to_varlen_bytes(self.error_msg) + data += bytes_to_varlen_bytes(self.error_type) + data += SnifferInt64.to_bytes(self.session) + + return data class SniffCommand: @@ -334,7 +351,7 @@ def __init__( self.meta = meta self.session = session - def to_bytes(self): + def to_bytes(self) -> bytes: data = b"" data += SnifferMessageType.SNIFF_COMMAND.to_bytes(8, "big") data += self.command.to_bytes(8, "big") @@ -381,8 +398,9 @@ def __init__(self, obj: Optional[SnifferError], *args, **kwargs): self.obj = obj -SnifferMessage: TypeVar = Union[SniffCommand] -ProxyMessage: TypeVar = Union[RequestData, ResponseData, ProxyEvent] +SnifferMessage: TypeAlias = Union[RequestData] +ProxyMessage: TypeAlias = Union[RequestData, ResponseData, ProxyEvent] +SkopoMessage: TypeAlias = Union[ProxyMessage, SnifferError] def sniffer_data_from_bytes(data: bytes): @@ -415,10 +433,9 @@ def pushback_message(self, obj): self.session_to_messages[obj.session].append(obj) def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> SnifferMessage: - print(f"Waiting message with session: {session}") while len(self.session_to_messages[session]): obj = self.session_to_messages[session].pop(0) - print(f"Got obj: {obj}") + if isinstance(obj, SnifferError) and not ignore_error: raise ProxyException(obj) @@ -432,7 +449,6 @@ def get_message(self, ignore_error: bool = False, session: Optional[int] = None) while True: data = self._get_data() obj, _ = sniffer_data_from_bytes(data) - print(f"Got obj: {obj}") if isinstance(obj, SnifferError) and not ignore_error: raise ProxyException(obj) From f33e8b7e8968796c78de86be73340053918c270a Mon Sep 17 00:00:00 2001 From: Marius Pricop <22615594+RazorBest@users.noreply.github.com> Date: Sun, 2 Nov 2025 15:57:25 +0200 Subject: [PATCH 18/18] feat: update skopo --- cdprecorder/erpeto.py | 12 +- cdprecorder/recorder.py | 26 ++++- cdprecorder/skopo/__init__.py | 2 +- cdprecorder/skopo/sniff_protocol.py | 153 +++++++++++++++++-------- main.py | 22 +--- tests/test_e2e_run/selenium_runners.py | 40 ++++--- tests/test_e2e_run/test_http_apps.py | 49 +++++--- 7 files changed, 196 insertions(+), 108 deletions(-) diff --git a/cdprecorder/erpeto.py b/cdprecorder/erpeto.py index b19fce5..39e2ff0 100644 --- a/cdprecorder/erpeto.py +++ b/cdprecorder/erpeto.py @@ -15,7 +15,7 @@ from pycdp import cdp import cdprecorder -from cdprecorder import generate_python +from cdprecorder import generate_python, logger from cdprecorder.action import ( BrowserAction, InputAction, @@ -55,7 +55,7 @@ def generate_action(action: HttpAction, prev_new_actions: list[Optional[HttpActi return new_action -def run_actions(actions: list[HttpAction]) -> None: +def run_actions(actions: list[HttpAction], proxies: Optional[list[str]] = None) -> None: new_actions: list[Optional[HttpAction]] = [] for action in actions: @@ -72,7 +72,8 @@ def run_actions(actions: list[HttpAction]) -> None: cookies=new_action.cookies_to_dict(), ) prepared_request = req.prepare() - resp = session.send(prepared_request, allow_redirects=False) + logger.debug("Replicating request: %s", prepared_request) + resp = session.send(prepared_request, allow_redirects=False, proxies=proxies) resp_action = response_action_from_python_response(resp) new_actions.append(resp_action) @@ -342,8 +343,9 @@ def run_analyse(actions): cdprecorder.analyser.analyse_actions(actions) -def run_replicate(actions): - run_actions(actions) +def run_replicate(actions, proxies: Optional[list[str]] = None): + logger.debug("Start run_replicate") + run_actions(actions, proxies) async def run(options: RecorderOptions) -> None: diff --git a/cdprecorder/recorder.py b/cdprecorder/recorder.py index caa7d57..542ac19 100644 --- a/cdprecorder/recorder.py +++ b/cdprecorder/recorder.py @@ -692,7 +692,11 @@ async def bind_func_to_context_id( # Using execution_context_id is deprecated # But, adding the biding with execution_context_name doesn't work when the recorder is restarted # on an already started Chrome - await target_session.execute(cdp.runtime.add_binding(name, execution_context_id=context_id)) + try: + await target_session.execute(cdp.runtime.add_binding(name, execution_context_id=context_id)) + except: + logger.exception("Exception in bind_func_to_context_id") + raise async def init_runtime_scripts( @@ -744,7 +748,23 @@ async def on_binding_called(self, evt: cdp.runtime.BindingCalled) -> None: async def on_execution_context_created(self, evt: cdp.runtime.ExecutionContextCreated) -> None: if evt.context.name == self.listener_context_name: self.listener_context_id = evt.context.id_ - await bind_func_to_context_id(self.target_session, self.EVENT_SEND_BINDING, self.listener_context_id) + try: + await bind_func_to_context_id(self.target_session, self.EVENT_SEND_BINDING, self.listener_context_id) + except pycdp.exceptions.CDPBrowserError as exc: + if "Cannot find execution context with given executionContextId" not in str(exc): + raise + # This happens when the execution gets destroyed before we even + # had the chance to create the binding + # One issue in the case we don't raise the error is that we + # might miss cases where an active page doesn't have the + # bindings set up. Which means that the javascript events + # (like clicks) + # can't be captured by our javascript extension. + # TODO: make this error traceable, to make the case above + # detectable + logger.debug("Couldn't attach binding to execution context %s because it has been destroyed already", evt.context.id_) + + def pop_actions(self) -> list[InputAction]: actions = self.actions @@ -791,7 +811,7 @@ async def insert_widget_extension( async def init_recorder(options: RecorderOptions): - urlfilter = filters.URLFilter() + urlfilter = None # filters.URLFilter() try: http = ClientSession() diff --git a/cdprecorder/skopo/__init__.py b/cdprecorder/skopo/__init__.py index 863ec24..77c40c5 100644 --- a/cdprecorder/skopo/__init__.py +++ b/cdprecorder/skopo/__init__.py @@ -75,7 +75,7 @@ def _handle_message( if obj.event == ProxyEvent.CONNECT or obj.event == ProxyEvent.CLOSE: return None - if isinstance(obj, SnifferMessage): + if isinstance(obj, (SniffCommand, SnifferError)): if obj.session is None and session is not None: raise SnifferException(obj, "Received sesionless message in a context with session") if session is not None and obj.session != session: diff --git a/cdprecorder/skopo/sniff_protocol.py b/cdprecorder/skopo/sniff_protocol.py index 8d4d7b2..4f958ca 100644 --- a/cdprecorder/skopo/sniff_protocol.py +++ b/cdprecorder/skopo/sniff_protocol.py @@ -15,7 +15,7 @@ from typing import Optional, TypeAlias, Protocol, Union -def bytes_to_varlen_bytes(data: bytes): +def bytes_to_varlen_bytes(data: bytes) -> bytes: size = len(data) return size.to_bytes(8, "big") + data @@ -25,13 +25,15 @@ class SnifferMessageType(IntEnum): RESPONSE_DATA = 2 PROXY_EVENT = 3 SNIFF_COMMAND = 4 - INT64 = 5 - NONE = 6 + SNIFFER_ERROR = 5 + INT64 = 6 + STRING = 7 + NONE = 8 class SnifferNone: @staticmethod - def to_bytes(): + def to_bytes() -> bytes: data = b"" data += SnifferMessageType.NONE.to_bytes(8, "big") return data @@ -39,7 +41,7 @@ def to_bytes(): class SnifferInt64: @staticmethod - def to_bytes(value: Optional[int]): + def to_bytes(value: Optional[int]) -> bytes: if value is None: return SnifferNone.to_bytes() @@ -49,8 +51,36 @@ def to_bytes(value: Optional[int]): return data +class SnifferString: + @staticmethod + def to_bytes(value: Optional[str]) -> bytes: + if value is None: + return SnifferNone.to_bytes() + + data = b"" + data += SnifferMessageType.STRING.to_bytes(8, "big") + data += bytes_to_varlen_bytes(value.encode()) + + return data + + @staticmethod + def from_bytes(data: bytes) -> tuple[str, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.STRING + + length = int.from_bytes(data[i : i + 8], "big") + i += 8 + + s = data[i : i + length].decode("utf-8") + i += length + + return s, i + length + + class SnifferMetadata: - def __init__(self, object_id: int, timestamp: int, proxyname: str): + def __init__(self, object_id: int, timestamp: int, proxyname: str) -> None: self.object_id = object_id self.timestamp = timestamp self.proxyname = proxyname @@ -160,12 +190,12 @@ def from_bytes(cls, data: bytes) -> tuple[RequestData, int]: def __str__(self) -> str: text = f"{self.__class__.__name__}(" - text += f"http_version={self.http_version}, " - text += f"method={self.method}, " - text += f"url={self.url}, " - text += f"headers={self.headers}, " + text += f"http_version={self.http_version!r}, " + text += f"method={self.method!r}, " + text += f"url={self.url!r}, " + text += f"headers={self.headers!r}, " text += f"content={self.content!r}, " - text += f"trailers={self.trailers}, " + text += f"trailers={self.trailers!r}, " text += f"meta={self.meta}," text += f"session={self.session}" text += ")" @@ -262,12 +292,12 @@ def from_bytes(cls, data: bytes) -> tuple[ResponseData, int]: def __str__(self) -> str: text = f"{self.__class__.__name__}(" - text += f"http_version={self.http_version}, " + text += f"http_version={self.http_version!r}, " text += f"status_code={self.status_code}, " - text += f"reasom={self.reason}, " - text += f"headers={self.headers}, " + text += f"reasom={self.reason!r}, " + text += f"headers={self.headers!r}, " text += f"content={self.content!r}, " - text += f"trailers={self.trailers}, " + text += f"trailers={self.trailers!r}, " text += f"meta={self.meta}," text += f"session={self.session}" text += ")" @@ -275,7 +305,7 @@ def __str__(self) -> str: return text def __eq__(self, other: object) -> bool: - if not isinstance(other, RequestData): + if not isinstance(other, ResponseData): raise NotImplementedError return ( @@ -324,12 +354,31 @@ def __init__(self, error_msg: str, error_type: str, session: Optional[int] = Non def to_bytes(self) -> bytes: data = b"" - data += bytes_to_varlen_bytes(self.error_msg) - data += bytes_to_varlen_bytes(self.error_type) + data += SnifferMessageType.SNIFFER_ERROR.to_bytes(8, "big") + data += SnifferString.to_bytes(self.error_msg) + data += SnifferString.to_bytes(self.error_type) data += SnifferInt64.to_bytes(self.session) return data + @classmethod + def from_bytes(cls, data: bytes) -> tuple[SnifferError, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.SNIFFER_ERROR + + error_msg, used = SnifferString.from_bytes(data[i:]) + i += used + + error_type, used = SnifferString.from_bytes(data[i:]) + i += used + + session = int.from_bytes(data[i : i + 8], "big") + i += 8 + + return cls(error_msg, error_type, session), i + class SniffCommand: NOP = 1 @@ -337,6 +386,8 @@ class SniffCommand: CANCEL = 3 CLOSE_CLIENT = 4 + __slots__ = ["command", "request", "response", "meta", "session"] + def __init__( self, command: int, @@ -356,7 +407,7 @@ def to_bytes(self) -> bytes: data += SnifferMessageType.SNIFF_COMMAND.to_bytes(8, "big") data += self.command.to_bytes(8, "big") data += self.request.to_bytes() if self.request is not None else SnifferNone.to_bytes() - data += self.repsonse.to_bytes() if self.response is not None else SnifferNone.to_bytes() + data += self.response.to_bytes() if self.response is not None else SnifferNone.to_bytes() data += self.meta.to_bytes() if self.meta is not None else SnifferNone.to_bytes() data += SnifferInt64.to_bytes(self.session) if self.session is not None else SnifferNone.to_bytes() @@ -393,17 +444,23 @@ def from_bytes(cls, data: bytes) -> tuple[SniffCommand, int]: class ProxyException(Exception): - def __init__(self, obj: Optional[SnifferError], *args, **kwargs): + def __init__(self, obj: Optional[SnifferError], *args: object, **kwargs: object) -> None: super().__init__(*args, **kwargs) self.obj = obj -SnifferMessage: TypeAlias = Union[RequestData] -ProxyMessage: TypeAlias = Union[RequestData, ResponseData, ProxyEvent] +class SnifferClientException(Exception): + def __init__(self, obj: Optional[SnifferMessage], *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.obj = obj + + +SnifferMessage: TypeAlias = Union[SniffCommand, SnifferError] +ProxyMessage: TypeAlias = Union[RequestData, ResponseData, ProxyEvent, SnifferError] SkopoMessage: TypeAlias = Union[ProxyMessage, SnifferError] -def sniffer_data_from_bytes(data: bytes): +def sniffer_data_from_bytes(data: bytes) -> tuple[Union[RequestData, ResponseData, ProxyEvent, SniffCommand, int, str, None], int]: message_type = int.from_bytes(data[:8], "big") if message_type == SnifferMessageType.REQUEST_DATA: return RequestData.from_bytes(data) @@ -416,6 +473,8 @@ def sniffer_data_from_bytes(data: bytes): elif message_type == SnifferMessageType.INT64: assert len(data) >= 16, "Not enough bytes to decode" return int.from_bytes(data[8:16], "big"), 16 + elif message_type == SnifferMessageType.STRING: + return SnifferString.from_bytes(data) elif message_type == SnifferMessageType.NONE: return None, 8 else: @@ -423,26 +482,26 @@ def sniffer_data_from_bytes(data: bytes): class SnifferProxyClient: - def __init__(self): - self.session_to_messages = defaultdict(list) + def __init__(self) -> None: + self.session_to_messages: dict[Optional[int], list[SnifferMessage]] = defaultdict(list) def _get_data(self) -> bytes: raise NotImplementedError - def pushback_message(self, obj): + def pushback_message(self, obj: SnifferMessage) -> None: self.session_to_messages[obj.session].append(obj) def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> SnifferMessage: while len(self.session_to_messages[session]): - obj = self.session_to_messages[session].pop(0) + queue_obj = self.session_to_messages[session].pop(0) - if isinstance(obj, SnifferError) and not ignore_error: - raise ProxyException(obj) + if isinstance(queue_obj, SnifferError) and not ignore_error: + raise ProxyException(queue_obj) - if isinstance(obj, SnifferMessage): - if obj.session is None and session is not None: - raise ProxyException(obj, "Received sessionless message in a context with session") - return obj + if isinstance(queue_obj, SniffCommand): + if queue_obj.session is None and session is not None: + raise SnifferClientException(queue_obj, "Received sessionless message in a context with session") + return queue_obj del self.session_to_messages[session] @@ -453,7 +512,7 @@ def get_message(self, ignore_error: bool = False, session: Optional[int] = None) if isinstance(obj, SnifferError) and not ignore_error: raise ProxyException(obj) - if isinstance(obj, SnifferMessage): + if isinstance(obj, SniffCommand): if obj.session != session: self.pushback_message(obj) else: @@ -466,62 +525,62 @@ def get_command(self, ignore_error: bool = False, session: Optional[int] = None) raise ProxyException(msg, "Expected message of type SniffCommand") return msg - def _send_data(self, data: bytes): + def _send_data(self, data: bytes) -> None: raise NotImplementedError - def send_proxy_message(self, obj: ProxyMessage, session: Optional[int] = None): + def send_proxy_message(self, obj: ProxyMessage, session: Optional[int] = None) -> None: if hasattr(obj, "session"): obj.session = session data = obj.to_bytes() self._send_data(data) - def send_request_data(self, obj: RequestData, session: Optional[int] = None): + def send_request_data(self, obj: RequestData, session: Optional[int] = None) -> None: self.send_proxy_message(obj, session) - def send_response_data(self, obj: ResponseData, session: Optional[int] = None): + def send_response_data(self, obj: ResponseData, session: Optional[int] = None) -> None: self.send_proxy_message(obj, session) - def send_proxy_event(self, obj: ProxyEvent, session: Optional[int] = None): + def send_proxy_event(self, obj: ProxyEvent, session: Optional[int] = None) -> None: self.send_proxy_message(obj, session) - def send_error(self, msg: SnifferError, session: Optional[int] = None): + def send_error(self, msg: SnifferError, session: Optional[int] = None) -> None: msg.session = session data = msg.to_bytes() self._send_data(data) - def new_session(self): + def new_session(self) -> SnifferClientSession: return SnifferClientSession(self) class SnifferClientSession: _LAST_ID = 1 - def __init__(self, client: SnifferProxyClient): + def __init__(self, client: SnifferProxyClient) -> None: self.client = client self.id = SnifferClientSession._LAST_ID SnifferClientSession._LAST_ID += 1 - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> object: value = getattr(self.client, key) - if isinstance(value, Callable): + if callable(value): value = functools.partial(value, session=self.id) return value -def to_sock_datagram(data: bytes): +def to_sock_datagram(data: bytes) -> bytes: size = len(data) return size.to_bytes(8, "big") + data -async def async_read_sock_datagram(reader: StreamReader): +async def async_read_sock_datagram(reader: StreamReader) -> bytes: data_size = await reader.readexactly(8) size = int.from_bytes(data_size, "big") data = await reader.readexactly(size) return data -def read_sock_datagram(sock: socket.socket): +def read_sock_datagram(sock: socket.socket) -> bytes: data_size = sock.recv(8) if len(data_size) != 8: raise ProxyException("Received less than expected data") diff --git a/main.py b/main.py index 6c0fcf4..cd079c6 100644 --- a/main.py +++ b/main.py @@ -15,26 +15,10 @@ async def on_fail(comparator, httpobj1, httpobj2): async def main() -> None: - sniffer_manager1 = skopo.MitmproxySnifferManager("1") - sniffer_manager1.start_sniffer_on_thread() - proxy1 = sniffer_manager1.start_proxy_instance(port=8082) + comparator = await skopo.create_mitmproxy_sniffer_comparator(on_fail) - sniffer_manager2 = skopo.MitmproxySnifferManager("2") - sniffer_manager2.start_sniffer_on_thread() - proxy2 = sniffer_manager2.start_proxy_instance(port=8083) - - print(f"Event loop: {asyncio.get_event_loop()}") - comparator = skopo.SnifferComparator( - on_fail, sniffer_manager1.sniffer, sniffer_manager2.sniffer, asyncio.get_event_loop() - ) - - await sniffer_manager1.wait_for_proxy_connection_with_sniffer() - await sniffer_manager2.wait_for_proxy_connection_with_sniffer() - - print("Waited") - - proxy_url1 = f"http://{proxy1.host}:{proxy1.port}" - proxy_url2 = f"http://{proxy2.host}:{proxy2.port}" + proxy_url1 = comparator.sniffer1.proxy_url + proxy_url2 = comparator.sniffer2.proxy_url import requests diff --git a/tests/test_e2e_run/selenium_runners.py b/tests/test_e2e_run/selenium_runners.py index 76d7060..69ccdb1 100644 --- a/tests/test_e2e_run/selenium_runners.py +++ b/tests/test_e2e_run/selenium_runners.py @@ -4,24 +4,28 @@ def run_csrf_form_submitsuccess(driver): # Step # | name | target | value - # 1 | open | / | - # driver.get("http://localhost:5000/") - # 2 | setWindowSize | 736x729 | - driver.set_window_size(736, 729) - # 3 | click | css=html | - driver.find_element(By.CSS_SELECTOR, "html").click() - # 4 | click | css=form | - driver.find_element(By.CSS_SELECTOR, "form").click() - # 5 | click | id=field2 | - driver.find_element(By.ID, "field2").click() - # 6 | click | id=field1 | - driver.find_element(By.ID, "field1").click() - # 7 | type | id=field1 | alpha - driver.find_element(By.ID, "field1").send_keys("alpha") - # 8 | type | id=field2 | beta - driver.find_element(By.ID, "field2").send_keys("beta") - # 9 | click | id=submit | - driver.find_element(By.ID, "submit").click() + driver.implicitly_wait(1000000) + try: + # 1 | open | / | + driver.get("http://local:5000/") + # 2 | setWindowSize | 736x729 | + driver.set_window_size(736, 729) + # 3 | click | css=html | + driver.find_element(By.CSS_SELECTOR, "html").click() + # 4 | click | css=form | + driver.find_element(By.CSS_SELECTOR, "form").click() + # 5 | click | id=field2 | + driver.find_element(By.ID, "field2").click() + # 6 | click | id=field1 | + driver.find_element(By.ID, "field1").click() + # 7 | type | id=field1 | alpha + driver.find_element(By.ID, "field1").send_keys("alpha") + # 8 | type | id=field2 | beta + driver.find_element(By.ID, "field2").send_keys("beta") + # 9 | click | id=submit | + driver.find_element(By.ID, "submit").click() + finally: + driver.implicitly_wait(0) selenium_runs = { diff --git a/tests/test_e2e_run/test_http_apps.py b/tests/test_e2e_run/test_http_apps.py index 67dee8f..b23cac6 100644 --- a/tests/test_e2e_run/test_http_apps.py +++ b/tests/test_e2e_run/test_http_apps.py @@ -51,9 +51,8 @@ def __del__(self): async def on_fail(comparator, httpobj1, httpobj2): - logging.error(f"Comparator failed after {comparator.requests_passed} passes") - logging.error(f"First httpobj: {httpobj1}") - logging.error(f"Second httpobj: {httpobj2}") + logging.error("First httpobj: %s", httpobj1) + logging.error("Second httpobj %s", httpobj2) raise Exception("Failure") @@ -61,12 +60,14 @@ async def on_fail(comparator, httpobj1, httpobj2): async def test_csrf_form_run_csrf_from_submit_success(): logging.getLogger("selenium").setLevel(logging.WARNING) logging.getLogger("pycdp").setLevel(logging.WARNING) + logging.getLogger("cdprecorder").setLevel(logging.WARNING) logging.info("Starting web app") app = VenvAppRunner("http_apps/csrf_form") app.wait_until_up() logging.info("Web app started") comparator = await skopo.create_mitmproxy_sniffer_comparator(on_fail) + comparator.add_filtered_url_keywords(["clients2.google.com"]) proxy_url1 = comparator.sniffer1.proxy_url proxy_url2 = comparator.sniffer2.proxy_url @@ -91,26 +92,24 @@ async def test_csrf_form_run_csrf_from_submit_success(): collect_all=True, ) - sniffer1 = comparator.sniffer1 - - async def pass_sniffer(): + async def pass_sniffer(sniffer): print("pass sniffer") msg = None try: while True: - msg = await sniffer1.get_message() + msg = await sniffer.get_message() - logging.debug("Msg: %s", msg) - logging.debug("Session: %s", msg.session) - await sniffer1.to_session(msg.session).send_command(skopo.SniffCommand.NOP) + # logging.debug("Msg: %s", msg) + # logging.debug("Session: %s", msg.session) + await sniffer.to_session(msg.session).send_command(skopo.SniffCommand.NOP) except asyncio.CancelledError: if msg is not None: - await sniffer1.to_session(msg.session).send_command(skopo.SniffCommand.NOP) + await sniffer.to_session(msg.session).send_command(skopo.SniffCommand.NOP) pass except: logging.exception("Pass sniffer ended with exception") - pass_task = asyncio.create_task(pass_sniffer()) + pass_task = asyncio.create_task(pass_sniffer(comparator.sniffer1)) rec = await recorder.init_recorder(recorder_options) await asyncio.sleep(2) pass_task.cancel() @@ -121,10 +120,12 @@ async def pass_sniffer(): except: logging.exception("pass_task ended with exception") + logging.info("Start recording") + try: t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) t2 = asyncio.create_task(recorder.collect_communications(rec, 30)) - t3 = asyncio.create_task(pass_sniffer()) + t3 = asyncio.create_task(pass_sniffer(comparator.sniffer1)) done, pending = await asyncio.wait([t1, t2, t3], return_when=asyncio.FIRST_COMPLETED) if t2 in done: @@ -148,6 +149,8 @@ async def pass_sniffer(): driver.quit() + await asyncio.sleep(2) + logging.info("Recorded communications") actions = erpeto.parse_communications_into_actions(communications) @@ -158,17 +161,33 @@ async def pass_sniffer(): # TODO: probably, a sniffer that's not linked to a comparator should not block # comparator = skopo.SnifferComparator(on_fail, sniffer_manager1.sniffer, sniffer_manager2.sniffer) + logging.info("Starting replication") + logging.getLogger("cdprecorder").setLevel(logging.DEBUG) + + #import requests + #requests.get("http://local:5000", ) + #time.sleep(2) + replicator_proxies = { + 'http': proxy_url2, + 'https': proxy_url2, + } + + #return + driver = webdriver.Chrome(options) t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) - t2 = asyncio.create_task(asyncio.to_thread(erpeto.run_replicate, actions, [proxy_url2])) + t2 = asyncio.create_task(asyncio.to_thread(erpeto.run_replicate, actions, replicator_proxies)) t3 = asyncio.create_task(comparator.run()) pair = asyncio.create_task(asyncio.wait([t1, t2], return_when=asyncio.ALL_COMPLETED)) - done, pending = await asyncio.wait([pair, t3], return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait([t1, t2, t3], return_when=asyncio.FIRST_COMPLETED) logging.debug("Pending state: %s, %s", pair.done(), t3.done()) + if t2.done(): + t2.result() + t3.cancel() if t3 in done: assert t3.result() is True