diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index c0902371..47ca2dd9 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -32,10 +32,13 @@ from taskiq.utils import maybe_awaitable from taskiq.warnings import TaskiqDeprecationWarning -if sys.version_info >= (3, 11): +if sys.version_info >= ( + 3, + 11, +): # Check which python version are we running to import correctly from typing import Self else: - from typing_extensions import Self + from typing_extensions import Self # pragma: no cover if TYPE_CHECKING: # pragma: no cover @@ -46,6 +49,7 @@ _FuncParams = ParamSpec("_FuncParams") _ReturnType = TypeVar("_ReturnType") +# an event handler can be either a sync or an async function that has one parameter of type TaskiqState EventHandler: TypeAlias = Callable[[TaskiqState], Awaitable[None] | None] logger = getLogger("taskiq") diff --git a/taskiq/cli/worker/process_manager.py b/taskiq/cli/worker/process_manager.py index 22257e86..394b05d8 100644 --- a/taskiq/cli/worker/process_manager.py +++ b/taskiq/cli/worker/process_manager.py @@ -169,7 +169,7 @@ def __init__( for path_to_watch in watch_paths: logger.debug(f"Watching directory: {path_to_watch}") observer.schedule( - FileWatcher( + FileWatcher( # type: ignore callback=schedule_workers_reload, path=Path(path_to_watch), use_gitignore=not args.no_gitignore, diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 53cef7c0..e48ac173 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -163,6 +163,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: receiver = receiver_type( broker=broker, executor=pool, + observer=getattr(broker, "_receiver_observer", None), validate_params=not args.no_parse, max_async_tasks=args.max_async_tasks, max_prefetch=args.max_prefetch, diff --git a/taskiq/middlewares/prometheus_middleware.py b/taskiq/middlewares/prometheus_middleware.py index 56837cf3..704c55c8 100644 --- a/taskiq/middlewares/prometheus_middleware.py +++ b/taskiq/middlewares/prometheus_middleware.py @@ -1,3 +1,4 @@ +import datetime import os from logging import getLogger from pathlib import Path @@ -6,6 +7,7 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.message import TaskiqMessage +from taskiq.receiver.observer import ReceiverObserver from taskiq.result import TaskiqResult logger = getLogger("taskiq.prometheus") @@ -20,7 +22,7 @@ class PrometheusMiddleware(TaskiqMiddleware): :param server_port: The port to listen on. :param server_addr: The address to listen on. - :paam metrics_path: The path to store metrics for multiproc env. + :param metrics_path: The path to store metrics for multiproc env. """ def __init__( @@ -74,6 +76,18 @@ def __init__( "Time of function execution", ["task_name"], ) + + self.queue_wait_seconds = Histogram( + "queue_wait_seconds", + "time task spent in message queue", + ["task_name"], + ) + self.task_errors_by_type = Counter( + "task_errors_by_type", + "Number of errors raised in tasks by their type", + ["task_name", "error_type"], + ) + self.server_port = server_port self.server_addr = server_addr @@ -104,6 +118,24 @@ def startup(self) -> None: except OSError as exc: logger.debug("Cannot start prometheus server: %s", exc) + def pre_send( + self, + message: "TaskiqMessage", + ) -> "TaskiqMessage": + """ + Function to track the time a task spend in queue. + + This function tracks the time a task spends in a queue until it is executed. + + :param message: current message. + :return: message + """ + if not message.labels.get("_taskiq_enqueue_timestamp"): + message.labels["_taskiq_enqueue_timestamp"] = datetime.datetime.now( + datetime.UTC, + ).isoformat() # Might conside using timezones too + return message + def pre_execute( self, message: "TaskiqMessage", @@ -117,9 +149,40 @@ def pre_execute( :param message: current message. :return: message """ + if message.labels.get( + "_taskiq_enqueue_timestamp", + ): # Handle case where the sender doesn't use the prometheus middleware + time_delta = datetime.datetime.now( + datetime.UTC, + ) - datetime.datetime.fromisoformat( + message.labels["_taskiq_enqueue_timestamp"], + ) + time_delta = max(0, time_delta.total_seconds()) + self.queue_wait_seconds.labels(message.task_name).observe( + time_delta, + ) + self.received_tasks.labels(message.task_name).inc() return message + def on_error( + self, + message: TaskiqMessage, + result: TaskiqResult[Any], # pylint: disable=unused-argument + exception: BaseException, + ) -> None: + """ + This function tracks the number of errors raised by tasks. + + :param message: the received task message + :param result: the result of task + :param exception: exception raised + """ + self.task_errors_by_type.labels( + message.task_name, + type(exception).__name__, + ).inc() + def post_execute( self, message: "TaskiqMessage", @@ -137,6 +200,15 @@ def post_execute( self.success_tasks.labels(message.task_name).inc() self.execution_time.labels(message.task_name).observe(result.execution_time) + def set_broker(self, broker: "AsyncBroker") -> None: # noqa: F821 + """ + Set broker and attach receiver observer. + + :param broker: broker to set. + """ + super().set_broker(broker) + broker._receiver_observer = PrometheusReceiverObserver() # noqa: SLF001 + def post_save( self, message: "TaskiqMessage", @@ -149,3 +221,60 @@ def post_save( :param result: result of execution. """ self.saved_results.labels(message.task_name).inc() + + +class PrometheusReceiverObserver(ReceiverObserver): + """Receiver observer implementation for prometheus.""" + + def __init__(self) -> None: + try: + from prometheus_client import Counter, Gauge # noqa: PLC0415 + except ImportError as exc: + raise ImportError( + "Cannot initialize metrics. Please install 'taskiq[metrics]'.", + ) from exc + + self.prefetch_queue_size = Gauge( + "prefetch_queue_size", + "The number of task in the prefetch queue.", + multiprocess_mode="livesum", + ) + self.semaphore_available = Gauge( + "semaphore_available", + "Number of semaphore slots available in broker", + multiprocess_mode="livesum", + ) + self.active_tasks_count = Gauge( + "worker_active_tasks_count", + "Number of active tasks in worker", + multiprocess_mode="livesum", + ) + self.task_not_found_total = Counter( + "task_not_found_total", + "Number of times the worker got a task not registered", + ["task_name"], + ) + self.deserialize_error = Counter( + "deserialize_error_count", + "Number of times broker faced a deserialization error", + ) + + def on_prefetch_queue_size(self, size: int) -> None: + """Record current prefetch queue depth.""" + self.prefetch_queue_size.set(size) + + def on_semaphore_status(self, available: int) -> None: + """Record available semaphore slots.""" + self.semaphore_available.set(available) + + def on_active_tasks_count(self, count: int) -> None: + """Record number of currently executing tasks.""" + self.active_tasks_count.set(count) + + def on_task_not_found(self, task_name: str) -> None: + """Increment counter for unregistered task lookups.""" + self.task_not_found_total.labels(task_name).inc() + + def on_deserialize_error(self, raw: bytes, error: Exception) -> None: + """Increment counter for message deserialization failures.""" + self.deserialize_error.inc() diff --git a/taskiq/receiver/__init__.py b/taskiq/receiver/__init__.py index c6a7e66b..b9527fb3 100644 --- a/taskiq/receiver/__init__.py +++ b/taskiq/receiver/__init__.py @@ -1,5 +1,6 @@ """Package for message receiver.""" from taskiq.receiver.receiver import Receiver +from taskiq.receiver.observer import ReceiverObserver -__all__ = ["Receiver"] +__all__ = ["Receiver", "ReceiverObserver"] diff --git a/taskiq/receiver/observer.py b/taskiq/receiver/observer.py new file mode 100644 index 00000000..70a7ccd8 --- /dev/null +++ b/taskiq/receiver/observer.py @@ -0,0 +1,35 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ReceiverObserver(Protocol): + """ + Observer for receiver stats. + + This class is used to observe/collect metrics for the receiver. + This includes semaphore usage, tasks in queue, etc. + + metrics tracked: + - Number of tasks in queue + - Number of tasks in execution (semaphore usage) + """ + + def on_prefetch_queue_size(self, size: int) -> None: + """Called when the prefetch queue size changes.""" + ... + + def on_semaphore_status(self, available: int) -> None: + """Called when semaphore availability changes.""" + ... + + def on_active_tasks_count(self, count: int) -> None: + """Called when the number of active tasks changes.""" + ... + + def on_task_not_found(self, task_name: str) -> None: + """Called when a received task is not registered.""" + ... + + def on_deserialize_error(self, raw: bytes, error: Exception) -> None: + """Called when a message fails to deserialize.""" + ... diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 99298af2..6afd1101 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -18,6 +18,7 @@ from taskiq.context import Context from taskiq.exceptions import NoResultError from taskiq.message import TaskiqMessage +from taskiq.receiver.observer import ReceiverObserver from taskiq.receiver.params_parser import parse_params from taskiq.result import TaskiqResult from taskiq.state import TaskiqState @@ -35,6 +36,7 @@ def __init__( self, broker: AsyncBroker, executor: Executor | None = None, + observer: ReceiverObserver | None = None, validate_params: bool = True, max_async_tasks: "int | None" = None, max_prefetch: int = 0, @@ -54,6 +56,7 @@ def __init__( self.dependency_graphs: dict[str, DependencyGraph] = {} self.propagate_exceptions = propagate_exceptions self.on_exit = on_exit + self.observer = observer self.ack_time = ack_type or AcknowledgeType.WHEN_SAVED self.known_tasks: set[str] = set() self.max_tasks_to_execute = max_tasks_to_execute @@ -92,6 +95,11 @@ async def callback( # noqa: C901, PLR0912 taskiq_msg = self.broker.formatter.loads(message=message_data) taskiq_msg.parse_labels() except Exception as exc: + if self.observer is not None: + self.observer.on_deserialize_error( + raw=message_data, + error=exc, + ) logger.warning( "Cannot parse message: %s. Skipping execution.\n %s", message_data, @@ -102,6 +110,11 @@ async def callback( # noqa: C901, PLR0912 logger.debug(f"Received message: {taskiq_msg}") task = self.broker.find_task(taskiq_msg.task_name) if task is None: + if self.observer is not None: + self.observer.on_task_not_found( + taskiq_msg.task_name, + ) + logger.warning( 'task "%s" is not found. Maybe you forgot to import it?', taskiq_msg.task_name, @@ -363,6 +376,7 @@ async def prefetcher( break try: await self.sem_prefetch.acquire() + if ( self.max_tasks_to_execute and fetched_tasks >= self.max_tasks_to_execute @@ -376,6 +390,7 @@ async def prefetcher( # and continue the loop. So it will check if finished event was set. if not done: self.sem_prefetch.release() + continue # We're done, so now we need to check # whether task has returned an error. @@ -383,6 +398,12 @@ async def prefetcher( current_message = asyncio.create_task(iterator.__anext__()) # type: ignore fetched_tasks += 1 await queue.put(message) + + if self.observer is not None: + self.observer.on_prefetch_queue_size( + queue.qsize(), + ) + except (asyncio.CancelledError, StopAsyncIteration): break # We don't want to fetch new messages if we are shutting down. @@ -391,7 +412,7 @@ async def prefetcher( await queue.put(QUEUE_DONE) self.sem_prefetch.release() - async def runner( + async def runner( # noqa: C901 self, queue: "asyncio.Queue[bytes | AckableMessage]", ) -> None: @@ -413,17 +434,29 @@ def task_cb(task: "asyncio.Task[Any]") -> None: :param task: finished task """ tasks.discard(task) + if self.observer is not None: + self.observer.on_active_tasks_count( + len(tasks), + ) + if self.sem is not None: self.sem.release() + if self.observer is not None: + self.observer.on_semaphore_status(self.sem._value) # noqa + while True: try: # Waits for semaphore to be released. if self.sem is not None: await self.sem.acquire() + if self.observer is not None: + self.observer.on_semaphore_status(self.sem._value) # noqa self.sem_prefetch.release() message = await queue.get() + if self.observer is not None: + self.observer.on_prefetch_queue_size(queue.qsize()) if message is QUEUE_DONE: # asyncio.wait will throw an error if there is nothing to wait for if tasks: @@ -438,7 +471,10 @@ def task_cb(task: "asyncio.Task[Any]") -> None: self.callback(message=message, raise_err=False), ) tasks.add(task) - + if self.observer is not None: + self.observer.on_active_tasks_count( + len(tasks), + ) # We want the task to remove itself from the set when it's done. # # Because if we won't save it anywhere,