Skip to content
Merged
2 changes: 2 additions & 0 deletions mocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Mocket - socket mocking library for Python."""

import importlib
import sys

Expand Down
36 changes: 35 additions & 1 deletion mocket/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,57 @@


def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes:
"""Encode a string or bytes to bytes.

Args:
s: String or bytes to encode
encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var)

Returns:
Encoded bytes
"""
if isinstance(s, str):
s = s.encode(encoding)
return bytes(s)


def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str:
"""Decode bytes or string to string.

Args:
s: String or bytes to decode
encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var)

Returns:
Decoded string
"""
if isinstance(s, bytes):
s = codecs.decode(s, encoding, "ignore")
return str(s)


def shsplit(s: str | bytes) -> list[str]:
"""Split a shell command string into arguments.

Args:
s: Shell command string or bytes

Returns:
List of shell command arguments
"""
s = decode_from_bytes(s)
return shlex.split(s)


def do_the_magic(body):
def do_the_magic(body: bytes) -> str:
"""Detect MIME type of binary data using puremagic.

Args:
body: Binary data to analyze

Returns:
MIME type string
"""
try:
magic = puremagic.magic_string(body)
except puremagic.PureError:
Expand Down
33 changes: 26 additions & 7 deletions mocket/decorators/async_mocket.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
"""Async version of Mocket decorator."""

from __future__ import annotations

from typing import Any, Callable

from mocket.decorators.mocketizer import Mocketizer
from mocket.utils import get_mocketize


async def wrapper(
test,
truesocket_recording_dir=None,
strict_mode=False,
strict_mode_allowed=None,
*args,
**kwargs,
):
test: Callable,
truesocket_recording_dir: str | None = None,
strict_mode: bool = False,
strict_mode_allowed: list | None = None,
*args: Any,
**kwargs: Any,
) -> Any:
"""Async wrapper function for @async_mocketize decorator.

Args:
test: Async test function to wrap
truesocket_recording_dir: Directory for recording true socket calls
strict_mode: Enable STRICT mode to forbid real socket calls
strict_mode_allowed: List of allowed hosts in STRICT mode
*args: Test arguments
**kwargs: Test keyword arguments

Returns:
Result of the test function
"""
async with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
):
Expand Down
120 changes: 99 additions & 21 deletions mocket/decorators/mocketizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
"""Mocketizer decorator for managing Mocket lifecycle in tests."""

from __future__ import annotations

from typing import Any, Callable

from mocket.mocket import Mocket
from mocket.mode import MocketMode
from mocket.utils import get_mocketize


class Mocketizer:
"""Context manager and decorator for managing Mocket lifecycle in tests."""

def __init__(
self,
instance=None,
namespace=None,
truesocket_recording_dir=None,
strict_mode=False,
strict_mode_allowed=None,
):
instance: Any | None = None,
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
strict_mode: bool = False,
strict_mode_allowed: list | None = None,
) -> None:
"""Initialize the Mocketizer.

Args:
instance: Test instance (optional)
namespace: Namespace for recordings
truesocket_recording_dir: Directory for recording true socket calls
strict_mode: Enable STRICT mode to forbid real socket calls
strict_mode_allowed: List of allowed hosts in STRICT mode
"""
self.instance = instance
self.truesocket_recording_dir = truesocket_recording_dir
self.namespace = namespace or str(id(self))
Expand All @@ -23,41 +40,89 @@ def __init__(
"Allowed locations are only accepted when STRICT mode is active."
)

def enter(self):
def enter(self) -> None:
"""Enter the Mocketizer context (enable Mocket)."""
Mocket.enable(
namespace=self.namespace,
truesocket_recording_dir=self.truesocket_recording_dir,
)
if self.instance:
self.check_and_call("mocketize_setup")

def __enter__(self):
def __enter__(self) -> Mocketizer:
"""Enter context manager.

Returns:
Self for use in `with` statements
"""
self.enter()
return self

def exit(self):
def exit(self) -> None:
"""Exit the Mocketizer context (disable Mocket)."""
if self.instance:
self.check_and_call("mocketize_teardown")

Mocket.disable()

def __exit__(self, type, value, tb):
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
"""Exit context manager.

