Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions examples/datafusion-ffi-example/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use std::ptr::NonNull;

use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use pyo3::exceptions::PyValueError;
use pyo3::ffi::c_str;
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
use pyo3::types::PyCapsule;
use pyo3::{Bound, PyAny, PyResult};
Expand All @@ -35,30 +34,10 @@ pub(crate) fn ffi_logical_codec_from_pycapsule(
};

let capsule = capsule.cast::<PyCapsule>()?;
validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;

let data: NonNull<FFI_LogicalExtensionCodec> = capsule
.pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))?
.pointer_checked(Some(c"datafusion_logical_extension_codec"))?
.cast();
let codec = unsafe { data.as_ref() };

Ok(codec.clone())
}

pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if capsule_name.is_none() {
return Err(PyValueError::new_err(format!(
"Expected {name} PyCapsule to have name set."
)));
}

let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
if capsule_name != name {
return Err(PyValueError::new_err(format!(
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
)));
}
Comment on lines -50 to -61
Copy link
Contributor

@kevinjqliu kevinjqliu Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We lose the error messages by removing this function. Would this be potentially confusing to end users?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The error messages from PyCapsule_GetPointer are ValueError("PyCapsule_GetPointer called with invalid PyCapsule object") and ValueError("PyCapsule_GetPointer called with incorrect name"). These error only happen if the "magic" methods return invalid capsules so they should only be encountered by people trying to implement them and not regular DataFusion users. Also, the stack trace should give a hint of which capsule is raising the error.

An option might be to wrap the error in a nicer message. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out the error messages from PyCapsule_GetPointer! They do seem fairly similar 😄

I still think it could be helpful to expose validate_pycapsule as a public function for implementers. While PyCapsule_GetPointer errors will mostly be encountered by developers implementing the “magic” methods, having a helper with clearer validation and error messages could make debugging easier. Since https://github.com/apache/datafusion-python/pull/1414/files#diff-6dc4fc730855f068c83d61b03b141832727bb30bc71788f763761ef8cd2bde7eR154 introduces a new datafusion-python-util crate that includes this helper, exposing it there could provide implementers with a simple way to validate capsules before extracting pointers.


Ok(())
}
6 changes: 1 addition & 5 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{Field, FieldRef};
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow::pyarrow::ToPyArrow;
use pyo3::ffi::c_str;
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
use pyo3::types::PyCapsule;
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};

use crate::errors::PyDataFusionResult;
use crate::utils::validate_pycapsule;

