Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter
// Load sundial to each device
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES));
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
// Unload sundial from each device
statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES));
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());

// Load timer_xl to each device
statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES));
checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());

// Clean every device
statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES));
// Unload timer_xl from each device
statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES));
checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());

}

private static final int LOOP_CNT = 10;
Expand Down
110 changes: 76 additions & 34 deletions iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def load_model(self, model_id: str, device_id_list: list[torch.device]):
model_id (str): The ID of the model to be loaded.
device_id_list (list[torch.device]): List of device_ids where the model should be loaded.
"""
self._task_queue.put((self._load_model_task, (model_id, device_id_list), {}))
self._task_queue.put(
(self._load_one_model_task, (model_id, device_id_list), {})
)

def unload_model(self, model_id: str, device_id_list: list[torch.device]):
"""
Expand All @@ -150,7 +152,9 @@ def unload_model(self, model_id: str, device_id_list: list[torch.device]):
model_id (str): The ID of the model to be unloaded.
device_id_list (list[torch.device]): List of device_ids where the model should be unloaded.
"""
self._task_queue.put((self._unload_model_task, (model_id, device_id_list), {}))
self._task_queue.put(
(self._unload_one_model_task, (model_id, device_id_list), {})
)

def show_loaded_models(
self, device_id_list: list[torch.device]
Expand Down Expand Up @@ -196,60 +200,92 @@ def _worker_loop(self):
finally:
self._task_queue.task_done()

def _load_model_task(self, model_id: str, device_id_list: list[torch.device]):
def _load_model_on_device_task(device_id: torch.device):
if not self.has_request_pools(model_id, device_id):
actions = self._pool_scheduler.schedule_load_model_to_device(
self._model_manager.get_model_info(model_id), device_id
)
for action in actions:
if action.action == ScaleActionType.SCALE_UP:
self._expand_pools_on_device(
action.model_id, device_id, action.amount
)
elif action.action == ScaleActionType.SCALE_DOWN:
self._shrink_pools_on_device(
action.model_id, device_id, action.amount
)
def _load_one_model_task(self, model_id: str, device_id_list: list[torch.device]):
def _load_one_model_on_device_task(device_id: torch.device):
if not self.has_pool_on_device(device_id):
self._expand_pools_on_device(model_id, device_id, 1)
else:
logger.info(
f"[Inference][{device_id}] Model {model_id} is already installed."
f"[Inference][{device_id}] There are already pools on this device."
)

load_model_futures = self._executor.submit_batch(
device_id_list, _load_model_on_device_task
device_id_list, _load_one_model_on_device_task
)
concurrent.futures.wait(
load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
)

def _unload_model_task(self, model_id: str, device_id_list: list[torch.device]):
def _unload_model_on_device_task(device_id: torch.device):
def _unload_one_model_task(self, model_id: str, device_id_list: list[torch.device]):
def _unload_one_model_on_device_task(device_id: torch.device):
if self.has_request_pools(model_id, device_id):
actions = self._pool_scheduler.schedule_unload_model_from_device(
self._model_manager.get_model_info(model_id), device_id
)
for action in actions:
if action.action == ScaleActionType.SCALE_DOWN:
self._shrink_pools_on_device(
action.model_id, device_id, action.amount
)
elif action.action == ScaleActionType.SCALE_UP:
self._expand_pools_on_device(
action.model_id, device_id, action.amount
)
self._shrink_pools_on_device(model_id, device_id, 1)
else:
logger.info(
f"[Inference][{device_id}] Model {model_id} is not installed."
)

unload_model_futures = self._executor.submit_batch(
device_id_list, _unload_model_on_device_task
device_id_list, _unload_one_model_on_device_task
)
concurrent.futures.wait(
unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
)

# def _load_model_task(self, model_id: str, device_id_list: list[torch.device]):
# def _load_model_on_device_task(device_id: torch.device):
# if not self.has_request_pools(model_id, device_id):
# actions = self._pool_scheduler.schedule_load_model_to_device(
# self._model_manager.get_model_info(model_id), device_id
# )
# for action in actions:
# if action.action == ScaleActionType.SCALE_UP:
# self._expand_pools_on_device(
# action.model_id, device_id, action.amount
# )
# elif action.action == ScaleActionType.SCALE_DOWN:
# self._shrink_pools_on_device(
# action.model_id, device_id, action.amount
# )
# else:
# logger.info(
# f"[Inference][{device_id}] Model {model_id} is already installed."
# )
#
# load_model_futures = self._executor.submit_batch(
# device_id_list, _load_model_on_device_task
# )
# concurrent.futures.wait(
# load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
# )
#
# def _unload_model_task(self, model_id: str, device_id_list: list[torch.device]):
# def _unload_model_on_device_task(device_id: torch.device):
# if self.has_request_pools(model_id, device_id):
# actions = self._pool_scheduler.schedule_unload_model_from_device(
# self._model_manager.get_model_info(model_id), device_id
# )
# for action in actions:
# if action.action == ScaleActionType.SCALE_DOWN:
# self._shrink_pools_on_device(
# action.model_id, device_id, action.amount
# )
# elif action.action == ScaleActionType.SCALE_UP:
# self._expand_pools_on_device(
# action.model_id, device_id, action.amount
# )
# else:
# logger.info(
# f"[Inference][{device_id}] Model {model_id} is not installed."
# )
#
# unload_model_futures = self._executor.submit_batch(
# device_id_list, _unload_model_on_device_task
# )
# concurrent.futures.wait(
# unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
# )

def _expand_pools_on_device(
self, model_id: str, device_id: torch.device, count: int
):
Expand Down Expand Up @@ -462,6 +498,12 @@ def has_running_pools(self, model_id: str) -> bool:
return True
return False

def has_pool_on_device(self, device_id: torch.device) -> bool:
"""
Check if there are pools on the given device_id.
"""
return any(device_id in pools for pools in self._request_pool_map.values())

def get_request_pools_group(
self, model_id: str, device_id: torch.device
) -> Optional[PoolGroup]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _estimate_shared_pool_size_by_total_mem(
new_model_info: Optional[ModelInfo] = None,
) -> Dict[str, int]:
"""
Estimate pool counts for (existing_model_ids + new_model_id) by equally
Estimate pool counts for (existing_model_infos + new_model_info) by equally
splitting the device's TOTAL memory among models.

Returns:
Expand All @@ -60,7 +60,7 @@ def _estimate_shared_pool_size_by_total_mem(
)
raise ModelNotExistException(new_model_info.model_id)

# Extract unique model IDs
# Extract unique model infos
all_models = existing_model_infos + (
[new_model_info] if new_model_info is not None else []
)
Expand Down