Args:
type: Exception type
value: Exception value
tb: Traceback
"""
self.exit()

async def __aenter__(self, *args, **kwargs):
async def __aenter__(self, *args: Any, **kwargs: Any) -> Mocketizer:
"""Enter async context manager.

Returns:
Self for use in `async with` statements
"""
self.enter()
return self

async def __aexit__(self, *args, **kwargs):
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
"""Exit async context manager.

Args:
*args: Exception arguments
**kwargs: Exception keyword arguments
"""
self.exit()

def check_and_call(self, method_name):
def check_and_call(self, method_name: str) -> None:
"""Check if instance has a method and call it.

Args:
method_name: Name of method to check and call
"""
method = getattr(self.instance, method_name, None)
if callable(method):
method()

@staticmethod
def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
def factory(
test: Callable,
truesocket_recording_dir: str | None,
strict_mode: bool,
strict_mode_allowed: list | None,
args: tuple,
) -> Mocketizer:
"""Create a Mocketizer instance for a test function.

Args:
test: Test function being decorated
truesocket_recording_dir: Recording directory
strict_mode: Enable STRICT mode
strict_mode_allowed: Allowed hosts in STRICT mode
args: Positional arguments to test

Returns:
Configured Mocketizer instance
"""
instance = args[0] if args else None
namespace = None
if truesocket_recording_dir:
Expand All @@ -79,13 +144,26 @@ def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, ar


def wrapper(
test,
truesocket_recording_dir=None,
strict_mode=False,
strict_mode_allowed=None,
*args,
**kwargs,
):
test: Callable,
truesocket_recording_dir: str | None = None,
strict_mode: bool = False,
strict_mode_allowed: list | None = None,
*args: Any,
**kwargs: Any,
) -> Any:
"""Wrapper function for @mocketize decorator.

Args:
test: Test function to wrap
truesocket_recording_dir: Recording directory
strict_mode: Enable STRICT mode
strict_mode_allowed: Allowed hosts in STRICT mode
*args: Test arguments
**kwargs: Test keyword arguments

Returns:
Result of the test function
"""
with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
):
Expand Down
60 changes: 49 additions & 11 deletions mocket/entry.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
"""Mocket entry base class for registering mock responses."""

from __future__ import annotations

import collections.abc
from typing import Any

from mocket.compat import encode_to_bytes
from mocket.mocket import Mocket


class MocketEntry:
"""Base class for Mocket entries that match requests and return responses."""

class Response(bytes):
"""Response wrapper class that extends bytes."""

@property
def data(self):
def data(self) -> bytes:
"""Get the response data."""
return self

response_index = 0
request_cls = bytes
response_cls = Response
responses = None
_served = None
response_index: int = 0
request_cls: type = bytes
response_cls: type = Response
responses: list | None = None
_served: bool | None = None

def __init__(self, location: tuple, responses: Any) -> None:
"""Initialize a Mocket entry.

def __init__(self, location, responses):
Args:
location: Tuple of (host, port)
responses: Single response or list of responses to cycle through
"""
self._served = False
self.location = location

Expand All @@ -34,18 +50,40 @@ def __init__(self, location, responses):
r = self.response_cls(r)
self.responses.append(r)

def __repr__(self):
def __repr__(self) -> str:
"""Return a string representation of the entry."""
return f"{self.__class__.__name__}(location={self.location})"

@staticmethod
def can_handle(data):
def can_handle(data: bytes) -> bool:
"""Check if this entry can handle the given request data.

Args:
data: Request data to check

Returns:
True if this entry can handle the request, False otherwise
"""
return True

def collect(self, data):
def collect(self, data: bytes) -> None:
"""Collect the request data in the Mocket singleton.

Args:
data: Request data to collect
"""
req = self.request_cls(data)
Mocket.collect(req)

def get_response(self):
def get_response(self) -> bytes:
"""Get the next response to send.

Returns:
Response bytes to send to the client

Raises:
BaseException: If a response is an exception, it will be raised
"""
response = self.responses[self.response_index]
if self.response_index < len(self.responses) - 1:
self.response_index += 1
Expand Down
7 changes: 7 additions & 0 deletions mocket/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Mocket exception classes."""


class MocketException(Exception):
"""Base exception class for Mocket errors."""

pass


class StrictMocketException(MocketException):
"""Exception raised when a socket operation is not allowed in STRICT mode."""

pass
Loading