/// A Python object which implements the Arrow PyCapsule for importing
/// into other libraries.
Expand All @@ -53,10 +51,8 @@ impl PyArrowArrayExportable {
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
let field = if let Some(schema_capsule) = requested_schema {
validate_pycapsule(&schema_capsule, "arrow_schema")?;

let data: NonNull<FFI_ArrowSchema> = schema_capsule
.pointer_checked(Some(c_str!("arrow_schema")))?
.pointer_checked(Some(c"arrow_schema"))?
.cast();
let schema_ptr = unsafe { data.as_ref() };
let desired_field = Field::try_from(schema_ptr)?;
Expand Down
11 changes: 3 additions & 8 deletions src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,14 @@ use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::schema_provider::FFI_SchemaProvider;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::PyKeyError;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

use crate::dataset::Dataset;
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::table::PyTable;
use crate::utils::{
create_logical_extension_capsule, extract_logical_extension_codec, validate_pycapsule,
wait_for_future,
create_logical_extension_capsule, extract_logical_extension_codec, wait_for_future,
};

#[pyclass(
Expand Down Expand Up @@ -658,9 +656,8 @@ fn extract_catalog_provider_from_pyobj(
}

let provider = if let Ok(capsule) = catalog_provider.cast::<PyCapsule>() {
validate_pycapsule(capsule, "datafusion_catalog_provider")?;
let data: NonNull<FFI_CatalogProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
.pointer_checked(Some(c"datafusion_catalog_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
Expand Down Expand Up @@ -691,10 +688,8 @@ fn extract_schema_provider_from_pyobj(
}

let provider = if let Ok(capsule) = schema_provider.cast::<PyCapsule>() {
validate_pycapsule(capsule, "datafusion_schema_provider")?;

let data: NonNull<FFI_SchemaProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_schema_provider")))?
.pointer_checked(Some(c"datafusion_schema_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn SchemaProvider + Send> = provider.into();
Expand Down
17 changes: 5 additions & 12 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use url::Url;
Expand Down Expand Up @@ -84,7 +83,7 @@ use crate::udtf::PyTableFunction;
use crate::udwf::PyWindowUDF;
use crate::utils::{
create_logical_extension_capsule, extract_logical_extension_codec, get_global_ctx,
get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future,
get_tokio_runtime, spawn_future, wait_for_future,
};

/// Configuration options for a SessionContext
Expand Down Expand Up @@ -671,12 +670,9 @@ impl PySessionContext {
.call1((codec_capsule,))?;
}

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
{
validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
let data: NonNull<FFI_CatalogProviderList> = capsule
.pointer_checked(Some(c_str!("datafusion_catalog_provider_list")))?
.pointer_checked(Some(c"datafusion_catalog_provider_list"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
Expand Down Expand Up @@ -709,12 +705,9 @@ impl PySessionContext {
.call1((codec_capsule,))?;
}

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>().map_err(py_datafusion_err)
{
validate_pycapsule(capsule, "datafusion_catalog_provider")?;

let provider = if let Ok(capsule) = provider.cast::<PyCapsule>() {
let data: NonNull<FFI_CatalogProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_catalog_provider")))?
.pointer_checked(Some(c"datafusion_catalog_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn CatalogProvider + Send> = provider.into();
Expand Down
7 changes: 2 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ use futures::{StreamExt, TryStreamExt};
use parking_lot::Mutex;
use pyo3::PyErr;
use pyo3::exceptions::PyValueError;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
Expand All @@ -58,7 +57,7 @@ use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::{PyRecordBatchStream, poll_next_batch};
use crate::sql::logical::PyLogicalPlan;
use crate::table::{PyTable, TempViewTable};
use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};
use crate::utils::{is_ipython_env, spawn_future, wait_for_future};

/// File-level static CStr for the Arrow array stream capsule name.
static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
Expand Down Expand Up @@ -1117,10 +1116,8 @@ impl PyDataFrame {
let mut projection: Option<SchemaRef> = None;

if let Some(schema_capsule) = requested_schema {
validate_pycapsule(&schema_capsule, "arrow_schema")?;

let data: NonNull<FFI_ArrowSchema> = schema_capsule
.pointer_checked(Some(c_str!("arrow_schema")))?
.pointer_checked(Some(c"arrow_schema"))?
.cast();
let schema_ptr = unsafe { data.as_ref() };
let desired_schema = Schema::try_from(schema_ptr)?;
Expand Down
7 changes: 2 additions & 5 deletions src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ use datafusion::logical_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
};
use datafusion_ffi::udaf::FFI_AggregateUDF;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};

use crate::common::data_type::PyScalarValue;
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, validate_pycapsule};
use crate::utils::parse_volatility;

#[derive(Debug)]
struct RustAccumulator {
Expand Down Expand Up @@ -157,10 +156,8 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;

let data: NonNull<FFI_AggregateUDF> = capsule
.pointer_checked(Some(c_str!("datafusion_aggregate_udf")))?
.pointer_checked(Some(c"datafusion_aggregate_udf"))?
.cast();
let udaf = unsafe { data.as_ref() };
let udaf: Arc<dyn AggregateUDFImpl> = udaf.into();
Expand Down
11 changes: 4 additions & 7 deletions src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ use datafusion::logical_expr::{
Volatility,
};
use datafusion_ffi::udf::FFI_ScalarUDF;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};

use crate::array::PyArrowArrayExportable;
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, validate_pycapsule};
use crate::utils::parse_volatility;

/// This struct holds the Python written function that is a
/// ScalarUDF.
Expand Down Expand Up @@ -194,11 +193,9 @@ impl PyScalarUDF {
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if func.hasattr("__datafusion_scalar_udf__")? {
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_scalar_udf")?;

let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
let data: NonNull<FFI_ScalarUDF> = capsule
.pointer_checked(Some(c_str!("datafusion_scalar_udf")))?
.pointer_checked(Some(c"datafusion_scalar_udf"))?
.cast();
let udf = unsafe { data.as_ref() };
let udf: Arc<dyn ScalarUDFImpl> = udf.into();
Expand Down
8 changes: 2 additions & 6 deletions src/udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ use datafusion::logical_expr::Expr;
use datafusion_ffi::udtf::FFI_TableFunction;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyImportError, PyTypeError};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple, PyType};

use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, to_datafusion_err};
use crate::expr::PyExpr;
use crate::table::PyTable;
use crate::utils::validate_pycapsule;

/// Represents a user defined table function
#[pyclass(from_py_object, frozen, name = "TableFunction", module = "datafusion")]
Expand Down Expand Up @@ -73,11 +71,9 @@ impl PyTableFunction {
err
}
})?;
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_function")?;

let capsule = capsule.cast::<PyCapsule>()?;
let data: NonNull<FFI_TableFunction> = capsule
.pointer_checked(Some(c_str!("datafusion_table_function")))?
.pointer_checked(Some(c"datafusion_table_function"))?
.cast();
let ffi_func = unsafe { data.as_ref() };
let foreign_func: Arc<dyn TableFunctionImpl> = ffi_func.to_owned().into();
Expand Down
11 changes: 4 additions & 7 deletions src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ use datafusion::logical_expr::{
use datafusion::scalar::ScalarValue;
use datafusion_ffi::udwf::FFI_WindowUDF;
use pyo3::exceptions::PyValueError;
use pyo3::ffi::c_str;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyList, PyTuple};

use crate::common::data_type::PyScalarValue;
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, validate_pycapsule};
use crate::utils::parse_volatility;

#[derive(Debug)]
struct RustPartitionEvaluator {
Expand Down Expand Up @@ -262,11 +261,9 @@ impl PyWindowUDF {
func
};

let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_window_udf")?;

let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
let data: NonNull<FFI_WindowUDF> = capsule
.pointer_checked(Some(c_str!("datafusion_window_udf")))?
.pointer_checked(Some(c"datafusion_window_udf"))?
.cast();
let udwf = unsafe { data.as_ref() };
let udwf: Arc<dyn WindowUDFImpl> = udwf.into();
Expand Down
35 changes: 6 additions & 29 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ use datafusion::logical_expr::Volatility;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::table_provider::FFI_TableProvider;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyImportError, PyTypeError, PyValueError};
use pyo3::ffi::c_str;
use pyo3::exceptions::{PyImportError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyType};
use tokio::runtime::Runtime;
Expand All @@ -36,7 +35,7 @@ use tokio::time::sleep;

use crate::TokioRuntime;
use crate::context::PySessionContext;
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::errors::{PyDataFusionError, PyDataFusionResult, to_datafusion_err};

/// Utility to get the Tokio Runtime from Python
#[inline]
Expand Down Expand Up @@ -152,24 +151,6 @@ pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
})
}

pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
let capsule_name = capsule.name()?;
if capsule_name.is_none() {
return Err(PyValueError::new_err(format!(
"Expected {name} PyCapsule to have name set."
)));
}

let capsule_name = unsafe { capsule_name.unwrap().as_cstr().to_str()? };
if capsule_name != name {
return Err(PyValueError::new_err(format!(
"Expected name '{name}' in PyCapsule, instead got '{capsule_name}'"
)));
}

Ok(())
}

pub(crate) fn table_provider_from_pycapsule<'py>(
mut obj: Bound<'py, PyAny>,
session: Bound<'py, PyAny>,
Expand All @@ -187,11 +168,9 @@ pub(crate) fn table_provider_from_pycapsule<'py>(
})?;
}

if let Ok(capsule) = obj.cast::<PyCapsule>().map_err(py_datafusion_err) {
validate_pycapsule(capsule, "datafusion_table_provider")?;

if let Ok(capsule) = obj.cast::<PyCapsule>() {
let data: NonNull<FFI_TableProvider> = capsule
.pointer_checked(Some(c_str!("datafusion_table_provider")))?
.pointer_checked(Some(c"datafusion_table_provider"))?
.cast();
let provider = unsafe { data.as_ref() };
let provider: Arc<dyn TableProvider> = provider.into();
Expand All @@ -216,12 +195,10 @@ pub(crate) fn extract_logical_extension_codec(
} else {
obj
};
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;

validate_pycapsule(capsule, "datafusion_logical_extension_codec")?;
let capsule = capsule.cast::<PyCapsule>()?;

let data: NonNull<FFI_LogicalExtensionCodec> = capsule
.pointer_checked(Some(c_str!("datafusion_logical_extension_codec")))?
.pointer_checked(Some(c"datafusion_logical_extension_codec"))?
.cast();
let codec = unsafe { data.as_ref() };
Ok(Arc::new(codec.clone()))
Expand Down