diff --git a/cdprecorder/_storage.py b/cdprecorder/_storage.py new file mode 100644 index 0000000..f1e74aa --- /dev/null +++ b/cdprecorder/_storage.py @@ -0,0 +1,8 @@ +import os + + +DEFAULT_SOCKET_NAME = "erpeto.sock" + + +def get_runtime_dir() -> str: + return os.getenv("XDG_RUNTIME_DIR") diff --git a/cdprecorder/erpeto.py b/cdprecorder/erpeto.py new file mode 100644 index 0000000..39e2ff0 --- /dev/null +++ b/cdprecorder/erpeto.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +from typing import cast, Optional, Union, TYPE_CHECKING + +import bs4 +import bs4.builder._htmlparser +import pycdp +import requests +import sys +import twisted.internet.reactor + +from twisted.python.log import err +from twisted.internet import defer, threads +from twisted.internet.interfaces import IReactorCore +from pycdp import cdp + +import cdprecorder +from cdprecorder import generate_python, logger +from cdprecorder.action import ( + BrowserAction, + InputAction, + HttpAction, + LowercaseStr, + RequestAction, + ResponseAction, + response_action_from_python_response, +) +from cdprecorder.recorder import ( + HttpCommunication, + RecorderOptions, + record, +) + +import cdprecorder.analyser + +if TYPE_CHECKING: + import bs4 + + from pycdp.cdp.util import T_JSON_DICT + from twisted.python.failure import Failure + + from cdprecorder.type_checking import CdpEvent, HttpTarget + + +# https://github.com/twisted/twisted/issues/9909 +reactor = cast(IReactorCore, twisted.internet.reactor) + + +def generate_action(action: HttpAction, prev_new_actions: list[Optional[HttpAction]]) -> RequestAction: + new_action = RequestAction() + new_action.shallow_copy_from_action(action) + for target in action.targets: + target.apply(new_action, prev_new_actions) + + return new_action + + +def run_actions(actions: list[HttpAction], proxies: Optional[list[str]] = None) -> None: + new_actions: list[Optional[HttpAction]] = [] + + for action in actions: + if isinstance(action, RequestAction): + new_action = generate_action(action, new_actions) + new_actions.append(new_action) + + with requests.Session() as session: + req = requests.Request( + method=new_action.method, + url=new_action.url, + headers=new_action.headers, + data=new_action.body, + cookies=new_action.cookies_to_dict(), + ) + prepared_request = req.prepare() + logger.debug("Replicating request: %s", prepared_request) + resp = session.send(prepared_request, allow_redirects=False, proxies=proxies) + resp_action = response_action_from_python_response(resp) + new_actions.append(resp_action) + + print(f"{new_action.method} {new_action.url} - {resp.status_code}") + + elif not isinstance(action, ResponseAction): + new_actions.append(None) + + +def to_cdp_event(event: CdpEvent) -> dict[str, Union[str, T_JSON_DICT]]: + cdp_method = None + for key, val in cdp.util._event_parsers.items(): + if val == event.__class__: + cdp_method = key + break + else: + raise Exception + + return { + "method": cdp_method, + "params": event.to_json(), + "type": "recv", + "domain": "-", + } + + +def get_only_http_actions(actions: list[BrowserActions]) -> list[HttpActions]: + return [action for action in actions if isinstance(action, HttpAction)] + + +def _generate_events_with_redirects_extracted(events: list[CdpEvent]) -> list[CdpEvent]: + new_events = [] + future_events: list[CdpEvents] = [] + wait_response_extra = False + wait_request_extra = False + wait_extra = False + for evt in events: + if wait_extra: + if ( + isinstance(evt, cdp.network.RequestWillBeSentExtraInfo) + and not wait_request_extra + or isinstance(evt, cdp.network.ResponseReceivedExtraInfo) + and not wait_response_extra + ): + wait_extra = False + new_events += future_events + future_events = [] + + if isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): + wait_request_extra = False + elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): + wait_response_extra = False + wait_extra = wait_response_extra or wait_request_extra + + new_events.append(evt) + continue + else: + new_events += future_events + future_events = [] + + future_events.append(evt) + if not isinstance(evt, cdp.network.RequestWillBeSent): + continue + + if not evt.redirect_response: + new_events += future_events + future_events = [] + wait_extra = False + else: + if evt.redirect_has_extra_info: + wait_response_extra = True + wait_request_extra = True + wait_extra = True + + response_evt = cdp.network.ResponseReceived( + request_id=evt.request_id, + loader_id=evt.loader_id, + timestamp=evt.timestamp, + type_=evt.type_, + response=evt.redirect_response, + has_extra_info=evt.redirect_has_extra_info, + frame_id=evt.frame_id, + ) + new_events.append(response_evt) + + new_events += future_events + + return new_events + + +def parse_communications_into_actions( + communications: list[Union[HttpCommunication, InputActioni]] +) -> list[BrowserAction]: + from cdprecorder import logger + + actions: list[BrowserAction] = [] + + for comm in communications: + logger.debug("Comm: %s", repr(comm)) + + if not isinstance(comm, HttpCommunication): + actions.append(comm) + continue + + if comm.ignored: + continue + + response_bodies = list(comm.response_bodies) + + curr_request: Optional[RequestAction] = None + request_extra: Optional[RequestAction] = None + curr_response: Optional[ResponseAction] = None + response_extra: Optional[ResponseAction] = None + events = _generate_events_with_redirects_extracted(comm.events) + print("--------------------------------------------------------") + # Append to actions the requests/responses from each event + for evt in events: + if isinstance(evt, cdp.network.RequestWillBeSent): + if curr_request is not None: + if all((curr_request, request_extra, curr_response)): + curr_request.has_response = True + actions.append(curr_request) + actions.append(curr_response) + else: + actions.append(curr_request) + if curr_response is not None: + curr_request.has_response = True + actions.append(curr_response) + + curr_request = None + request_extra = None + curr_response = None + + """ + if curr_request is not None: + # Consume the previous request + if curr_response: + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + request_extra = None + + if curr_response: + # Consume the previous response + actions.append(curr_response) + if response_extra: + raise Exception + curr_response = None + """ + + curr_request = RequestAction() + curr_request.update_info(evt.request) + if evt.request.has_post_data and evt.request.post_data: + # TODO: Check if bytes in other entry + curr_request.set_body(evt.request.post_data.encode()) + + if request_extra is not None: + curr_request.merge(request_extra) + + elif isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): + if request_extra is not None: + if all((curr_request, request_extra, curr_response)): + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + request_extra = None + + actions.append(curr_response) + curr_response = None + + """ + if curr_request is not None and request_extra is not None: + # Consume the previous request + if curr_response: + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + request_extra = None + + if curr_response: + # Consume the previous response + actions.append(curr_response) + if response_extra: + raise Exception + curr_response = None + """ + + if request_extra is not None: + raise Exception + request_extra = RequestAction() + request_extra.update_info(evt) + + if curr_request is not None: + curr_request.merge(request_extra) + # request_extra = None + + elif isinstance(evt, cdp.network.ResponseReceived): + if curr_response is None: + curr_response = ResponseAction(evt.response) + else: + raise Exception + + if response_extra is not None: + # Always merge response_extra over curr_response, not the other way + curr_response.merge(response_extra) + response_extra = None + + elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): + if curr_response is not None: + # Always merge response_extra over curr_response, not the other way + curr_response.merge(ResponseAction(evt)) + elif response_extra is None: + response_extra = ResponseAction(evt) + else: + raise Exception + + elif isinstance(evt, cdp.network.LoadingFinished): + # Manually inserted + response_body = response_bodies.pop(0) + if response_body is not None: + if curr_response: + curr_response.set_body(response_body) + elif response_extra: + response_extra.set_body(response_body) + else: + raise Exception + + if curr_request is not None: + if curr_response or response_extra: + curr_request.has_response = True + actions.append(curr_request) + curr_request = None + if curr_response is not None: + actions.append(curr_response) + curr_response = None + elif response_extra is not None: + actions.append(response_extra) + response_extra = None + + if curr_request is not None: + if curr_response is not None: + curr_request.has_response = True + curr_request.merge(request_extra) + actions.append(curr_request) + if curr_response is not None: + # Always merge response_extra over curr_response, not the other way + curr_response.merge(response_extra) + actions.append(curr_response) + + return actions + + +def make_action_ids_consecutive_from_list(actions: list[BrowserAction]): + for i, action in enumerate(actions): + action.ID = i + + +async def run_recorder(options: RecorderOptions): + communications = await record(options) + actions = parse_communications_into_actions(communications) + make_action_ids_consecutive_from_list(actions) + + return actions + + +def run_analyse(actions): + cdprecorder.analyser.analyse_actions(actions) + + +def run_replicate(actions, proxies: Optional[list[str]] = None): + logger.debug("Start run_replicate") + run_actions(actions, proxies) + + +async def run(options: RecorderOptions) -> None: + actions = await run_recorder(options) + run_analyse(actions) + run_replicate(actions) + + # actions = get_only_http_actions(actions) + # run_actions(actions) + + # generate_python.write_python_code(actions, "generated.py") + + """ + await threads.deferToThread(chrome.kill) + """ diff --git a/cdprecorder/recorder.py b/cdprecorder/recorder.py index 5c38fe0..542ac19 100644 --- a/cdprecorder/recorder.py +++ b/cdprecorder/recorder.py @@ -28,12 +28,14 @@ from cheap_repr import cheap_repr from pycdp import cdp from pycdp.browser import ChromeLauncher -from pycdp.twisted import CDPConnection as _PyCDPConnection +from pycdp.twisted import CDPConnection as _TwistedPyCDPConnection from twisted.internet import defer, threads from twisted.internet.error import ConnectionRefusedError from twisted.internet.interfaces import IReactorCore, IReactorTime from twisted.web.client import Agent +from pycdp.asyncio import CDPConnection as _PyCDPConnection, ClientSession + from . import filters, logger, tkinter_ui from .action import InputAction @@ -98,13 +100,10 @@ async def insert_js_leech_script( await target_session.execute(cdp.page.runtime.enable()) - receiver = None - evt_to_listen = cdp.runtime.ExecutionContextCreated + events_queue = None try: - # Don't call target_session.listen because we need the receiver object to close it at the end - receiver = pycdp.twisted.CDPEventListener(defer.DeferredQueue(1024)) - target_session._listeners[evt_to_listen].add(receiver) - listener = aiter(receiver) + events_queue = target_session.create_listener(cdp.runtime.ExecutionContextCreated) + listener = aiter(events_queue) await target_session.execute( cdp.page.add_script_to_evaluate_on_new_document(expression, run_immediately=True, world_name=context_name) @@ -118,8 +117,8 @@ async def insert_js_leech_script( except defer.TimeoutError as exc: raise Exception from exc finally: - if receiver is not None: - receiver.close() + if events_queue is not None: + events_queue.close() if context_id is None: raise Exception @@ -193,7 +192,7 @@ def __eq__(self, obj: object) -> bool: T = TypeVar("T") -class AsyncIteratorWithTimeout(Generic[T]): +class TwistedAsyncIteratorWithTimeout(Generic[T]): def __init__( self, iterator: AsyncIterator[T], @@ -220,6 +219,50 @@ async def __anext__(self) -> T: d.addTimeout(remained, reactor) return await d + def __aiter__(self) -> TwistedAsyncIteratorWithTimeout: + return self + + +class TwistedAsyncIterableWithTimeout(Generic[T]): + def __init__( + self, + iterable: AsyncIterable[T], + timeout: float, + start_time: Optional[float] = None, + ): + self.iterable = iterable + self.timeout = timeout + if start_time is None: + self.start_time = time.time() + else: + self.start_time = start_time + + def __aiter__(self) -> AsyncIterator[T]: + return TwistedAsyncIteratorWithTimeout(self.iterable.__aiter__(), self.timeout, self.start_time) + + +class AsyncIteratorWithTimeout(Generic[T]): + def __init__( + self, + iterator: AsyncIterator[T], + timeout: float, + start_time: Optional[float] = None, + ): + self.iterator = iterator + self.timeout = timeout + if start_time is None: + self.start_time = time.time() + else: + self.start_time = start_time + + async def __anext__(self) -> T: + curr_time = time.time() + remained = self.timeout - (curr_time - self.start_time) + if remained <= 0: + raise StopAsyncIteration + + return await asyncio.wait_for(self.iterator.__anext__(), remained) + def __aiter__(self) -> AsyncIteratorWithTimeout: return self @@ -242,7 +285,7 @@ def __aiter__(self) -> AsyncIterator[T]: return AsyncIteratorWithTimeout(self.iterable.__aiter__(), self.timeout, self.start_time) -class CancelableAsyncIterator(Generic[T]): +class TwistedCancelableAsyncIterator(Generic[T]): def __init__(self, iterator: AsyncIterator[T], stop_event: defer.Deferred): self.iterator = iterator self._stop_event = stop_event @@ -277,11 +320,11 @@ async def __anext__(self) -> T: return res - def __aiter__(self) -> CancelableAsyncIterator: + def __aiter__(self) -> TwistedCancelableAsyncIterator: return self -class CancelableAsyncIterable(Generic[T]): +class TwistedCancelableAsyncIterable(Generic[T]): def __init__(self, iterable: AsyncIterable[T]): self.iterable = iterable self._stop_event = defer.Deferred() @@ -289,6 +332,47 @@ def __init__(self, iterable: AsyncIterable[T]): def cancel(self): return self._stop_event.callback(None) + def __aiter__(self) -> AsyncIterator[T]: + return TwistedCancelableAsyncIterator(self.iterable.__aiter__(), self._stop_event) + + +class CancelableAsyncIterator(Generic[T]): + def __init__(self, iterator: AsyncIterator[T], stop_event: asyncio.Event): + self.iterator = iterator + self._stop_event = stop_event + self._wait_cancel = asyncio.create_task(self._stop_event.wait()) + + def cancel(self): + return self._stop_event.set() + + async def __anext__(self) -> T: + if self._stop_event.is_set(): + raise StopAsyncIteration + + it_next = asyncio.create_task(self.iterator.__anext__()) + + done, pending = await asyncio.wait([it_next, self._wait_cancel], return_when=asyncio.FIRST_COMPLETED) + + if it_next in pending: + it_next.cancel() + + if self._wait_cancel in done: + raise StopAsyncIteration + + return it_next.result() + + def __aiter__(self) -> CancelableAsyncIterator: + return self + + +class CancelableAsyncIterable(Generic[T]): + def __init__(self, iterable: AsyncIterable[T]): + self.iterable = iterable + self._stop_event = asyncio.Event() + + def cancel(self): + return self._stop_event.set() + def __aiter__(self) -> AsyncIterator[T]: return CancelableAsyncIterator(self.iterable.__aiter__(), self._stop_event) @@ -300,11 +384,15 @@ def __init__( urlfilter: filters.URLFilter, collect_all: bool, start_origin: Optional[str], + connection, + listener: AsyncIterable, ): self.target_session = target_session self.urlfilter = urlfilter self.start_origin = start_origin self.collect_all = collect_all + self.connection = connection + self.listener = CancelableAsyncIterable(listener) self.communications: list[Union[HttpCommunication, InputAction]] = [] self.request_map: dict[pycdp.cdp.network.RequestId, HttpCommunication] = {} self.runtime_ctx = get_runtime_context() @@ -384,14 +472,14 @@ async def on_stop(self): def add_on_stop_callback(self, callback): self.on_stop_cbs.append(callback) + async def close(self): + self.target_session.close_listeners() + await self.connection.close() + async def collect_communications( - target_session: pycdp.twisted.CDPSession, - listener: AsyncIterator[object], - urlfilter: filters.URLFilter, + recorder: Recorder, timeout: int = 120, - collect_all: bool = False, - start_origin: Optional[str] = None, ) -> list[Union[HttpCommunication, InputAction]]: """Takes a cdp session, listens for events, and generates a list of communications. @@ -407,7 +495,7 @@ async def collect_communications( Returns: A list of communications. """ - recorder = Recorder(target_session, urlfilter, collect_all, start_origin) + listener = recorder.listener runtime_context = get_runtime_context() ui = tkinter_ui.TkRecordControl(reactor, recorder.on_start, recorder.on_stop) @@ -448,6 +536,11 @@ async def collect_communications( return recorder.communications +class TwistedCDPConnection(_TwistedPyCDPConnection): + # Remove `retry_on` wrapper from function + connect = _PyCDPConnection.connect.__wrapped__ # type: ignore[attr-defined] # pylint: disable=no-member + + class CDPConnection(_PyCDPConnection): # Remove `retry_on` wrapper from function connect = _PyCDPConnection.connect.__wrapped__ # type: ignore[attr-defined] # pylint: disable=no-member @@ -472,16 +565,25 @@ def find_chrome_binary_path() -> str: class RecorderOptions: start_url: str keep_only_same_origin_urls: bool = True - collect_all: bool = False - binary: str = CHROME_BINARY + collect_all: bool = True + binary: Optional[str] = CHROME_BINARY cdp_host: str = "localhost" cdp_port: int = 9222 - fail_if_no_connection: bool = False + proxy_host: Optional[str] = None + proxy_port: int = 8080 + proxy_scheme: str = "http" @property def cdp_url(self) -> str: return f"http://{self.cdp_host}:{self.cdp_port}" + @property + def proxy_url(self) -> Optional[str]: + if self.proxy_host is None: + return None + + return f"{self.proxy_scheme}://{self.proxy_host}:{self.proxy_port}" + async def obtain_active_tab( targets: list[cdp.target.TargetInfo], conn: CDPConnection @@ -590,7 +692,11 @@ async def bind_func_to_context_id( # Using execution_context_id is deprecated # But, adding the biding with execution_context_name doesn't work when the recorder is restarted # on an already started Chrome - await target_session.execute(cdp.runtime.add_binding(name, execution_context_id=context_id)) + try: + await target_session.execute(cdp.runtime.add_binding(name, execution_context_id=context_id)) + except: + logger.exception("Exception in bind_func_to_context_id") + raise async def init_runtime_scripts( @@ -642,7 +748,23 @@ async def on_binding_called(self, evt: cdp.runtime.BindingCalled) -> None: async def on_execution_context_created(self, evt: cdp.runtime.ExecutionContextCreated) -> None: if evt.context.name == self.listener_context_name: self.listener_context_id = evt.context.id_ - await bind_func_to_context_id(self.target_session, self.EVENT_SEND_BINDING, self.listener_context_id) + try: + await bind_func_to_context_id(self.target_session, self.EVENT_SEND_BINDING, self.listener_context_id) + except pycdp.exceptions.CDPBrowserError as exc: + if "Cannot find execution context with given executionContextId" not in str(exc): + raise + # This happens when the execution gets destroyed before we even + # had the chance to create the binding + # One issue in the case we don't raise the error is that we + # might miss cases where an active page doesn't have the + # bindings set up. Which means that the javascript events + # (like clicks) + # can't be captured by our javascript extension. + # TODO: make this error traceable, to make the case above + # detectable + logger.debug("Couldn't attach binding to execution context %s because it has been destroyed already", evt.context.id_) + + def pop_actions(self) -> list[InputAction]: actions = self.actions @@ -688,21 +810,26 @@ async def insert_widget_extension( return await insert_js_leech_script(target_session, expression) -async def record( - options: RecorderOptions, -) -> list[Union[HttpCommunication, InputAction]]: - urlfilter = filters.URLFilter() +async def init_recorder(options: RecorderOptions): + urlfilter = None # filters.URLFilter() try: - conn = CDPConnection(options.cdp_url, Agent(reactor), reactor) # type: ignore[no-untyped-call] + http = ClientSession() + conn = CDPConnection(options.cdp_url, http) await conn.connect() + conn.start() except ConnectionRefusedError: - if options.fail_if_no_connection: + if options.binary is None: raise ConnectionRefusedError port = options.cdp_port + + chrome_args = [f"--remote-debugging-port={port}", "--incognito"] + if options.proxy_url is not None: + chrome_args.append(f"--proxy-server={options.proxy_url}") + chrome_args.append("--ignore-certificate-errors") chrome = ChromeLauncher( binary=options.binary, - args=[f"--remote-debugging-port={port}", "--incognito"], + args=chrome_args, ) await threads.deferToThread(chrome.launch) # type: ignore[no-untyped-call] await conn.connect() @@ -718,6 +845,9 @@ async def record( await target_session.execute(cdp.network.enable()) + # await target_session.execute(cdp.security.enable()) + # await target_session.execute(cdp.security.set_ignore_certificate_errors(True)) + # Start the listener before navigating to the page listener = target_session.listen( cdp.runtime.BindingCalled, @@ -745,12 +875,17 @@ async def record( if options.keep_only_same_origin_urls: start_origin = extract_origin(start_url) + return Recorder(target_session, urlfilter, options.collect_all, start_origin, conn, listener) + + +async def record( + options: RecorderOptions, +) -> list[Union[HttpCommunication, InputAction]]: + recorder = await init_recorder(options) + try: - communications = await collect_communications( - target_session, listener, urlfilter, 20, options.collect_all, start_origin - ) + communications = await collect_communications(recorder, 20) finally: - target_session.close_listeners() - await conn.close() + await recorder.close() return communications diff --git a/cdprecorder/skopo/__init__.py b/cdprecorder/skopo/__init__.py new file mode 100644 index 0000000..77c40c5 --- /dev/null +++ b/cdprecorder/skopo/__init__.py @@ -0,0 +1,544 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import os.path +import random +import socket +import string +import subprocess +from collections import defaultdict +from typing import TYPE_CHECKING + +from .._storage import DEFAULT_SOCKET_NAME, get_runtime_dir +from .sniff_protocol import ( + ProxyEvent, + ProxyMessage, + ProxyException, + RequestData, + ResponseData, + SniffCommand, + SnifferError, + SnifferMessage, + SkopoMessage, + async_read_sock_datagram, + sniffer_data_from_bytes, + to_sock_datagram, +) + +if TYPE_CHECKING: + from asyncio import StreamReader, StreamWriter + from typing import Callable, Optional, Union + + +logger = logging.getLogger(__name__) + + +class SkopoException(Exception): + pass + + +class SnifferException(SkopoException): + def __init__(self, obj: Optional[SkopoMessage], *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.obj = obj + + +class SnifferProcessTerminated(SkopoException): + def __init__(self, proc: asyncio.subprocess.Process, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.proc = proc + + +class Sniffer: + def __init__(self) -> None: + self._ignore_reconnects = False + self.session_to_msg_queues: dict[Optional[int], asyncio.Queue[ProxyMessage]] = defaultdict(asyncio.Queue) + + def ignore_reconnects(self, ignore: bool = True) -> None: + self._ignore_reconnects = ignore + + async def _get_data(self) -> bytes: + raise NotImplementedError + + def pushback_message(self, obj: ProxyMessage) -> None: + self.session_to_msg_queues[obj.session].put_nowait(obj) + + def _handle_message( + self, obj: ProxyMessage, ignore_error: bool = False, session: Optional[int] = None + ) -> Optional[ProxyMessage]: + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) + + if self._ignore_reconnects and isinstance(obj, ProxyEvent): + if obj.event == ProxyEvent.CONNECT or obj.event == ProxyEvent.CLOSE: + return None + + if isinstance(obj, (SniffCommand, SnifferError)): + if obj.session is None and session is not None: + raise SnifferException(obj, "Received sesionless message in a context with session") + if session is not None and obj.session != session: + self.pushback_message(obj) + return None + + # TODO: check if it's a bug it this returns when session is None, but obj.session is not None + return obj + + async def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> ProxyMessage: + if session is not None: + while True: + t1 = asyncio.create_task(self.session_to_msg_queues[session].get()) + t2 = asyncio.create_task(self._get_data()) + done, pending = await asyncio.wait([t1, t2], return_when=asyncio.FIRST_COMPLETED) + for t in pending: + t.cancel() + + results = [] + if t1 in done: + results.append(t1.result()) + if t2 in done: + data = t2.result() + obj, _ = sniffer_data_from_bytes(data) + results.append(obj) + + while results: + obj = results.pop(0) + obj = self._handle_message(obj, ignore_error, session) + if obj is not None: + for result in results: + self.pushback_message(result) + return obj + + while True: + data = await self._get_data() + obj, _ = sniffer_data_from_bytes(data) + obj = self._handle_message(obj, ignore_error, session) + if obj is not None: + return obj + + async def get_request_data(self, ignore_error: bool = False, session: Optional[int] = None) -> RequestData: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, RequestData): + raise SnifferException(msg, "Expected message of type RequestData") + return msg + + async def get_response_data(self, ignore_error: bool = False, session: Optional[int] = None) -> ResponseData: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, ResponseData): + raise SnifferException(msg, "Expected message of type ResponseData") + return msg + + async def get_proxy_event(self, ignore_error: bool = False, session: Optional[int] = None) -> ProxyEvent: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, ProxyEvent): + raise SnifferException(msg, "Expected message of type ProxyEvent") + return msg + + async def wait_event(self, event_id: int, ignore_error: bool = False, session: Optional[int] = None) -> ProxyEvent: + while True: + msg = await self.get_message(ignore_error, session) + if not isinstance(msg, ProxyEvent): + raise SnifferException(msg, "Expected message of type ProxyEvent") + if msg.event != event_id: + raise SnifferException(msg, f"Expected event id {event_id}") + + return msg + + async def _send_data(self, data: bytes) -> None: + raise NotImplementedError + + async def send_command( + self, + command: int, + request: Optional[RequestData] = None, + response: Optional[ResponseData] = None, + session: Optional[int] = None, + ) -> None: + sniff_command = SniffCommand(command, request, response) + sniff_command.session = session + data = sniff_command.to_bytes() + await self._send_data(data) + + async def send_error(self, msg: SnifferError, session: Optional[int] = None) -> None: + msg.session = session + data = msg.to_bytes() + await self._send_data(data) + + async def async_stop(self) -> None: + raise NotImplementedError + + def stop(self) -> None: + raise NotImplementedError + + def to_session(self, session: int) -> SnifferSession: + return SnifferSession(self, session) + + +class SnifferSession: + def __init__(self, sniffer: Sniffer, id_: int): + self.sniffer = sniffer + self.id = id_ + + def __getattr__(self, key: str): # type: ignore + value = getattr(self.sniffer, key) + if callable(value): + value = functools.partial(value, session=self.id) + + return value + + +class ComparatorSniffer: + def __init__(self, sniffer1: Sniffer, sniffer2: Sniffer, on_diff: Callable): + self.sniffer1 = sniffer1 + self.sniffer2 = sniffer2 + + self.on_diff = on_diff + + # self.http_queues = defaultdict(lambda _: asyncio.Queue(), asyncio.Queue()) + self.sessions1: dict[Optional[int], asyncio.Queue[ProxyMessage]] = defaultdict(lambda: asyncio.Queue()) + self.sessions2: dict[Optional[int], asyncio.Queue[ProxyMessage]] = defaultdict(lambda: asyncio.Queue()) + self.unpaired_sessions1: dict[bytes, asyncio.Queue[ProxyMessage]] = {} + self.unpaired_sessions2: dict[bytes, asyncio.Queue[ProxyMessage]] = {} + self.waiting_tasks: list[asyncio.Task] = [] + + self.filtered_sessions: set[int] = set() + self.filtered_url_keywords: list[bytes] = [] + + def add_filtered_url_keywords(self, keywords: list[str]) -> None: + self.filtered_url_keywords.extend(kw.encode() for kw in keywords) + + async def handle_request_response(self, queue1: asyncio.Queue, queue2: asyncio.Queue) -> None: + session1: Union[Sniffer, SnifferSession] = self.sniffer1 + session2: Union[Sniffer, SnifferSession] = self.sniffer2 + + req1 = await queue1.get() + assert isinstance(req1, RequestData) + if req1.session is not None: + session1 = session1.to_session(req1.session) + await session1.send_command(SniffCommand.NOP) + + req2 = await queue2.get() + assert isinstance(req2, RequestData) + + if req1 != req2: + await self.on_diff(self, req1, req2) + + res1 = await queue1.get() + assert isinstance(res1, ResponseData) + await session1.send_command(SniffCommand.NOP) + + if req2.session is not None: + session2 = self.sniffer2.to_session(req2.session) + await session2.send_command(SniffCommand.REPLACE, response=res1) + + res2 = await queue2.get() + assert isinstance(res2, ResponseData) + + # If the proxy works as expected, this should never happen + if res1 != res2: + self.on_diff(res1, res2) + await session2.send_command(SniffCommand.NOP) + + async def _pass_message(self, sniffer: Sniffer, msg: Union[RequestData, ResponseData]) -> None: + session: Union[Sniffer, SnifferSession] = sniffer + if msg.session is not None: + session = sniffer.to_session(msg.session) + + await session.send_command(SniffCommand.NOP) + + @staticmethod + def _request_data_key(obj: RequestData) -> bytes: + return obj.method + obj.url + + @classmethod + def _try_extracting_session( + cls, obj: RequestData, unpaired_sessions: dict[bytes, asyncio.Queue[ProxyMessage]] + ) -> Optional[asyncio.Queue[ProxyMessage]]: + obj_key = cls._request_data_key(obj) + if obj_key in unpaired_sessions: + value = unpaired_sessions[obj_key] + del unpaired_sessions[obj_key] + return value + + return None + + async def _create_comparator_task(self, q1: asyncio.Queue, q2: asyncio.Queue) -> asyncio.Task: + task = asyncio.create_task(self.handle_request_response(q1, q2)) + self.waiting_tasks.append(task) + return task + + def is_object_filtered(self, obj: Union[RequestData, ResponseData]) -> bool: + if isinstance(obj, RequestData): + return any(kw in obj.url for kw in self.filtered_url_keywords) + + if isinstance(obj, ResponseData): + if obj.session is not None and obj.session in self.filtered_sessions: + return True + + return False + + def remember_filtered_object(self, obj: Union[RequestData, ResponseData]) -> None: + if obj.session is None: + raise SkopoException("Can't remember a filtered object with no session.") + self.filtered_sessions.add(obj.session) + + async def run(self) -> None: + print("Comparator start run") + on_message1 = asyncio.create_task(self.sniffer1.get_message()) + on_message2 = asyncio.create_task(self.sniffer2.get_message()) + while True: + logger.debug("Comparator await") + done, pending = await asyncio.wait([on_message1, on_message2], return_when=asyncio.FIRST_COMPLETED) + logger.debug("Comparator callback") + + for task in [on_message1, on_message2]: + if task not in done: + continue + + try: + obj = task.result() + if not isinstance(obj, (RequestData, ResponseData)): + logger.error("Received unwanted object: %s", obj) + continue + + if self.is_object_filtered(obj): + self.remember_filtered_object(obj) + sniffer = self.sniffer1 if task is on_message1 else self.sniffer2 + await self._pass_message(sniffer, obj) + continue + + if task is on_message1: + if obj.session not in self.sessions1: + q1: asyncio.Queue[ProxyMessage] = asyncio.Queue() + self.sessions1[obj.session] = q1 + q2 = self._try_extracting_session(obj, self.unpaired_sessions2) + if q2 is not None: + await self._create_comparator_task(q1, q2) + else: + obj_key = self._request_data_key(obj) + self.unpaired_sessions1[obj_key] = q1 + + q1 = self.sessions1[obj.session] + await q1.put(obj) + + if task is on_message2: + if obj.session not in self.sessions2: + q2 = asyncio.Queue() + self.sessions2[obj.session] = q2 + q1 = self._try_extracting_session(obj, self.unpaired_sessions1) + if q1 is not None: + await self._create_comparator_task(q1, q2) + else: + obj_key = self._request_data_key(obj) + self.unpaired_sessions2[obj_key] = q2 + + q2 = self.sessions2[obj.session] + await q2.put(obj) + finally: + if task is on_message1: + on_message1 = asyncio.create_task(self.sniffer1.get_message()) + elif task is on_message2: + on_message2 = asyncio.create_task(self.sniffer2.get_message()) + + def stop(self) -> None: + self.sniffer1.stop() + self.sniffer2.stop() + + +async def mitmproxy_run( + sniffer_socket_address: str, + host: str = "localhost", + port: int = 8080, + addon_script: str = "intercept_addon.py", + binary: str = "mitmdump", +) -> tuple[asyncio.subprocess.Process, str, int, str]: + proxy_name: str = "".join(random.choices(string.ascii_letters, k=32)) + module_dir = os.path.dirname(os.path.realpath(__file__)) + addon_path = os.path.join(module_dir, addon_script) + cli_args = [ + "--mode", + "regular", + "--listen-host", + host, + "--listen-port", + str(port), + "-s", + addon_path, + "--set", + f"socketaddress={sniffer_socket_address}", + "--set", + f"proxyname={proxy_name}", + ] + args = [binary] + cli_args + + # TODO: stderr DEVNULL + p = await asyncio.create_subprocess_exec(*args) # , stdout=subprocess.PIPE) + + return p, host, port, proxy_name + + +class MitmproxySniffer(Sniffer): + def __init__( + self, + sockaddr: str, + proxy_host: str, + proxy_port: int, + proxy_name: str, + mitmproxy_proc: asyncio.subprocess.Process, + ): + super().__init__() + self.sockaddr = sockaddr + self._server: Optional[asyncio.Server] = None + self.stop_event = asyncio.Event() + + self._read_queue: asyncio.Queue[bytes] = asyncio.Queue() + self._write_queue: asyncio.Queue[bytes] = asyncio.Queue() + + self._proc: Optional[asyncio.subprocess.Process] = mitmproxy_proc + self.proxy_host = proxy_host + self.proxy_port = proxy_port + self.proxy_name = proxy_name + + async def init(self) -> None: + self._server = await asyncio.start_unix_server(self.on_client_connected, path=self.sockaddr) + + @property + def proxy_url(self) -> str: + return f"http://{self.proxy_host}:{self.proxy_port}" + + async def _get_data(self) -> bytes: + tasks: list[asyncio.Task] = [] + wait_task = None + if self._proc is not None: + wait_task = asyncio.create_task(self._proc.wait()) + tasks.append(wait_task) + read_task = asyncio.create_task(self._read_queue.get()) + tasks.append(read_task) + try: + done, _pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + finally: + # Here, I found out that task cancelation can create bugs + # Hence, in a couroutine, you should always expect that when you + # await a coroutine/task, that can generate a CancelledError. + if wait_task is not None: + wait_task.cancel() + read_task.cancel() + + if wait_task in done and self._proc is not None: + raise SnifferProcessTerminated(self._proc) + + return read_task.result() + + async def _send_data(self, data: bytes) -> None: + datagram = to_sock_datagram(data) + await self._write_queue.put(datagram) + + async def on_client_connected(self, reader: StreamReader, writer: StreamWriter) -> None: + """Called when a proxy server has connected to this sniffer.""" + await self._read_queue.put(ProxyEvent(event=ProxyEvent.CONNECT).to_bytes()) + + # Handle both reading and writing concurrently + # Listen on the reader stream and the internal write queue + task1 = asyncio.create_task(async_read_sock_datagram(reader)) + task2 = asyncio.create_task(self._write_queue.get()) + stop_task = asyncio.create_task(self.stop_event.wait()) + try: + while True: + # ignore pending tasks, because they will be waited again in the next loop iteration + done, _pending = await asyncio.wait([task1, task2, stop_task], return_when=asyncio.FIRST_COMPLETED) + + if stop_task in done: + return + + if task1 in done: + data = task1.result() + await self._read_queue.put(data) + # recreate the task + task1 = asyncio.create_task(async_read_sock_datagram(reader)) + + if task2 in done: + data = task2.result() + writer.write(data) + await writer.drain() + # recreate the task + task2 = asyncio.create_task(self._write_queue.get()) + except asyncio.exceptions.IncompleteReadError: + pass + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"Exception") + # Hopefully, the event loop is still running + await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) + while not self._write_queue.empty(): + await self._write_queue.get() + raise + finally: + task1.cancel() + task2.cancel() + stop_task.cancel() + + await self._read_queue.put(ProxyEvent(event=ProxyEvent.CLOSE).to_bytes()) + + def stop(self) -> None: + if self._server is not None: + self._server.close() + self.stop_event.set() + self._server = None + if self._proc is None: + return + + self._proc.terminate() + try: + # Imagine having an API that that has both async and blocking functions + # Sadly, this is not the case for asyncio.subprocess.Process + p = self._proc._transport._proc.wait(timeout=0.5) # type: ignore[attr-defined] + except subprocess.TimeoutExpired: + self._proc.kill() + + self._proc = None + + def __del__(self) -> None: + # The object might be destroyed before __init__ terminates + if hasattr(self, "_proc"): + self.stop() + + +async def create_mitmproxy_sniffer_comparator( + on_diff: Callable, socksuffix1: str = "1", proxy_port1: int = 8080, socksuffix2: str = "2", proxy_port2: int = 8081 +) -> ComparatorSniffer: + if socksuffix1 == socksuffix2: + raise SkopoException("Socket suffixes must be different") + if proxy_port1 == proxy_port2: + raise SkopoException("Proxy ports must be different") + + sockaddr1 = os.path.join(get_runtime_dir(), DEFAULT_SOCKET_NAME + socksuffix1) + proc1, host1, port1, proxy_name1 = await mitmproxy_run(sockaddr1, port=proxy_port1) + sniffer1 = MitmproxySniffer(sockaddr1, host1, port1, proxy_name1, proc1) + + sockaddr2 = os.path.join(get_runtime_dir(), DEFAULT_SOCKET_NAME + socksuffix2) + proc2, host2, port2, proxy_name2 = await mitmproxy_run(sockaddr2, port=proxy_port2) + sniffer2 = MitmproxySniffer(sockaddr2, host2, port2, proxy_name2, proc2) + + await sniffer1.init() + await sniffer2.init() + + try: + await sniffer1.wait_event(ProxyEvent.CONNECT) + await sniffer2.wait_event(ProxyEvent.CONNECT) + except: + proc1.kill() + out, err = await proc1.communicate() + logger.error("Proc output: %s", out) + logger.error("Proc err: %s", err) + raise + + sniffer1.ignore_reconnects(True) + sniffer2.ignore_reconnects(True) + comparator = ComparatorSniffer(sniffer1, sniffer2, on_diff) + + return comparator diff --git a/cdprecorder/skopo/intercept_addon.py b/cdprecorder/skopo/intercept_addon.py new file mode 100644 index 0000000..256fc71 --- /dev/null +++ b/cdprecorder/skopo/intercept_addon.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import logging +import socket +import sys +from urllib.parse import urlunparse +from typing import Optional, TYPE_CHECKING + +from mitmproxy import ctx, http, tcp + +from sniff_protocol import ( + RequestData, + ResponseData, + SniffCommand, + SnifferMetadata, + SnifferProxyClient, + read_sock_datagram, + to_sock_datagram, +) + + +if TYPE_CHECKING: + from mitmproxy import http + + +class WrapperFormatter(logging.Formatter): + def __init__(self, formatter: loggingFormatter): + self.formatter = formatter + + def format(self, record: logging.LogRecord) -> str: + msg = self.formatter.format(record) + + msg_prefix = "" + if hasattr(record, "prefix"): + msg_prefix = f"[{record.prefix}]" + + return f"{msg_prefix}{msg}" + + def formatTime(self, *args: object, **kwargs: object) -> str: + return self.formatter.formatTime(*args, **kwargs) + + def formatException(self, *args: object, **kwargs: object) -> str: + return self.formatter.formatException(*args, **kwargs) + + def formatStack(self, *args: object, **kwargs: object) -> str: + return self.formatter.formatStack(*args, **kwargs) + + +class PrefixFilter(logging.Filter): + def __init__(self, prefix: Optional[str] = None): + self.prefix = prefix + + def filter(self, record: logging.LogRecord) -> bool: + if self.prefix is not None: + record.prefix = self.prefix + return True + + +class MitmproxySnifferProxyClient(SnifferProxyClient): + def __init__(self, sockaddr: str) -> None: + super().__init__() + self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.client.connect(sockaddr) + + def _get_data(self) -> bytes: + data = read_sock_datagram(self.client) + return data + + def _send_data(self, data: bytes) -> None: + datagram = to_sock_datagram(data) + self.client.sendall(datagram) + + def close(self) -> None: + if self.client is not None: + self.client.shutdown(socket.SHUT_RDWR) + self.client.close() + self.client = None + + def __del__(self): + self.close() + + +class TheSpy: + def __init__(self) -> None: + self.client = None + self.flows = set() + + self.log_filter = PrefixFilter() + handler = logging.getLogger().handlers[0] + handler.addFilter(self.log_filter) + wrapper_formatter = WrapperFormatter(handler.formatter) + handler.setFormatter(wrapper_formatter) + + self.objid_to_session = {} + + def load(self, loader): + url = f"{ctx.options.listen_host}:{ctx.options.listen_port}" + self.log_filter.prefix = url + + loader.add_option( + name="socketaddress", + # Wow, they really used typing at runtime + typespec=Optional[str], + default=None, + help="Socket address to send request/response pickled data", + ) + loader.add_option( + name="proxyname", + typespec=str, + default="thespy", + help="Name used to distinguish different mitmdump instances", + ) + + def start_connection(self, socketaddress: str): + if self.client is not None: + self.client.close() + logging.info("Connecting to %s", socketaddress) + self.client = MitmproxySnifferProxyClient(socketaddress) + assert isinstance(self.proxyname, str) + + def running(self) -> None: + if ctx.options.socketaddress is not None: + self.start_connection(ctx.options.socketaddress) + + def configure(self, updated: set[str]) -> None: + self.proxyname = ctx.options.proxyname + logging.info("Set proxyname=%s", self.proxyname) + if ctx.options.socketaddress is not None: + self.start_connection(ctx.options.socketaddress) + + def get_session_for_object_id(self, object_id: int) -> None: + if object_id not in self.objid_to_session: + session = self.client.new_session() + self.objid_to_session[object_id] = session + + return self.objid_to_session[object_id] + + def request(self, flow: http.HTTPFlow) -> None: + self.flows.add(id(flow)) + logging.info("Intercepted request") + mitmreq = flow.request + # TODO: handle mitmreq.authority + url = urlunparse( + ( + mitmreq.scheme, + f"{mitmreq.host}:{mitmreq.port}", + # Can this ever be bytes?? + mitmreq.path, + "", + "", + "", + ) + ) + + req = RequestData( + http_version=mitmreq.http_version.encode(), + method=mitmreq.method.encode(), + url=url.encode(), + headers=bytes(mitmreq.headers), + content=mitmreq.raw_content, + trailers=bytes(mitmreq.trailers) if mitmreq.trailers else b"", + meta=SnifferMetadata(object_id=id(flow), timestamp=0, proxyname=self.proxyname), + ) + + session = self.get_session_for_object_id(id(flow)) + session.send_request_data(req) + + logging.info("Sent request session=%s", session.id) + command = session.get_command() + + logging.info("Got command in request") + + if command.command == SniffCommand.REPLACE: + if command.response is not None: + r = command.response + + headers_bytes = [] + for key, value in r.headers: + headers_bytes.append((key.encode(), value.encode())) + resp = http.Response.make( + status_code=r.status_code, + content=r.raw_content, + headers=headers_bytes, + ) + resp.http_version = r.http_version + resp.reason = r.reason + resp.trailers = http.Headers(r.trailers) + + flow.response = resp + if command.request is not None: + r = command.request + + req = http.Request.make( + method=r.method, url=r.url, content=r.raw_content, headers=http.Headers(r.headers) + ) + req.http_version = (r.http_version,) + req.trailers = (http.Headers(r.trailers),) + + flow.request = req + elif command.command != SniffCommand.NOP: + raise Exception(f"Unknown command: {command.command}") + + logging.info("OK Intercepted request") + + def response(self, flow: http.HTTPFlow) -> None: + logging.info("Intercepted response") + if id(flow) not in self.flows: + logging.error("Got response before existing request") + else: + self.flows.remove(id(flow)) + + try: + mitmres = flow.response + logging.info("Response headers: %s", mitmres.headers) + res = ResponseData( + http_version=mitmres.http_version.encode(), + status_code=mitmres.status_code, + reason=mitmres.reason.encode(), + headers=bytes(mitmres.headers), + content=mitmres.raw_content, + trailers=bytes(mitmres.trailers) if mitmres.trailers else b"", + meta=SnifferMetadata(object_id=id(flow), timestamp=0, proxyname=self.proxyname), + ) + + logging.info("Constructed ResponseData") + + session = self.get_session_for_object_id(id(flow)) + session.send_response_data(res) + logging.info("Sent response session=%s", session.id) + + command = session.get_command() + logging.info("Got command in response") + + if command.command == SniffCommand.REPLACE: + if command.response is not None: + r = command.response + + resp = http.Response.make( + status_code=r.status_code, + content=r.raw_content, + headers=http.Headers(r.headers), + ) + resp.http_version = (r.http_version,) + resp.reason = (r.reason,) + resp.trailers = (http.Headers(r.trailers),) + + flow.response = resp + elif command.command != SniffCommand.NOP: + raise Exception(f"Unknown command: {command.command}") + + logging.info("OK Intercepted response") + except Exception as exc: + logging.info("Addon Exception: %s", exc) + raise + + def server_connect(self, data): + logging.info("About to connect to: %s. %s", data.server.address, str(data.server)) + if data.server.address[0] == "local": + data.server.address = ("127.0.0.1", data.server.address[1]) + + +def tcp_message(flow: tcp.TCPFlow): + from mitmproxy.utils import strutils + + message = flow.messages[-1] + # message.content = message.content.replace(b"foo", b"bar") + + logging.info( + f"tcp_message[from_client={message.from_client}), content={strutils.bytes_to_escaped_str(message.content)}]" + ) + + +addons = [TheSpy()] diff --git a/cdprecorder/skopo/sniff_protocol.py b/cdprecorder/skopo/sniff_protocol.py new file mode 100644 index 0000000..4f958ca --- /dev/null +++ b/cdprecorder/skopo/sniff_protocol.py @@ -0,0 +1,591 @@ +from __future__ import annotations + +import asyncio +import functools +import os.path +import subprocess +from collections import defaultdict +from enum import IntEnum +from typing import Callable, TYPE_CHECKING + + +if TYPE_CHECKING: + import socket + from asyncio import StreamReader, StreamWriter + from typing import Optional, TypeAlias, Protocol, Union + + +def bytes_to_varlen_bytes(data: bytes) -> bytes: + size = len(data) + return size.to_bytes(8, "big") + data + + +class SnifferMessageType(IntEnum): + REQUEST_DATA = 1 + RESPONSE_DATA = 2 + PROXY_EVENT = 3 + SNIFF_COMMAND = 4 + SNIFFER_ERROR = 5 + INT64 = 6 + STRING = 7 + NONE = 8 + + +class SnifferNone: + @staticmethod + def to_bytes() -> bytes: + data = b"" + data += SnifferMessageType.NONE.to_bytes(8, "big") + return data + + +class SnifferInt64: + @staticmethod + def to_bytes(value: Optional[int]) -> bytes: + if value is None: + return SnifferNone.to_bytes() + + data = b"" + data += SnifferMessageType.INT64.to_bytes(8, "big") + data += value.to_bytes(8, "big") + return data + + +class SnifferString: + @staticmethod + def to_bytes(value: Optional[str]) -> bytes: + if value is None: + return SnifferNone.to_bytes() + + data = b"" + data += SnifferMessageType.STRING.to_bytes(8, "big") + data += bytes_to_varlen_bytes(value.encode()) + + return data + + @staticmethod + def from_bytes(data: bytes) -> tuple[str, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.STRING + + length = int.from_bytes(data[i : i + 8], "big") + i += 8 + + s = data[i : i + length].decode("utf-8") + i += length + + return s, i + length + + +class SnifferMetadata: + def __init__(self, object_id: int, timestamp: int, proxyname: str) -> None: + self.object_id = object_id + self.timestamp = timestamp + self.proxyname = proxyname + + def to_bytes(self) -> bytes: + data = b"" + data += self.object_id.to_bytes(8, "big") + data += self.timestamp.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.proxyname.encode()) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[SnifferMetadata, int]: + i = 0 + object_id = int.from_bytes(data[i : i + 8], "big") + i += 8 + + timestamp = int.from_bytes(data[i : i + 8], "big") + i += 8 + + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + + proxyname = data[i : i + size].decode("utf8") + i += size + + return cls(object_id, timestamp, proxyname), i + + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"object_id={self.object_id}, " + text += f"timestamp={self.timestamp}, " + text += f"proxyname={self.proxyname}" + text += ")" + + return text + + +class RequestData: + __slots__ = ["http_version", "method", "url", "headers", "content", "trailers", "meta", "session"] + + def __init__( + self, + http_version: bytes, + method: bytes, + url: bytes, + headers: bytes, + content: bytes, + trailers: bytes, + meta: SnifferMetadata, + session: Optional[int] = None, + ): + self.http_version = http_version + self.method = method + self.url = url + self.headers = headers + self.content = content + self.trailers = trailers + self.meta = meta + self.session = session + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.REQUEST_DATA.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.http_version) + data += bytes_to_varlen_bytes(self.method) + data += bytes_to_varlen_bytes(self.url) + data += bytes_to_varlen_bytes(self.headers) + data += bytes_to_varlen_bytes(self.content) + data += bytes_to_varlen_bytes(self.trailers) + data += self.meta.to_bytes() + data += SnifferInt64.to_bytes(self.session) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[RequestData, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.REQUEST_DATA + + components = [] + for _ in range(6): + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + obj = data[i : i + size] + i += size + components.append(obj) + + http_version = components[0] + method = components[1] + url = components[2] + headers = components[3] + content = components[4] + trailers = components[5] + + meta, used = SnifferMetadata.from_bytes(data[i:]) + i += used + + session, used = sniffer_data_from_bytes(data[i:]) + assert isinstance(session, int) or session is None + i += used + + return cls(http_version, method, url, headers, content, trailers, meta, session), i + + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"http_version={self.http_version!r}, " + text += f"method={self.method!r}, " + text += f"url={self.url!r}, " + text += f"headers={self.headers!r}, " + text += f"content={self.content!r}, " + text += f"trailers={self.trailers!r}, " + text += f"meta={self.meta}," + text += f"session={self.session}" + text += ")" + + return text + + def __eq__(self, other: object) -> bool: + if not isinstance(other, RequestData): + raise NotImplementedError + + return ( + self.http_version == other.http_version + and self.method == other.method + and self.url == other.url + and self.headers == other.headers + and self.content == other.content + and self.trailers == other.trailers + ) + + +class ResponseData: + __slots__ = ["http_version", "status_code", "reason", "headers", "content", "trailers", "meta", "session"] + + def __init__( + self, + http_version: bytes, + status_code: int, + reason: bytes, + headers: bytes, + content: bytes, + trailers: bytes, + meta: SnifferMetadata, + session: Optional[int] = None, + ): + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.content = content + self.trailers = trailers + self.meta = meta + self.session = session + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.RESPONSE_DATA.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.http_version) + data += self.status_code.to_bytes(8, "big") + data += bytes_to_varlen_bytes(self.reason) + data += bytes_to_varlen_bytes(self.headers) + data += bytes_to_varlen_bytes(self.content) + data += bytes_to_varlen_bytes(self.trailers) + data += self.meta.to_bytes() + data += SnifferInt64.to_bytes(self.session) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[ResponseData, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.RESPONSE_DATA + + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + http_version = data[i : i + size] + i += size + + status_code = int.from_bytes(data[i : i + 8], "big") + i += 8 + + components = [] + for _ in range(4): + size = int.from_bytes(data[i : i + 8], "big") + i += 8 + obj = data[i : i + size] + i += size + components.append(obj) + + reason = components[0] + headers = components[1] + content = components[2] + trailers = components[3] + + meta, used = SnifferMetadata.from_bytes(data[i:]) + i += used + + session, used = sniffer_data_from_bytes(data[i:]) + assert isinstance(session, int) or session is None + i += used + + return cls(http_version, status_code, reason, headers, content, trailers, meta, session), i + + def __str__(self) -> str: + text = f"{self.__class__.__name__}(" + text += f"http_version={self.http_version!r}, " + text += f"status_code={self.status_code}, " + text += f"reasom={self.reason!r}, " + text += f"headers={self.headers!r}, " + text += f"content={self.content!r}, " + text += f"trailers={self.trailers!r}, " + text += f"meta={self.meta}," + text += f"session={self.session}" + text += ")" + + return text + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ResponseData): + raise NotImplementedError + + return ( + self.http_version == other.http_version + and self.status_code == other.status_code + and self.reason == other.reason + and self.headers == other.headers + and self.content == other.content + and self.trailers == other.trailers + ) + + +class ProxyEvent: + CONNECT = 1 + CLOSE = 2 + + def __init__(self, event: int): + self.event = event + self.session: Optional[int] = None + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[ProxyEvent, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.PROXY_EVENT + + event = int.from_bytes(data[i : i + 8], "big") + i += 8 + + return cls(event), i + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.PROXY_EVENT.to_bytes(8, "big") + data += self.event.to_bytes(8, "big") + + return data + + +class SnifferError: + def __init__(self, error_msg: str, error_type: str, session: Optional[int] = None): + self.error_msg = error_msg + self.error_type = error_type + self.session = session + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.SNIFFER_ERROR.to_bytes(8, "big") + data += SnifferString.to_bytes(self.error_msg) + data += SnifferString.to_bytes(self.error_type) + data += SnifferInt64.to_bytes(self.session) + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[SnifferError, int]: + i = 0 + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.SNIFFER_ERROR + + error_msg, used = SnifferString.from_bytes(data[i:]) + i += used + + error_type, used = SnifferString.from_bytes(data[i:]) + i += used + + session = int.from_bytes(data[i : i + 8], "big") + i += 8 + + return cls(error_msg, error_type, session), i + + +class SniffCommand: + NOP = 1 + REPLACE = 2 + CANCEL = 3 + CLOSE_CLIENT = 4 + + __slots__ = ["command", "request", "response", "meta", "session"] + + def __init__( + self, + command: int, + request: Optional[RequestData] = None, + response: Optional[ResponseData] = None, + meta: Optional[SnifferMetadata] = None, + session: Optional[int] = None, + ): + self.command = command + self.request = request + self.response = response + self.meta = meta + self.session = session + + def to_bytes(self) -> bytes: + data = b"" + data += SnifferMessageType.SNIFF_COMMAND.to_bytes(8, "big") + data += self.command.to_bytes(8, "big") + data += self.request.to_bytes() if self.request is not None else SnifferNone.to_bytes() + data += self.response.to_bytes() if self.response is not None else SnifferNone.to_bytes() + data += self.meta.to_bytes() if self.meta is not None else SnifferNone.to_bytes() + data += SnifferInt64.to_bytes(self.session) if self.session is not None else SnifferNone.to_bytes() + + return data + + @classmethod + def from_bytes(cls, data: bytes) -> tuple[SniffCommand, int]: + i = 0 + + message_type = int.from_bytes(data[i : i + 8], "big") + i += 8 + assert message_type == SnifferMessageType.SNIFF_COMMAND + + command = int.from_bytes(data[i : i + 8], "big") + i += 8 + + request, used = sniffer_data_from_bytes(data[i:]) + i += used + assert isinstance(request, RequestData) or request is None + + response, used = sniffer_data_from_bytes(data[i:]) + i += used + assert isinstance(response, ResponseData) or response is None + + meta, used = sniffer_data_from_bytes(data[i:]) + i += used + assert isinstance(meta, SnifferMetadata) or meta is None + + session, used = sniffer_data_from_bytes(data[i:]) + assert isinstance(session, int) or session is None + i += used + + return cls(command, request, response, meta, session), i + + +class ProxyException(Exception): + def __init__(self, obj: Optional[SnifferError], *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.obj = obj + + +class SnifferClientException(Exception): + def __init__(self, obj: Optional[SnifferMessage], *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.obj = obj + + +SnifferMessage: TypeAlias = Union[SniffCommand, SnifferError] +ProxyMessage: TypeAlias = Union[RequestData, ResponseData, ProxyEvent, SnifferError] +SkopoMessage: TypeAlias = Union[ProxyMessage, SnifferError] + + +def sniffer_data_from_bytes(data: bytes) -> tuple[Union[RequestData, ResponseData, ProxyEvent, SniffCommand, int, str, None], int]: + message_type = int.from_bytes(data[:8], "big") + if message_type == SnifferMessageType.REQUEST_DATA: + return RequestData.from_bytes(data) + elif message_type == SnifferMessageType.RESPONSE_DATA: + return ResponseData.from_bytes(data) + elif message_type == SnifferMessageType.PROXY_EVENT: + return ProxyEvent.from_bytes(data) + elif message_type == SnifferMessageType.SNIFF_COMMAND: + return SniffCommand.from_bytes(data) + elif message_type == SnifferMessageType.INT64: + assert len(data) >= 16, "Not enough bytes to decode" + return int.from_bytes(data[8:16], "big"), 16 + elif message_type == SnifferMessageType.STRING: + return SnifferString.from_bytes(data) + elif message_type == SnifferMessageType.NONE: + return None, 8 + else: + raise ProxyException(None, f"Unknown message type: {message_type}") + + +class SnifferProxyClient: + def __init__(self) -> None: + self.session_to_messages: dict[Optional[int], list[SnifferMessage]] = defaultdict(list) + + def _get_data(self) -> bytes: + raise NotImplementedError + + def pushback_message(self, obj: SnifferMessage) -> None: + self.session_to_messages[obj.session].append(obj) + + def get_message(self, ignore_error: bool = False, session: Optional[int] = None) -> SnifferMessage: + while len(self.session_to_messages[session]): + queue_obj = self.session_to_messages[session].pop(0) + + if isinstance(queue_obj, SnifferError) and not ignore_error: + raise ProxyException(queue_obj) + + if isinstance(queue_obj, SniffCommand): + if queue_obj.session is None and session is not None: + raise SnifferClientException(queue_obj, "Received sessionless message in a context with session") + return queue_obj + + del self.session_to_messages[session] + + while True: + data = self._get_data() + obj, _ = sniffer_data_from_bytes(data) + + if isinstance(obj, SnifferError) and not ignore_error: + raise ProxyException(obj) + + if isinstance(obj, SniffCommand): + if obj.session != session: + self.pushback_message(obj) + else: + return obj + + def get_command(self, ignore_error: bool = False, session: Optional[int] = None) -> SniffCommand: + while True: + msg = self.get_message(ignore_error, session) + if not isinstance(msg, SniffCommand): + raise ProxyException(msg, "Expected message of type SniffCommand") + return msg + + def _send_data(self, data: bytes) -> None: + raise NotImplementedError + + def send_proxy_message(self, obj: ProxyMessage, session: Optional[int] = None) -> None: + if hasattr(obj, "session"): + obj.session = session + data = obj.to_bytes() + self._send_data(data) + + def send_request_data(self, obj: RequestData, session: Optional[int] = None) -> None: + self.send_proxy_message(obj, session) + + def send_response_data(self, obj: ResponseData, session: Optional[int] = None) -> None: + self.send_proxy_message(obj, session) + + def send_proxy_event(self, obj: ProxyEvent, session: Optional[int] = None) -> None: + self.send_proxy_message(obj, session) + + def send_error(self, msg: SnifferError, session: Optional[int] = None) -> None: + msg.session = session + data = msg.to_bytes() + self._send_data(data) + + def new_session(self) -> SnifferClientSession: + return SnifferClientSession(self) + + +class SnifferClientSession: + _LAST_ID = 1 + + def __init__(self, client: SnifferProxyClient) -> None: + self.client = client + self.id = SnifferClientSession._LAST_ID + SnifferClientSession._LAST_ID += 1 + + def __getattr__(self, key: str) -> object: + value = getattr(self.client, key) + if callable(value): + value = functools.partial(value, session=self.id) + + return value + + +def to_sock_datagram(data: bytes) -> bytes: + size = len(data) + return size.to_bytes(8, "big") + data + + +async def async_read_sock_datagram(reader: StreamReader) -> bytes: + data_size = await reader.readexactly(8) + size = int.from_bytes(data_size, "big") + data = await reader.readexactly(size) + return data + + +def read_sock_datagram(sock: socket.socket) -> bytes: + data_size = sock.recv(8) + if len(data_size) != 8: + raise ProxyException("Received less than expected data") + size = int.from_bytes(data_size, "big") + data = sock.recv(size) + if len(data) != size: + raise ProxyException("Received less than expected data") + return data diff --git a/cdprecorder/user_interface.py b/cdprecorder/user_interface.py index ca7ebc3..d99b916 100644 --- a/cdprecorder/user_interface.py +++ b/cdprecorder/user_interface.py @@ -1,6 +1,8 @@ import abc from typing import TYPE_CHECKING, Callable +from twisted.internet.interfaces import IReactorCore + if TYPE_CHECKING: from typing import TYPE_CHECKING, Callable from twisted.internet.interfaces import IReactorCore diff --git a/dev/requirements-dev.txt b/dev/requirements-dev.txt index edb9605..df37086 100644 --- a/dev/requirements-dev.txt +++ b/dev/requirements-dev.txt @@ -1,8 +1,8 @@ +# Repo coverage-badge pybadges -pytest -pytest-cov==4.* -pytest-asyncio==0.23.5 + +# Styling mypy==1.17.* pylint==3.* isort==5.* @@ -10,6 +10,13 @@ black==24.* mypy-zope==1.0.* pre-commit==4.* +# Testing +pytest +virtualenv +pytest-cov==4.* +pytest-asyncio==0.23.5 +selenium==4.35.* + # Mypy stubs types-requests types-beautifulsoup4==4.* diff --git a/dev/requirements-http-apps.txt b/dev/requirements-http-apps.txt new file mode 100644 index 0000000..46c69f4 --- /dev/null +++ b/dev/requirements-http-apps.txt @@ -0,0 +1,4 @@ +# Requirements for the Python apps in tests/test_e2e_run/http_apps +Flask==3.1.1 +Flask-WTF==1.2.2 +WTForms==3.2.1 diff --git a/dev/setup.sh b/dev/setup.sh new file mode 100755 index 0000000..e310624 --- /dev/null +++ b/dev/setup.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -e + +PARENT_PATH=$(cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P) +VENV_HTTP_APPS=".venv_http_apps" + + +# TODO: make sure the script is run from a specific place + +# Create venv if it doesn't exist +if [ ! -d "$VENV_HTTP_APPS" ]; then + echo Creating virtual environment in $VENV_HTTP_APPS... + python3 -m virtualenv "$VENV_HTTP_APPS" +fi + +# Activate venv +echo Activating virtual environment $VENV_HTTP_APPS... +# shellcheck disable=SC1091 +source "$VENV_HTTP_APPS/bin/activate" + +echo Installing dependencies... +pip install --upgrade pip +pip install -r $PARENT_PATH/requirements-http-apps.txt + +echo Done diff --git a/main.py b/main.py index 791a07d..cd079c6 100644 --- a/main.py +++ b/main.py @@ -1,385 +1,60 @@ from __future__ import annotations -from typing import cast, Optional, Union, TYPE_CHECKING - -import bs4 -import bs4.builder._htmlparser -import pycdp -import requests +import asyncio import sys -import twisted.internet.reactor - -from twisted.python.log import err -from twisted.internet import defer, threads -from twisted.internet.interfaces import IReactorCore -from pycdp import cdp import cdprecorder -from cdprecorder import generate_python -from cdprecorder.action import ( - BrowserAction, - InputAction, - HttpAction, - LowercaseStr, - RequestAction, - ResponseAction, - response_action_from_python_response, -) -from cdprecorder.recorder import ( - HttpCommunication, - RecorderOptions, - record, -) - -import cdprecorder.analyser - -if TYPE_CHECKING: - import bs4 - - from pycdp.cdp.util import T_JSON_DICT - from twisted.python.failure import Failure - - from cdprecorder.type_checking import CdpEvent, HttpTarget - - -# https://github.com/twisted/twisted/issues/9909 -reactor = cast(IReactorCore, twisted.internet.reactor) - - -def generate_action( - action: HttpAction, prev_new_actions: list[Optional[HttpAction]] -) -> RequestAction: - new_action = RequestAction() - new_action.shallow_copy_from_action(action) - for target in action.targets: - target.apply(new_action, prev_new_actions) - - return new_action - - -def run_actions(actions: list[HttpAction]) -> None: - new_actions: list[Optional[HttpAction]] = [] - - for action in actions: - if isinstance(action, RequestAction): - new_action = generate_action(action, new_actions) - new_actions.append(new_action) - - with requests.Session() as session: - req = requests.Request( - method=new_action.method, - url=new_action.url, - headers=new_action.headers, - data=new_action.body, - cookies=new_action.cookies_to_dict(), - ) - prepared_request = req.prepare() - resp = session.send(prepared_request, allow_redirects=False) - resp_action = response_action_from_python_response(resp) - new_actions.append(resp_action) - - print(f"{new_action.method} {new_action.url} - {resp.status_code}") - - elif not isinstance(action, ResponseAction): - new_actions.append(None) - - -def to_cdp_event(event: CdpEvent) -> dict[str, Union[str, T_JSON_DICT]]: - cdp_method = None - for key, val in cdp.util._event_parsers.items(): - if val == event.__class__: - cdp_method = key - break - else: - raise Exception - - return { - "method": cdp_method, - "params": event.to_json(), - "type": "recv", - "domain": "-", - } - - -def get_only_http_actions(actions: list[BrowserActions]) -> list[HttpActions]: - return [action for action in actions if isinstance(action, HttpAction)] - - -def _generate_events_with_redirects_extracted(events: list[CdpEvent]) -> list[CdpEvent]: - new_events = [] - future_events: list[CdpEvents] = [] - wait_response_extra = False - wait_request_extra = False - wait_extra = False - for evt in events: - if wait_extra: - if ( - isinstance(evt, cdp.network.RequestWillBeSentExtraInfo) - and not wait_request_extra - or isinstance(evt, cdp.network.ResponseReceivedExtraInfo) - and not wait_response_extra - ): - wait_extra = False - new_events += future_events - future_events = [] - - if isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): - wait_request_extra = False - elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): - wait_response_extra = False - wait_extra = wait_response_extra or wait_request_extra - - new_events.append(evt) - continue - else: - new_events += future_events - future_events = [] - - future_events.append(evt) - if not isinstance(evt, cdp.network.RequestWillBeSent): - continue - - if not evt.redirect_response: - new_events += future_events - future_events = [] - wait_extra = False - else: - if evt.redirect_has_extra_info: - wait_response_extra = True - wait_request_extra = True - wait_extra = True - - response_evt = cdp.network.ResponseReceived( - request_id=evt.request_id, - loader_id=evt.loader_id, - timestamp=evt.timestamp, - type_=evt.type_, - response=evt.redirect_response, - has_extra_info=evt.redirect_has_extra_info, - frame_id=evt.frame_id, - ) - new_events.append(response_evt) - - new_events += future_events - - return new_events - +from cdprecorder import erpeto, recorder, skopo -def parse_communications_into_actions( - communications: list[Union[HttpCommunication, InputActioni]] -) -> list[BrowserAction]: - from cdprecorder import logger - actions: list[BrowserAction] = [] +async def on_fail(comparator, httpobj1, httpobj2): + print(f"Comparator failed after {comparator.requests_passed} passes") + print(f"First httpobj: {httpobj1}") + print(f"Second httpobj: {httpobj2}") + raise Exception("Failure") - for comm in communications: - logger.debug("Comm: %s", repr(comm)) - if not isinstance(comm, HttpCommunication): - actions.append(comm) - continue - - if comm.ignored: - continue - - response_bodies = list(comm.response_bodies) - - curr_request: Optional[RequestAction] = None - request_extra: Optional[RequestAction] = None - curr_response: Optional[ResponseAction] = None - response_extra: Optional[ResponseAction] = None - events = _generate_events_with_redirects_extracted(comm.events) - print("--------------------------------------------------------") - # Append to actions the requests/responses from each event - for evt in events: - if isinstance(evt, cdp.network.RequestWillBeSent): - if curr_request is not None: - if all((curr_request, request_extra, curr_response)): - curr_request.has_response = True - actions.append(curr_request) - actions.append(curr_response) - else: - actions.append(curr_request) - if curr_response is not None: - curr_request.has_response = True - actions.append(curr_response) - - curr_request = None - request_extra = None - curr_response = None - - """ - if curr_request is not None: - # Consume the previous request - if curr_response: - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - request_extra = None - - if curr_response: - # Consume the previous response - actions.append(curr_response) - if response_extra: - raise Exception - curr_response = None - """ - - curr_request = RequestAction() - curr_request.update_info(evt.request) - if evt.request.has_post_data and evt.request.post_data: - # TODO: Check if bytes in other entry - curr_request.set_body(evt.request.post_data.encode()) - - if request_extra is not None: - curr_request.merge(request_extra) - - elif isinstance(evt, cdp.network.RequestWillBeSentExtraInfo): - if request_extra is not None: - if all((curr_request, request_extra, curr_response)): - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - request_extra = None - - actions.append(curr_response) - curr_response = None - - """ - if curr_request is not None and request_extra is not None: - # Consume the previous request - if curr_response: - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - request_extra = None - - if curr_response: - # Consume the previous response - actions.append(curr_response) - if response_extra: - raise Exception - curr_response = None - """ - - if request_extra is not None: - raise Exception - request_extra = RequestAction() - request_extra.update_info(evt) - - if curr_request is not None: - curr_request.merge(request_extra) - # request_extra = None - - elif isinstance(evt, cdp.network.ResponseReceived): - if curr_response is None: - curr_response = ResponseAction(evt.response) - else: - raise Exception - - if response_extra is not None: - # Always merge response_extra over curr_response, not the other way - curr_response.merge(response_extra) - response_extra = None - - elif isinstance(evt, cdp.network.ResponseReceivedExtraInfo): - if curr_response is not None: - # Always merge response_extra over curr_response, not the other way - curr_response.merge(ResponseAction(evt)) - elif response_extra is None: - response_extra = ResponseAction(evt) - else: - raise Exception - - elif isinstance(evt, cdp.network.LoadingFinished): - # Manually inserted - response_body = response_bodies.pop(0) - if response_body is not None: - if curr_response: - curr_response.set_body(response_body) - elif response_extra: - response_extra.set_body(response_body) - else: - raise Exception - - if curr_request is not None: - if curr_response or response_extra: - curr_request.has_response = True - actions.append(curr_request) - curr_request = None - if curr_response is not None: - actions.append(curr_response) - curr_response = None - elif response_extra is not None: - actions.append(response_extra) - response_extra = None - - if curr_request is not None: - if curr_response is not None: - curr_request.has_response = True - curr_request.merge(request_extra) - actions.append(curr_request) - if curr_response is not None: - # Always merge response_extra over curr_response, not the other way - curr_response.merge(response_extra) - actions.append(curr_response) - - return actions +async def main() -> None: + comparator = await skopo.create_mitmproxy_sniffer_comparator(on_fail) + proxy_url1 = comparator.sniffer1.proxy_url + proxy_url2 = comparator.sniffer2.proxy_url -def make_action_ids_consecutive_from_list(actions: list[BrowserAction]): - for i, action in enumerate(actions): - action.ID = i + import requests + def threaded_run1(): + proxies = {"http": proxy_url1, "https": proxy_url1} + r = requests.get("https://google.com", proxies=proxies, verify=False) + print("done") + r = requests.get("https://google.com/test", proxies=proxies, verify=False) + r = requests.get("https://google.com/test2", proxies=proxies, verify=False) -async def run(options: RecorderOptions) -> None: - communications = await record(options) - actions = parse_communications_into_actions(communications) - make_action_ids_consecutive_from_list(actions) - cdprecorder.analyser.analyse_actions(actions) - # actions = get_only_http_actions(actions) - # run_actions(actions) + def threaded_run2(): + proxies = {"http": proxy_url2, "https": proxy_url2} + r = requests.get("https://google.com", proxies=proxies, verify=False) + r = requests.get("https://google.com/test", proxies=proxies, verify=False) + r = requests.get("https://google.com/test", proxies=proxies, verify=False) - generate_python.write_python_code(actions, "generated.py") + import threading - """ - async with target_session.wait_for(cdp.runtime.ConsoleAPICalled) as evt: - print(evt) - if evt.args[0].value == "click": - for obj in evt.args: - if obj.class_name == "PointerEvent": + thread1 = threading.Thread(target=threaded_run1, daemon=True) + thread1.start() + thread2 = threading.Thread(target=threaded_run2, daemon=True) + thread2.start() - result = await target_session.execute(cdp.runtime.get_properties(obj.object_id)) - pointer_evt_props = result[0] - break + await comparator.run() - for prop in pointer_evt_props: - if prop.name == "srcElement": - print(prop) - result = await target_session.execute(cdp.dom.describe_node(object_id=prop.value.object_id)) - print(result) - """ + time.sleep(20) + exit() """ - await threads.deferToThread(chrome.kill) - """ - - -async def main() -> None: cdprecorder.enable_logger() cdprecorder.configure_root_logger(stream=sys.stdout) start_url = "https://github.com" - options = RecorderOptions(start_url) - await run(options) - - -def main_error(failure: Failure) -> None: - err(failure) # type: ignore[no-untyped-call] - reactor.stop() + options = recorder.RecorderOptions(start_url) + await erpeto.run(options) + """ if __name__ == "__main__": - d = defer.ensureDeferred(main()) - d.addErrback(main_error) - d.addCallback(lambda *args: reactor.stop()) - reactor.run() + asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt index 6225b91..b13805e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-dateutil==2.8.2 python-cdp @ git+https://github.com/RazorBest/python-cdp.git@erpeto adblockparser==0.7 cheap_repr @ git+https://github.com/RazorBest/cheap_repr.git@add-annotations +mitmproxy==11.0.2 diff --git a/tests/test_e2e_run/http_apps/csrf_form/app.py b/tests/test_e2e_run/http_apps/csrf_form/app.py new file mode 100644 index 0000000..1a63b26 --- /dev/null +++ b/tests/test_e2e_run/http_apps/csrf_form/app.py @@ -0,0 +1,41 @@ +from flask import Flask, request, render_template_string +from flask_wtf import FlaskForm +from wtforms import StringField, SubmitField +from wtforms.validators import DataRequired +import secrets + +app = Flask(__name__) +app.secret_key = secrets.token_hex(16) # Needed for CSRF + +SECRET1 = "alpha" +SECRET2 = "beta" + + +class MyForm(FlaskForm): + field1 = StringField("Field 1", validators=[DataRequired()]) + field2 = StringField("Field 2", validators=[DataRequired()]) + submit = SubmitField("Submit") + + +@app.route("/", methods=["GET", "POST"]) +def index(): + form = MyForm() + if form.validate_on_submit() and form.field1.data == SECRET1 and form.field2.data == SECRET2: + return "SUCCESS", 200 + return render_template_string( + """ +
+ +The secret is not alpha
+ """, + form=form, + ) + + +if __name__ == "__main__": + app.run(debug=False) diff --git a/tests/test_e2e_run/selenium_runners.py b/tests/test_e2e_run/selenium_runners.py new file mode 100644 index 0000000..69ccdb1 --- /dev/null +++ b/tests/test_e2e_run/selenium_runners.py @@ -0,0 +1,33 @@ +from selenium import webdriver +from selenium.webdriver.common.by import By + + +def run_csrf_form_submitsuccess(driver): + # Step # | name | target | value + driver.implicitly_wait(1000000) + try: + # 1 | open | / | + driver.get("http://local:5000/") + # 2 | setWindowSize | 736x729 | + driver.set_window_size(736, 729) + # 3 | click | css=html | + driver.find_element(By.CSS_SELECTOR, "html").click() + # 4 | click | css=form | + driver.find_element(By.CSS_SELECTOR, "form").click() + # 5 | click | id=field2 | + driver.find_element(By.ID, "field2").click() + # 6 | click | id=field1 | + driver.find_element(By.ID, "field1").click() + # 7 | type | id=field1 | alpha + driver.find_element(By.ID, "field1").send_keys("alpha") + # 8 | type | id=field2 | beta + driver.find_element(By.ID, "field2").send_keys("beta") + # 9 | click | id=submit | + driver.find_element(By.ID, "submit").click() + finally: + driver.implicitly_wait(0) + + +selenium_runs = { + "http_apps/csrf_form": [run_csrf_form_submitsuccess], +} diff --git a/tests/test_e2e_run/test_http_apps.py b/tests/test_e2e_run/test_http_apps.py new file mode 100644 index 0000000..b23cac6 --- /dev/null +++ b/tests/test_e2e_run/test_http_apps.py @@ -0,0 +1,204 @@ +import asyncio +import logging +import os +import subprocess +import time +import pathlib +import urllib.request + +import pytest +from selenium import webdriver +from selenium_runners import run_csrf_form_submitsuccess + +from cdprecorder import erpeto, skopo, recorder + + +VENV_HTTP_APPS = ".venv_http_apps" + + +class VenvAppRunner: + def __init__(self, app_path: str): + abs_app_path = pathlib.Path(__file__).parent.resolve() / app_path / "app.py" + python_binary = os.path.join(VENV_HTTP_APPS, "bin/python") + self.proc = subprocess.Popen( + [python_binary, abs_app_path], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + def wait_until_up(self, timeout: int = 3): + start_time = time.time() + while time.time() - start_time < timeout: + try: + ret = self.proc.poll() + if ret is not None: + out, err = self.proc.communicate() + raise RuntimeError(f"Processed terminated. retcode: {ret}. stderr: {err}") + urllib.request.urlopen("http://localhost:5000") + break + except urllib.error.URLError: + time.sleep(0.1) + else: + self.proc.kill() + out, err = self.proc.communicate() + print(f"stdout: {out}") + print(f"stderr: {err}") + raise TimeoutError(f"Timeout exceeded: {timeout} seconds") + + def __del__(self): + self.proc.terminate() + + +async def on_fail(comparator, httpobj1, httpobj2): + logging.error("First httpobj: %s", httpobj1) + logging.error("Second httpobj %s", httpobj2) + raise Exception("Failure") + + +@pytest.mark.asyncio +async def test_csrf_form_run_csrf_from_submit_success(): + logging.getLogger("selenium").setLevel(logging.WARNING) + logging.getLogger("pycdp").setLevel(logging.WARNING) + logging.getLogger("cdprecorder").setLevel(logging.WARNING) + logging.info("Starting web app") + app = VenvAppRunner("http_apps/csrf_form") + app.wait_until_up() + logging.info("Web app started") + + comparator = await skopo.create_mitmproxy_sniffer_comparator(on_fail) + comparator.add_filtered_url_keywords(["clients2.google.com"]) + + proxy_url1 = comparator.sniffer1.proxy_url + proxy_url2 = comparator.sniffer2.proxy_url + + # proxy_url = f"http://{proxy_info.host}:{proxy_info.port}" + options = webdriver.ChromeOptions() + cdp_port = 9222 + options.add_argument(f"--remote-debugging-port={cdp_port}") + options.add_argument(f"--proxy-server={proxy_url1}") + # options.add_argument("--headless=new") + capabilities = options.to_capabilities() + capabilities["acceptInsecureCerts"] = True + print(options.to_capabilities()) + options.binary_location = "/usr/bin/google-chrome-stable" + driver = webdriver.Chrome(options) + logging.info("Instantiated web driver") + + recorder_options = recorder.RecorderOptions( + "http://local:5000", + cdp_host="localhost", + cdp_port=cdp_port, + collect_all=True, + ) + + async def pass_sniffer(sniffer): + print("pass sniffer") + msg = None + try: + while True: + msg = await sniffer.get_message() + + # logging.debug("Msg: %s", msg) + # logging.debug("Session: %s", msg.session) + await sniffer.to_session(msg.session).send_command(skopo.SniffCommand.NOP) + except asyncio.CancelledError: + if msg is not None: + await sniffer.to_session(msg.session).send_command(skopo.SniffCommand.NOP) + pass + except: + logging.exception("Pass sniffer ended with exception") + + pass_task = asyncio.create_task(pass_sniffer(comparator.sniffer1)) + rec = await recorder.init_recorder(recorder_options) + await asyncio.sleep(2) + pass_task.cancel() + try: + pass_task.result() + except (asyncio.CancelledError, asyncio.InvalidStateError): + pass + except: + logging.exception("pass_task ended with exception") + + logging.info("Start recording") + + try: + t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) + t2 = asyncio.create_task(recorder.collect_communications(rec, 30)) + t3 = asyncio.create_task(pass_sniffer(comparator.sniffer1)) + done, pending = await asyncio.wait([t1, t2, t3], return_when=asyncio.FIRST_COMPLETED) + + if t2 in done: + logging.info("Task2 is done") + logging.info("Result: %s", t2.result()) + + if t3 in done: + logging.info("Task3 is done") + + if t1 in pending: + raise Exception("Recorder stopped before selenium test") + + if t3 in pending: + t3.cancel() + + if t2 in pending: + rec.listener.cancel() + communications = await t2 + finally: + await rec.close() + + driver.quit() + + await asyncio.sleep(2) + + logging.info("Recorded communications") + + actions = erpeto.parse_communications_into_actions(communications) + erpeto.make_action_ids_consecutive_from_list(actions) + # actions = await erpeto.run_recorder(recorder_options) + erpeto.run_analyse(actions) + + # TODO: probably, a sniffer that's not linked to a comparator should not block + # comparator = skopo.SnifferComparator(on_fail, sniffer_manager1.sniffer, sniffer_manager2.sniffer) + + logging.info("Starting replication") + logging.getLogger("cdprecorder").setLevel(logging.DEBUG) + + #import requests + #requests.get("http://local:5000", ) + #time.sleep(2) + replicator_proxies = { + 'http': proxy_url2, + 'https': proxy_url2, + } + + #return + + driver = webdriver.Chrome(options) + + t1 = asyncio.create_task(asyncio.to_thread(run_csrf_form_submitsuccess, driver)) + t2 = asyncio.create_task(asyncio.to_thread(erpeto.run_replicate, actions, replicator_proxies)) + t3 = asyncio.create_task(comparator.run()) + + pair = asyncio.create_task(asyncio.wait([t1, t2], return_when=asyncio.ALL_COMPLETED)) + done, pending = await asyncio.wait([t1, t2, t3], return_when=asyncio.FIRST_COMPLETED) + + logging.debug("Pending state: %s, %s", pair.done(), t3.done()) + + if t2.done(): + t2.result() + + t3.cancel() + if t3 in done: + assert t3.result() is True + # comparator.stop() + return + + # TODO: this probably can't cancel t1 and t2 + for pair in pending: + pair.cancel() + + if t3 in done: + res = t3.result() + + assert res is True