{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Modifying Operators with the Interceptor Pattern" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from typing import Any, Callable\n", "from functools import wraps\n", "from types import MethodType" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In software engineering, the interceptor design pattern changes the behaviour of an existing service [^1]. This approach is used, for example, to implement the `evokit.evolvables.selectors.Elitist` wrapper.\n", "\n", "This short tutorial illustrates how this is done. You can define new wrappers using this approach.\n", "\n", "[^1]: \"Decorator\" is more fitting, but the name conflicts with the Python decorator. \"Interceptor\" seems to be the next best thing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A Trivial Problem\n", "\n", "To begin, consider a simple example. Declare a class `NumberBox` with attribute `.value` and method `.increment`. For now, calling `.increment` increments `.value` by 1.\n", "\n", "The problem, then, is to find a way to modify `.increment` so that it increments `.value` by 2, 3, and so on." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from typing import Self\n", "from typing import override" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class NumberBox:\n", " def __init__(self: Self, value: int)-> None:\n", " self.value: int = value\n", "\n", " def increment(self: Self)-> None:\n", " self.value = self.value + 1\n", "\n", "def check_increment(nb: NumberBox):\n", " old_value = nb.value\n", " print(f\"Initial value is {old_value}\")\n", " new_value = (nb.increment(), nb.value)[-1]\n", " print(\"Calling `.increment()` increases the value by\"\n", " f\" {new_value - old_value}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check that `NumberBox` behaves as expected." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial value is 1\n", "Calling `.increment()` increases the value by 1\n" ] } ], "source": [ "nb = NumberBox(1)\n", "check_increment(nb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Attempt 1: Overriding `.increment`\n", "\n", "There are many ways to change the number `.increment` increases `.value` by. For example, one can define a new class, `NumberBoxBy2`, which extends NumberBox and overrides `.increment`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class NumberBoxBy2(NumberBox):\n", " @override\n", " def increment(self: Self)-> None:\n", " self.value = self.value + 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This approach is inflexible, however. `NumberBoxBy2.increment` now increases `.value` by exactly 2; changing the behaviour of `.increment` further requires extending `NumberBoxBy2`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial value is 1\n", "Calling `.increment()` increases the value by 2\n" ] } ], "source": [ "nb_new = NumberBoxBy2(1)\n", "check_increment(nb_new)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Attempt 2: Constructing an Interceptor\n", "\n", "The correct approach follows. Suppose there is a function `by1more: NumberBox -> NumberBox` that, when given an `NumberBox`, returns a `NumberBox` whose `.increment` increases `.value` by 1 more. \n", "\n", "```python\n", "def by1more(sel: NumberBox) -> NumberBox:\n", " pass\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The key is using `functools.wraps`. A minimal example follows:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def by1more(numbox: NumberBox):\n", " def wrap_function(original_increment:\n", " Callable[[NumberBox], None]) -> Callable:\n", "\n", " @wraps(original_increment)\n", " def wrapper(self: NumberBox) -> None:\n", " original_increment(self)\n", " self.value = self.value + 1\n", " return wrapper\n", "\n", " setattr(numbox, 'increment',\n", " MethodType(\n", " wrap_function(numbox.increment.__func__), # type:ignore\n", " numbox))\n", " return numbox" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial value is 1\n", "Calling `.increment()` increases the value by 1\n" ] } ], "source": [ "new_nb = NumberBox(1)\n", "check_increment(new_nb)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initial value is 1\n", "Calling `.increment()` increases the value by 3\n" ] } ], "source": [ "modified_nb = NumberBox(1)\n", "check_increment(by1more(by1more(modified_nb)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that `by1more` modifies its argument in-place." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 2 }