From c92d4b5417b48eab809c538f19897b8005d71293 Mon Sep 17 00:00:00 2001 From: Liu Zhengyun Date: Sat, 14 Mar 2026 14:44:14 +0800 Subject: [PATCH] modify model loading --- .../ainode/it/AINodeInstanceManagementIT.java | 9 +- .../ainode/core/inference/pool_controller.py | 110 ++++++++++++------ .../pool_scheduler/basic_pool_scheduler.py | 4 +- 3 files changed, 83 insertions(+), 40 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index 15ddce11ede11..f59490613db2a 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -87,16 +87,17 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter // Load sundial to each device statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); + // Unload sundial from each device + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); // Load timer_xl to each device statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); - - // Clean every device - statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); + // Unload timer_xl from each device statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES)); checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); + } private static final int LOOP_CNT = 10; diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index f4b3d23d36fc1..73bf01c3f7a63 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -141,7 +141,9 @@ def load_model(self, model_id: str, device_id_list: list[torch.device]): model_id (str): The ID of the model to be loaded. device_id_list (list[torch.device]): List of device_ids where the model should be loaded. """ - self._task_queue.put((self._load_model_task, (model_id, device_id_list), {})) + self._task_queue.put( + (self._load_one_model_task, (model_id, device_id_list), {}) + ) def unload_model(self, model_id: str, device_id_list: list[torch.device]): """ @@ -150,7 +152,9 @@ def unload_model(self, model_id: str, device_id_list: list[torch.device]): model_id (str): The ID of the model to be unloaded. device_id_list (list[torch.device]): List of device_ids where the model should be unloaded. """ - self._task_queue.put((self._unload_model_task, (model_id, device_id_list), {})) + self._task_queue.put( + (self._unload_one_model_task, (model_id, device_id_list), {}) + ) def show_loaded_models( self, device_id_list: list[torch.device] @@ -196,60 +200,92 @@ def _worker_loop(self): finally: self._task_queue.task_done() - def _load_model_task(self, model_id: str, device_id_list: list[torch.device]): - def _load_model_on_device_task(device_id: torch.device): - if not self.has_request_pools(model_id, device_id): - actions = self._pool_scheduler.schedule_load_model_to_device( - self._model_manager.get_model_info(model_id), device_id - ) - for action in actions: - if action.action == ScaleActionType.SCALE_UP: - self._expand_pools_on_device( - action.model_id, device_id, action.amount - ) - elif action.action == ScaleActionType.SCALE_DOWN: - self._shrink_pools_on_device( - action.model_id, device_id, action.amount - ) + def _load_one_model_task(self, model_id: str, device_id_list: list[torch.device]): + def _load_one_model_on_device_task(device_id: torch.device): + if not self.has_pool_on_device(device_id): + self._expand_pools_on_device(model_id, device_id, 1) else: logger.info( - f"[Inference][{device_id}] Model {model_id} is already installed." + f"[Inference][{device_id}] There are already pools on this device." ) load_model_futures = self._executor.submit_batch( - device_id_list, _load_model_on_device_task + device_id_list, _load_one_model_on_device_task ) concurrent.futures.wait( load_model_futures, return_when=concurrent.futures.ALL_COMPLETED ) - def _unload_model_task(self, model_id: str, device_id_list: list[torch.device]): - def _unload_model_on_device_task(device_id: torch.device): + def _unload_one_model_task(self, model_id: str, device_id_list: list[torch.device]): + def _unload_one_model_on_device_task(device_id: torch.device): if self.has_request_pools(model_id, device_id): - actions = self._pool_scheduler.schedule_unload_model_from_device( - self._model_manager.get_model_info(model_id), device_id - ) - for action in actions: - if action.action == ScaleActionType.SCALE_DOWN: - self._shrink_pools_on_device( - action.model_id, device_id, action.amount - ) - elif action.action == ScaleActionType.SCALE_UP: - self._expand_pools_on_device( - action.model_id, device_id, action.amount - ) + self._shrink_pools_on_device(model_id, device_id, 1) else: logger.info( f"[Inference][{device_id}] Model {model_id} is not installed." ) unload_model_futures = self._executor.submit_batch( - device_id_list, _unload_model_on_device_task + device_id_list, _unload_one_model_on_device_task ) concurrent.futures.wait( unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED ) + # def _load_model_task(self, model_id: str, device_id_list: list[torch.device]): + # def _load_model_on_device_task(device_id: torch.device): + # if not self.has_request_pools(model_id, device_id): + # actions = self._pool_scheduler.schedule_load_model_to_device( + # self._model_manager.get_model_info(model_id), device_id + # ) + # for action in actions: + # if action.action == ScaleActionType.SCALE_UP: + # self._expand_pools_on_device( + # action.model_id, device_id, action.amount + # ) + # elif action.action == ScaleActionType.SCALE_DOWN: + # self._shrink_pools_on_device( + # action.model_id, device_id, action.amount + # ) + # else: + # logger.info( + # f"[Inference][{device_id}] Model {model_id} is already installed." + # ) + # + # load_model_futures = self._executor.submit_batch( + # device_id_list, _load_model_on_device_task + # ) + # concurrent.futures.wait( + # load_model_futures, return_when=concurrent.futures.ALL_COMPLETED + # ) + # + # def _unload_model_task(self, model_id: str, device_id_list: list[torch.device]): + # def _unload_model_on_device_task(device_id: torch.device): + # if self.has_request_pools(model_id, device_id): + # actions = self._pool_scheduler.schedule_unload_model_from_device( + # self._model_manager.get_model_info(model_id), device_id + # ) + # for action in actions: + # if action.action == ScaleActionType.SCALE_DOWN: + # self._shrink_pools_on_device( + # action.model_id, device_id, action.amount + # ) + # elif action.action == ScaleActionType.SCALE_UP: + # self._expand_pools_on_device( + # action.model_id, device_id, action.amount + # ) + # else: + # logger.info( + # f"[Inference][{device_id}] Model {model_id} is not installed." + # ) + # + # unload_model_futures = self._executor.submit_batch( + # device_id_list, _unload_model_on_device_task + # ) + # concurrent.futures.wait( + # unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED + # ) + def _expand_pools_on_device( self, model_id: str, device_id: torch.device, count: int ): @@ -462,6 +498,12 @@ def has_running_pools(self, model_id: str) -> bool: return True return False + def has_pool_on_device(self, device_id: torch.device) -> bool: + """ + Check if there are pools on the given device_id. + """ + return any(device_id in pools for pools in self._request_pool_map.values()) + def get_request_pools_group( self, model_id: str, device_id: torch.device ) -> Optional[PoolGroup]: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 5ce9eceba1453..49aebe8a89e70 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -46,7 +46,7 @@ def _estimate_shared_pool_size_by_total_mem( new_model_info: Optional[ModelInfo] = None, ) -> Dict[str, int]: """ - Estimate pool counts for (existing_model_ids + new_model_id) by equally + Estimate pool counts for (existing_model_infos + new_model_info) by equally splitting the device's TOTAL memory among models. Returns: @@ -60,7 +60,7 @@ def _estimate_shared_pool_size_by_total_mem( ) raise ModelNotExistException(new_model_info.model_id) - # Extract unique model IDs + # Extract unique model infos all_models = existing_model_infos + ( [new_model_info] if new_model_info is not None else [] )