From 856abb9d5a9d75028001eca3b1038c77aca52727 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 21 Mar 2026 19:32:08 +0100 Subject: [PATCH] Added better API for timeouts. --- python/natsrpy/_natsrpy_rs/__init__.pyi | 4 +- python/natsrpy/_natsrpy_rs/js/consumers.pyi | 1 + src/exceptions/rust_err.rs | 3 +- src/js/consumers/pull/consumer.rs | 14 ++++-- src/nats_cls.rs | 56 ++++++++++++--------- src/subscription.rs | 13 +---- src/utils/futures.rs | 32 +++++++++++- src/utils/py_types.rs | 34 +++++++++++++ 8 files changed, 113 insertions(+), 44 deletions(-) diff --git a/python/natsrpy/_natsrpy_rs/__init__.pyi b/python/natsrpy/_natsrpy_rs/__init__.pyi index 54b4ae0..ff8402d 100644 --- a/python/natsrpy/_natsrpy_rs/__init__.pyi +++ b/python/natsrpy/_natsrpy_rs/__init__.pyi @@ -20,8 +20,8 @@ class Nats: read_buffer_capacity: int = 65535, sender_capacity: int = 128, max_reconnects: int | None = None, - connection_timeout: timedelta = ..., - request_timeout: timedelta = ..., + connection_timeout: float | timedelta = ..., + request_timeout: float | timedelta = ..., ) -> None: ... async def startup(self) -> None: ... async def shutdown(self) -> None: ... diff --git a/python/natsrpy/_natsrpy_rs/js/consumers.pyi b/python/natsrpy/_natsrpy_rs/js/consumers.pyi index 14bb850..7cba0d8 100644 --- a/python/natsrpy/_natsrpy_rs/js/consumers.pyi +++ b/python/natsrpy/_natsrpy_rs/js/consumers.pyi @@ -160,4 +160,5 @@ class PullConsumer: expires: timedelta | None = None, min_pending: int | None = None, min_ack_pending: int | None = None, + timeout: float | timedelta | None = None, ) -> list[JetStreamMessage]: ... diff --git a/src/exceptions/rust_err.rs b/src/exceptions/rust_err.rs index cfac744..86ceb87 100644 --- a/src/exceptions/rust_err.rs +++ b/src/exceptions/rust_err.rs @@ -1,4 +1,4 @@ -use pyo3::exceptions::PyTypeError; +use pyo3::exceptions::{PyTimeoutError, PyTypeError}; use crate::exceptions::py_err::{NatsrpyPublishError, NatsrpySessionError}; @@ -72,6 +72,7 @@ impl From for pyo3::PyErr { fn from(value: NatsrpyError) -> Self { match value { NatsrpyError::PublishError(_) => NatsrpyPublishError::new_err(value.to_string()), + NatsrpyError::Timeout(_) => PyTimeoutError::new_err(value.to_string()), NatsrpyError::PyError(py_err) => py_err, NatsrpyError::InvalidArgument(descr) => PyTypeError::new_err(descr), _ => NatsrpySessionError::new_err(value.to_string()), diff --git a/src/js/consumers/pull/consumer.rs b/src/js/consumers/pull/consumer.rs index e2e4960..fd14f61 100644 --- a/src/js/consumers/pull/consumer.rs +++ b/src/js/consumers/pull/consumer.rs @@ -4,7 +4,10 @@ use futures_util::StreamExt; use pyo3::{Bound, PyAny, Python}; use tokio::sync::RwLock; -use crate::{exceptions::rust_err::NatsrpyResult, utils::natsrpy_future}; +use crate::{ + exceptions::rust_err::NatsrpyResult, + utils::{futures::natsrpy_future_with_timeout, py_types::TimeoutValue}, +}; type NatsPullConsumer = async_nats::jetstream::consumer::Consumer; @@ -35,6 +38,7 @@ impl PullConsumer { expires=None, min_pending=None, min_ack_pending=None, + timeout=None, ))] pub fn fetch<'py>( &self, @@ -47,12 +51,14 @@ impl PullConsumer { expires: Option, min_pending: Option, min_ack_pending: Option, + timeout: Option, ) -> NatsrpyResult> { let ctx = self.consumer.clone(); + + // Because we borrow cosnumer lock + // later for modifications of fetchbuilder. #[allow(clippy::significant_drop_tightening)] - natsrpy_future(py, async move { - // Because we borrow created value - // later for modifications. + natsrpy_future_with_timeout(py, timeout, async move { let consumer = ctx.read().await; let mut fetch_builder = consumer.fetch(); if let Some(max_messages) = max_messages { diff --git a/src/nats_cls.rs b/src/nats_cls.rs index 83b3e72..c4c4e6e 100644 --- a/src/nats_cls.rs +++ b/src/nats_cls.rs @@ -9,7 +9,12 @@ use tokio::sync::RwLock; use crate::{ exceptions::rust_err::NatsrpyError, subscription::Subscription, - utils::{headers::NatsrpyHeadermapExt, natsrpy_future, py_types::SendableValue}, + utils::{ + futures::natsrpy_future_with_timeout, + headers::NatsrpyHeadermapExt, + natsrpy_future, + py_types::{SendableValue, TimeoutValue}, + }, }; #[pyo3::pyclass(name = "Nats")] @@ -23,8 +28,8 @@ pub struct NatsCls { read_buffer_capacity: u16, sender_capacity: usize, max_reconnects: Option, - connection_timeout: Duration, - request_timeout: Option, + connection_timeout: TimeoutValue, + request_timeout: Option, } #[pyo3::pymethods] @@ -40,8 +45,8 @@ impl NatsCls { read_buffer_capacity=65535, sender_capacity=128, max_reconnects=None, - connection_timeout=Duration::from_secs(5), - request_timeout=Duration::from_secs(10), + connection_timeout=TimeoutValue::FloatSecs(5.0), + request_timeout=TimeoutValue::FloatSecs(10.0), ))] fn __new__( addrs: Vec, @@ -52,8 +57,8 @@ impl NatsCls { read_buffer_capacity: u16, sender_capacity: usize, max_reconnects: Option, - connection_timeout: Duration, - request_timeout: Option, + connection_timeout: TimeoutValue, + request_timeout: Option, ) -> Self { Self { nats_session: Arc::new(RwLock::new(None)), @@ -80,8 +85,8 @@ impl NatsCls { } conn_opts = conn_opts .max_reconnects(self.max_reconnects) - .connection_timeout(self.connection_timeout) - .request_timeout(self.request_timeout) + .connection_timeout(self.connection_timeout.into()) + .request_timeout(self.request_timeout.map(Into::into)) .read_buffer_capacity(self.read_buffer_capacity) .client_capacity(self.sender_capacity); @@ -94,23 +99,24 @@ impl NatsCls { let session = self.nats_session.clone(); let address = self.addr.clone(); - let startup_future = async move { - if session.read().await.is_some() { - return Err(NatsrpyError::SessionError( - "NATS session already exists".to_string(), - )); - } - // Scoping for early-dropping of a guard. - { - let mut sesion_guard = session.write().await; - *sesion_guard = Some(conn_opts.connect(address).await?); - } - Ok(()) - }; let timeout = self.connection_timeout; - return Ok(natsrpy_future(py, async move { - tokio::time::timeout(timeout, startup_future).await? - })?); + return Ok(natsrpy_future_with_timeout( + py, + Some(timeout), + async move { + if session.read().await.is_some() { + return Err(NatsrpyError::SessionError( + "NATS session already exists".to_string(), + )); + } + // Scoping for early-dropping of a guard. + { + let mut sesion_guard = session.write().await; + *sesion_guard = Some(conn_opts.connect(address).await?); + } + Ok(()) + }, + )?); } #[pyo3(signature = (subject, payload, *, headers=None, reply=None, err_on_disconnect = false))] diff --git a/src/subscription.rs b/src/subscription.rs index bce844a..b9b1aff 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -7,7 +7,7 @@ use tokio::sync::Mutex; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, - utils::natsrpy_future, + utils::futures::natsrpy_future_with_timeout, }; #[pyo3::pyclass] @@ -39,21 +39,12 @@ impl Subscription { let Some(inner) = self.inner.clone() else { return Err(NatsrpyError::NotInitialized); }; - - let future = async move { + natsrpy_future_with_timeout(py, timeout, async move { let Some(message) = inner.lock().await.next().await else { return Err(PyStopAsyncIteration::new_err("End of the stream.").into()); }; crate::message::Message::try_from(message) - }; - - natsrpy_future(py, async move { - if let Some(timeout) = timeout { - tokio::time::timeout(timeout, future).await? - } else { - future.await - } }) } diff --git a/src/utils/futures.rs b/src/utils/futures.rs index bf83534..a6b2a10 100644 --- a/src/utils/futures.rs +++ b/src/utils/futures.rs @@ -1,6 +1,8 @@ +use std::time::Duration; + use pyo3::{Bound, IntoPyObject, PyAny, Python}; -use crate::exceptions::rust_err::NatsrpyResult; +use crate::exceptions::rust_err::{NatsrpyError, NatsrpyResult}; pub fn natsrpy_future(py: Python, fut: F) -> NatsrpyResult> where @@ -11,3 +13,31 @@ where pyo3_async_runtimes::tokio::future_into_py(py, async { fut.await.map_err(Into::into) })?; Ok(res) } + +pub fn natsrpy_future_with_timeout( + py: Python, + timeout: Option, + fut: F, +) -> NatsrpyResult> +where + F: Future> + Send + 'static, + T: for<'py> IntoPyObject<'py> + Send + 'static, + D: Into, +{ + let timeout = timeout.map(Into::into); + let res = pyo3_async_runtimes::tokio::future_into_py(py, async move { + if let Some(timeout) = timeout { + tokio::time::timeout(timeout, fut) + .await + // First map_err is for timeout + .map_err(NatsrpyError::from)? + // This one is for result returned from + // a future. + .map_err(Into::into) + } else { + // Simple return with error mapping. + fut.await.map_err(Into::into) + } + })?; + Ok(res) +} diff --git a/src/utils/py_types.rs b/src/utils/py_types.rs index f1130c6..b0b546a 100644 --- a/src/utils/py_types.rs +++ b/src/utils/py_types.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use pyo3::{ FromPyObject, types::{PyBytes, PyBytesMethods}, @@ -40,3 +42,35 @@ impl From for bytes::Bytes { } } } + +#[derive(Clone, Debug, Copy, PartialEq, PartialOrd)] +pub enum TimeoutValue { + Duration(Duration), + FloatSecs(f32), +} + +impl From for Duration { + fn from(value: TimeoutValue) -> Self { + match value { + TimeoutValue::Duration(duration) => duration, + TimeoutValue::FloatSecs(fsecs) => Self::from_secs_f32(fsecs), + } + } +} + +impl<'py> FromPyObject<'_, 'py> for TimeoutValue { + type Error = NatsrpyError; + + fn extract(obj: pyo3::Borrowed<'_, 'py, pyo3::PyAny>) -> Result { + #[allow(clippy::option_if_let_else)] + if let Ok(fsec) = obj.extract::() { + Ok(Self::FloatSecs(fsec)) + } else if let Ok(duration) = obj.extract::() { + Ok(Self::Duration(duration)) + } else { + Err(NatsrpyError::InvalidArgument(String::from( + "As timeouts only float or timedelta are accepted.", + ))) + } + } +}