diff --git a/examples/datafusion-ffi-example/src/utils.rs b/examples/datafusion-ffi-example/src/utils.rs index 5f2865aa2..14fab59e4 100644 --- a/examples/datafusion-ffi-example/src/utils.rs +++ b/examples/datafusion-ffi-example/src/utils.rs @@ -19,7 +19,6 @@ use std::ptr::NonNull; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use pyo3::exceptions::PyValueError; -use pyo3::ffi::c_str; use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods}; use pyo3::types::PyCapsule; use pyo3::{Bound, PyAny, PyResult}; @@ -35,30 +34,10 @@ pub(crate) fn ffi_logical_codec_from_pycapsule( }; let capsule = capsule.cast::()?; - validate_pycapsule(capsule, "datafusion_logical_extension_codec")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))? + .pointer_checked(Some(c"datafusion_logical_extension_codec"))? .cast(); let codec = unsafe { data.as_ref() }; Ok(codec.clone()) } - -pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { - let capsule_name = capsule.name()?; - if capsule_name.is_none() { - return Err(PyValueError::new_err(format!( - "Expected {name} PyCapsule to have name set." - ))); - } - - let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? }; - if capsule_name != name { - return Err(PyValueError::new_err(format!( - "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" - ))); - } - - Ok(()) -} diff --git a/src/array.rs b/src/array.rs index 1ff08dfb2..f284fa9de 100644 --- a/src/array.rs +++ b/src/array.rs @@ -22,13 +22,11 @@ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{Field, FieldRef}; use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; use arrow::pyarrow::ToPyArrow; -use pyo3::ffi::c_str; use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods}; use pyo3::types::PyCapsule; use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods}; use crate::errors::PyDataFusionResult; -use crate::utils::validate_pycapsule; /// A Python object which implements the Arrow PyCapsule for importing /// into other libraries. @@ -53,10 +51,8 @@ impl PyArrowArrayExportable { requested_schema: Option>, ) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> { let field = if let Some(schema_capsule) = requested_schema { - validate_pycapsule(&schema_capsule, "arrow_schema")?; - let data: NonNull = schema_capsule - .pointer_checked(Some(c_str!("arrow_schema")))? + .pointer_checked(Some(c"arrow_schema"))? .cast(); let schema_ptr = unsafe { data.as_ref() }; let desired_field = Field::try_from(schema_ptr)?; diff --git a/src/catalog.rs b/src/catalog.rs index 43325c30d..e571ef490 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -32,7 +32,6 @@ use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::schema_provider::FFI_SchemaProvider; use pyo3::IntoPyObjectExt; use pyo3::exceptions::PyKeyError; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::PyCapsule; @@ -40,8 +39,7 @@ use crate::dataset::Dataset; use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err}; use crate::table::PyTable; use crate::utils::{ - create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule, - wait_for_future, + create_logical_extension_capsule, extract_logical_extension_codec, wait_for_future, }; #[pyclass( @@ -658,9 +656,8 @@ fn extract_catalog_provider_from_pyobj( } let provider = if let Ok(capsule) = catalog_provider.cast::() { - validate_pycapsule(capsule, "datafusion_catalog_provider")?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_catalog_provider")))? + .pointer_checked(Some(c"datafusion_catalog_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); @@ -691,10 +688,8 @@ fn extract_schema_provider_from_pyobj( } let provider = if let Ok(capsule) = schema_provider.cast::() { - validate_pycapsule(capsule, "datafusion_schema_provider")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_schema_provider")))? + .pointer_checked(Some(c"datafusion_schema_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); diff --git a/src/context.rs b/src/context.rs index 2eaf5a737..d83826518 100644 --- a/src/context.rs +++ b/src/context.rs @@ -55,7 +55,6 @@ use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyKeyError, PyValueError}; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; use url::Url; @@ -84,7 +83,7 @@ use crate::udtf::PyTableFunction; use crate::udwf::PyWindowUDF; use crate::utils::{ create_logical_extension_capsule, extract_logical_extension_codec, get_global_ctx, - get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future, + get_tokio_runtime, spawn_future, wait_for_future, }; /// Configuration options for a SessionContext @@ -671,12 +670,9 @@ impl PySessionContext { .call1((codec_capsule,))?; } - let provider = if let Ok(capsule) = provider.cast::().map_err(py_datafusion_err) - { - validate_pycapsule(capsule, "datafusion_catalog_provider_list")?; - + let provider = if let Ok(capsule) = provider.cast::() { let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_catalog_provider_list")))? + .pointer_checked(Some(c"datafusion_catalog_provider_list"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); @@ -709,12 +705,9 @@ impl PySessionContext { .call1((codec_capsule,))?; } - let provider = if let Ok(capsule) = provider.cast::().map_err(py_datafusion_err) - { - validate_pycapsule(capsule, "datafusion_catalog_provider")?; - + let provider = if let Ok(capsule) = provider.cast::() { let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_catalog_provider")))? + .pointer_checked(Some(c"datafusion_catalog_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); diff --git a/src/dataframe.rs b/src/dataframe.rs index eb1fa4a81..65c20ecd5 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -45,7 +45,6 @@ use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; use pyo3::PyErr; use pyo3::exceptions::PyValueError; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; @@ -58,7 +57,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::{PyRecordBatchStream, poll_next_batch}; use crate::sql::logical::PyLogicalPlan; use crate::table::{PyTable, TempViewTable}; -use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future}; +use crate::utils::{is_ipython_env, spawn_future, wait_for_future}; /// File-level static CStr for the Arrow array stream capsule name. static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream"); @@ -1117,10 +1116,8 @@ impl PyDataFrame { let mut projection: Option = None; if let Some(schema_capsule) = requested_schema { - validate_pycapsule(&schema_capsule, "arrow_schema")?; - let data: NonNull = schema_capsule - .pointer_checked(Some(c_str!("arrow_schema")))? + .pointer_checked(Some(c"arrow_schema"))? .cast(); let schema_ptr = unsafe { data.as_ref() }; let desired_schema = Schema::try_from(schema_ptr)?; diff --git a/src/udaf.rs b/src/udaf.rs index 7ba499c66..956b648c3 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -27,14 +27,13 @@ use datafusion::logical_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf, }; use datafusion_ffi::udaf::FFI_AggregateUDF; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple}; use crate::common::data_type::PyScalarValue; use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, validate_pycapsule}; +use crate::utils::parse_volatility; #[derive(Debug)] struct RustAccumulator { @@ -157,10 +156,8 @@ pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { } fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { - validate_pycapsule(capsule, "datafusion_aggregate_udf")?; - let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_aggregate_udf")))? + .pointer_checked(Some(c"datafusion_aggregate_udf"))? .cast(); let udaf = unsafe { data.as_ref() }; let udaf: Arc = udaf.into(); diff --git a/src/udf.rs b/src/udf.rs index 2d60abc09..d95acd3ba 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -32,14 +32,13 @@ use datafusion::logical_expr::{ Volatility, }; use datafusion_ffi::udf::FFI_ScalarUDF; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple}; use crate::array::PyArrowArrayExportable; -use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err}; +use crate::errors::{PyDataFusionResult, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, validate_pycapsule}; +use crate::utils::parse_volatility; /// This struct holds the Python written function that is a /// ScalarUDF. @@ -194,11 +193,9 @@ impl PyScalarUDF { pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { if func.hasattr("__datafusion_scalar_udf__")? { let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_scalar_udf")?; - + let capsule = capsule.cast::().map_err(to_datafusion_err)?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_scalar_udf")))? + .pointer_checked(Some(c"datafusion_scalar_udf"))? .cast(); let udf = unsafe { data.as_ref() }; let udf: Arc = udf.into(); diff --git a/src/udtf.rs b/src/udtf.rs index 24df93e2b..9371732dc 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -24,7 +24,6 @@ use datafusion::logical_expr::Expr; use datafusion_ffi::udtf::FFI_TableFunction; use pyo3::IntoPyObjectExt; use pyo3::exceptions::{PyImportError, PyTypeError}; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyTuple, PyType}; @@ -32,7 +31,6 @@ use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, to_datafusion_err}; use crate::expr::PyExpr; use crate::table::PyTable; -use crate::utils::validate_pycapsule; /// Represents a user defined table function #[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")] @@ -73,11 +71,9 @@ impl PyTableFunction { err } })?; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_table_function")?; - + let capsule = capsule.cast::()?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_table_function")))? + .pointer_checked(Some(c"datafusion_table_function"))? .cast(); let ffi_func = unsafe { data.as_ref() }; let foreign_func: Arc = ffi_func.to_owned().into(); diff --git a/src/udwf.rs b/src/udwf.rs index de63e2f9a..e935cc764 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -33,14 +33,13 @@ use datafusion::logical_expr::{ use datafusion::scalar::ScalarValue; use datafusion_ffi::udwf::FFI_WindowUDF; use pyo3::exceptions::PyValueError; -use pyo3::ffi::c_str; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyList, PyTuple}; use crate::common::data_type::PyScalarValue; -use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err}; +use crate::errors::{PyDataFusionResult, to_datafusion_err}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, validate_pycapsule}; +use crate::utils::parse_volatility; #[derive(Debug)] struct RustPartitionEvaluator { @@ -262,11 +261,9 @@ impl PyWindowUDF { func }; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - validate_pycapsule(capsule, "datafusion_window_udf")?; - + let capsule = capsule.cast::().map_err(to_datafusion_err)?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_window_udf")))? + .pointer_checked(Some(c"datafusion_window_udf"))? .cast(); let udwf = unsafe { data.as_ref() }; let udwf: Arc = udwf.into(); diff --git a/src/utils.rs b/src/utils.rs index 5085018f7..e69890d5c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -26,8 +26,7 @@ use datafusion::logical_expr::Volatility; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; use pyo3::IntoPyObjectExt; -use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError}; -use pyo3::ffi::c_str; +use pyo3::exceptions::{PyImportError, PyTypeError}; use pyo3::prelude::*; use pyo3::types::{PyCapsule, PyType}; use tokio::runtime::Runtime; @@ -36,7 +35,7 @@ use tokio::time::sleep; use crate::TokioRuntime; use crate::context::PySessionContext; -use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err}; +use crate::errors::{PyDataFusionError, PyDataFusionResult, to_datafusion_err}; /// Utility to get the Tokio Runtime from Python #[inline] @@ -152,24 +151,6 @@ pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { }) } -pub(crate) fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { - let capsule_name = capsule.name()?; - if capsule_name.is_none() { - return Err(PyValueError::new_err(format!( - "Expected {name} PyCapsule to have name set." - ))); - } - - let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? }; - if capsule_name != name { - return Err(PyValueError::new_err(format!( - "Expected name '{name}' in PyCapsule, instead got '{capsule_name}'" - ))); - } - - Ok(()) -} - pub(crate) fn table_provider_from_pycapsule<'py>( mut obj: Bound<'py, PyAny>, session: Bound<'py, PyAny>, @@ -187,11 +168,9 @@ pub(crate) fn table_provider_from_pycapsule<'py>( })?; } - if let Ok(capsule) = obj.cast::().map_err(py_datafusion_err) { - validate_pycapsule(capsule, "datafusion_table_provider")?; - + if let Ok(capsule) = obj.cast::() { let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_table_provider")))? + .pointer_checked(Some(c"datafusion_table_provider"))? .cast(); let provider = unsafe { data.as_ref() }; let provider: Arc = provider.into(); @@ -216,12 +195,10 @@ pub(crate) fn extract_logical_extension_codec( } else { obj }; - let capsule = capsule.cast::().map_err(py_datafusion_err)?; - - validate_pycapsule(capsule, "datafusion_logical_extension_codec")?; + let capsule = capsule.cast::()?; let data: NonNull = capsule - .pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))? + .pointer_checked(Some(c"datafusion_logical_extension_codec"))? .cast(); let codec = unsafe { data.as_ref() }; Ok(Arc::new(codec.clone()))