diff --git a/ucm/transport/kv/asu/CMakeLists.txt b/ucm/transport/kv/asu/CMakeLists.txt index 1776d8683..8a2aa2dc6 100644 --- a/ucm/transport/kv/asu/CMakeLists.txt +++ b/ucm/transport/kv/asu/CMakeLists.txt @@ -1,71 +1,72 @@ -if(DEFINED ASCEND_ROOT AND NOT DEFINED Ascend_ROOT) - set(Ascend_ROOT "${ASCEND_ROOT}" CACHE PATH "Path to Ascend root directory") +if(NOT RUNTIME_ENVIRONMENT STREQUAL "ascend") + message(FATAL_ERROR "BUILD_UCM_ASU requires RUNTIME_ENVIRONMENT=ascend. Current value: ${RUNTIME_ENVIRONMENT}") endif() -if(NOT DEFINED Ascend_ROOT) - set(Ascend_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory") + +if(NOT DEFINED ASCEND_ROOT) + set(ASCEND_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory") endif() find_path(ASU_ASCEND_INCLUDE_DIR NAMES acl/acl.h HINTS - ${Ascend_ROOT}/include - ${Ascend_ROOT}/aarch64-linux/include - ${Ascend_ROOT}/arm64-linux/include + ${ASCEND_ROOT}/include + ${ASCEND_ROOT}/aarch64-linux/include + ${ASCEND_ROOT}/arm64-linux/include NO_DEFAULT_PATH ) if(NOT ASU_ASCEND_INCLUDE_DIR) - message(FATAL_ERROR "Cannot find acl/acl.h under Ascend_ROOT=${Ascend_ROOT}") + message(FATAL_ERROR "Cannot find acl/acl.h under ASCEND_ROOT=${ASCEND_ROOT}") endif() -if(RUNTIME_ENVIRONMENT STREQUAL "ascend") - file(GLOB ASU_TRANSPORT_SOURCES CONFIGURE_DEPENDS trans/src/*.cpp) - add_library(asu_transport SHARED ${ASU_TRANSPORT_SOURCES}) - target_include_directories(asu_transport - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/trans/include - ${ASU_ASCEND_INCLUDE_DIR} - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/common - ${CMAKE_CURRENT_SOURCE_DIR}/trans/src - ${UCM_ROOT_DIR}/ucm/shared/infra - ) - target_link_libraries(asu_transport PUBLIC infra_logger pthread) +file(GLOB ASU_COMMON_SOURCES CONFIGURE_DEPENDS common/*.cpp) +file(GLOB ASU_TRANSPORT_SOURCES CONFIGURE_DEPENDS trans/src/*.cpp) +list(APPEND ASU_TRANSPORT_SOURCES ${ASU_COMMON_SOURCES}) +add_library(asu_transport SHARED ${ASU_TRANSPORT_SOURCES}) +target_include_directories(asu_transport + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/trans/include + ${ASU_ASCEND_INCLUDE_DIR} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/common + ${CMAKE_CURRENT_SOURCE_DIR}/trans/src + ${UCM_ROOT_DIR}/ucm/shared/infra +) +target_link_libraries(asu_transport PUBLIC infra_logger pthread) - target_link_libraries(asu_transport PUBLIC trans) +target_link_libraries(asu_transport PUBLIC trans) - file(GLOB ASU_CLIENT_SOURCES CONFIGURE_DEPENDS client/src/*.cpp) - add_library(asu_client SHARED ${ASU_CLIENT_SOURCES}) - target_include_directories(asu_client - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/client/include +file(GLOB ASU_CLIENT_SOURCES CONFIGURE_DEPENDS client/src/*.cpp) +add_library(asu_client SHARED ${ASU_CLIENT_SOURCES}) +target_include_directories(asu_client + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/client/include +) +target_include_directories(asu_client + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/common + ${CMAKE_CURRENT_SOURCE_DIR}/client/src + ${UCM_ROOT_DIR}/ucm/shared/infra +) +target_link_libraries(asu_client PUBLIC asu_transport kv_common pthread) + +if(BUILD_UNIT_TESTS) + target_compile_definitions(asu_transport PRIVATE ASU_BUILD_TESTS) + include(GoogleTest) + file(GLOB_RECURSE ASU_TEST_SOURCE_FILES CONFIGURE_DEPENDS test/case/*.cc) + file(GLOB_RECURSE ASU_TRANSPORT_TEST_SOURCE_FILES CONFIGURE_DEPENDS trans/test/*.cpp) + list(APPEND ASU_TEST_SOURCE_FILES ${ASU_TRANSPORT_TEST_SOURCE_FILES}) + file(GLOB_RECURSE ASU_CLIENT_TEST_SOURCE_FILES CONFIGURE_DEPENDS client/test/*.cpp) + list(APPEND ASU_TEST_SOURCE_FILES ${ASU_CLIENT_TEST_SOURCE_FILES}) + add_executable(asu.test ${ASU_TEST_SOURCE_FILES}) + target_compile_definitions(asu.test PRIVATE ASU_BUILD_TESTS) + target_include_directories(asu.test PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/common + ${CMAKE_CURRENT_SOURCE_DIR}/client/src + ${CMAKE_CURRENT_SOURCE_DIR}/trans/src + ${UCM_ROOT_DIR}/ucm/shared/infra ) - target_include_directories(asu_client - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/common - ${CMAKE_CURRENT_SOURCE_DIR}/client/src - ${UCM_ROOT_DIR}/ucm/shared/infra + target_link_libraries(asu.test PRIVATE + asu_client + gtest_main gtest ) - target_link_libraries(asu_client PUBLIC asu_transport kv_common pthread) - - if(BUILD_UNIT_TESTS) - target_compile_definitions(asu_transport PRIVATE ASU_BUILD_TESTS) - include(GoogleTest) - file(GLOB_RECURSE ASU_TEST_SOURCE_FILES CONFIGURE_DEPENDS test/case/*.cc) - file(GLOB_RECURSE ASU_TRANSPORT_TEST_SOURCE_FILES CONFIGURE_DEPENDS trans/test/*.cpp) - list(APPEND ASU_TEST_SOURCE_FILES ${ASU_TRANSPORT_TEST_SOURCE_FILES}) - file(GLOB_RECURSE ASU_CLIENT_TEST_SOURCE_FILES CONFIGURE_DEPENDS client/test/*.cpp) - list(APPEND ASU_TEST_SOURCE_FILES ${ASU_CLIENT_TEST_SOURCE_FILES}) - add_executable(asu.test ${ASU_TEST_SOURCE_FILES}) - target_compile_definitions(asu.test PRIVATE ASU_BUILD_TESTS) - target_include_directories(asu.test PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/common - ${CMAKE_CURRENT_SOURCE_DIR}/client/src - ${CMAKE_CURRENT_SOURCE_DIR}/trans/src - ${UCM_ROOT_DIR}/ucm/shared/infra - ) - target_link_libraries(asu.test PRIVATE - asu_client - gtest_main gtest - ) - gtest_discover_tests(asu.test) - endif() + gtest_discover_tests(asu.test) endif() diff --git a/ucm/transport/kv/asu/client/src/asu_client_impl.cpp b/ucm/transport/kv/asu/client/src/asu_client_impl.cpp index 0bf3f2d14..0b711ac77 100644 --- a/ucm/transport/kv/asu/client/src/asu_client_impl.cpp +++ b/ucm/transport/kv/asu/client/src/asu_client_impl.cpp @@ -566,8 +566,7 @@ bool AsuClientImpl::PollTask(const ClientTaskContextPtr& ctx) ctx->finalStatus = anyFailed ? Status::Error(StatusCode::PARTIAL_FAILED, "client task partially failed") : Status::OK(); - ctx->state.store(anyFailed ? ClientTaskState::FAILED : ClientTaskState::COMPLETED, - std::memory_order_release); + ctx->state.store(ClientTaskState::COMPLETED, std::memory_order_release); ctx->cv.notify_all(); return true; } @@ -664,8 +663,7 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint ctx->finalStatus = anyFailed ? Status::Error(StatusCode::PARTIAL_FAILED, "client task partially failed") : Status::OK(); - ctx->state.store(anyFailed ? ClientTaskState::FAILED : ClientTaskState::COMPLETED, - std::memory_order_release); + ctx->state.store(ClientTaskState::COMPLETED, std::memory_order_release); ctx->cv.notify_all(); break; } diff --git a/ucm/transport/kv/asu/client/src/client_config_parser.cpp b/ucm/transport/kv/asu/client/src/client_config_parser.cpp index 13b683d39..047e9ec25 100644 --- a/ucm/transport/kv/asu/client/src/client_config_parser.cpp +++ b/ucm/transport/kv/asu/client/src/client_config_parser.cpp @@ -25,121 +25,21 @@ #include #include #include -#include #include #include +#include #include "asu_client/asu_client.h" +#include "config_parser_common.h" #include "view_server.h" namespace UC::ASU { namespace { -std::string Trim(const std::string& value) -{ - const auto begin = value.find_first_not_of(" \t\r\n"); - if (begin == std::string::npos) { return ""; } - const auto end = value.find_last_not_of(" \t\r\n"); - return value.substr(begin, end - begin + 1); -} - -std::vector Split(const std::string& value, char delimiter) -{ - std::vector parts; - std::stringstream stream{value}; - std::string part; - while (std::getline(stream, part, delimiter)) { - part = Trim(part); - if (!part.empty()) { parts.emplace_back(std::move(part)); } - } - return parts; -} - -std::uint64_t ParseUint64(const std::string& value) { return std::stoull(value, nullptr, 0); } - -Protocol ToTransportProtocol(const std::string& value) -{ - auto protocol = value; - std::transform(protocol.begin(), protocol.end(), protocol.begin(), - [](unsigned char ch) { return static_cast(std::toupper(ch)); }); - if (protocol == "UB" || protocol == "UBOE") { return Protocol::UB; } - if (protocol == "ROCE") { return Protocol::ROCE; } - if (protocol == "TCP") { return Protocol::TCP; } - return Protocol::TCP; -} - -bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId) -{ - constexpr const char* kCamelPrefix = "asuInfo."; - constexpr const char* kSnakePrefix = "asu_info."; - if (key.rfind(kCamelPrefix, 0) == 0) { - asuId = std::stoull(key.substr(std::string{kCamelPrefix}.size())); - return true; - } - if (key.rfind(kSnakePrefix, 0) == 0) { - asuId = std::stoull(key.substr(std::string{kSnakePrefix}.size())); - return true; - } - return false; -} - -void SetEndpointAttr(AsuEndpoint& endpoint, const std::string& key, const std::string& value) -{ - endpoint.attrs[key] = value; -} - -AsuEndpoint ParseAsuEndpoint(const std::string& value) -{ - AsuEndpoint endpoint; - if (value.find('=') == std::string::npos) { - auto parts = Split(value, ':'); - if (!parts.empty()) { endpoint.ip = parts[0]; } - if (parts.size() > 1) { endpoint.port = static_cast(ParseUint64(parts[1])); } - if (parts.size() > 2) { - endpoint.protocol = ToTransportProtocol(parts[2]); - SetEndpointAttr(endpoint, "protocol", parts[2]); - } - return endpoint; - } - - for (const auto& item : Split(value, ',')) { - const auto pos = item.find('='); - if (pos == std::string::npos) { continue; } - - const auto key = Trim(item.substr(0, pos)); - const auto fieldValue = Trim(item.substr(pos + 1)); - if (key == "protocol") { - endpoint.protocol = ToTransportProtocol(fieldValue); - SetEndpointAttr(endpoint, "protocol", fieldValue); - } else if (key == "placement") { - SetEndpointAttr(endpoint, "placement", fieldValue); - } else if (key == "port") { - endpoint.port = static_cast(ParseUint64(fieldValue)); - } else if (key == "local.comm_id" || key == "localCommId") { - endpoint.ip = fieldValue; - } else if (key == "local.phy_device_id" || key == "localPhyDeviceId") { - endpoint.deviceId = static_cast(ParseUint64(fieldValue)); - } else if (key == "tc") { - SetEndpointAttr(endpoint, "tc", fieldValue); - } else if (key == "sl") { - SetEndpointAttr(endpoint, "sl", fieldValue); - } else if (key == "send_size" || key == "sendSize") { - SetEndpointAttr(endpoint, "send_size", fieldValue); - } else if (key == "flag_size" || key == "flagSize") { - SetEndpointAttr(endpoint, "flag_size", fieldValue); - } else if (key == "remote_send_addr" || key == "remoteSendAddr") { - SetEndpointAttr(endpoint, "remote_send_addr", fieldValue); - } else if (key == "remote_flag_addr" || key == "remoteFlagAddr") { - SetEndpointAttr(endpoint, "remote_flag_addr", fieldValue); - } - } - return endpoint; -} - AsuInfo ParseAsuInfo(const std::string& value) { AsuInfo info; - for (const auto& endpointValue : Split(value, ';')) { - info.endpoints.emplace_back(ParseAsuEndpoint(endpointValue)); + for (const auto& endpointValue : SplitConfigValue(value, ';')) { + info.endpoints.emplace_back(ParseClientViewEndpoint(endpointValue)); } return info; } @@ -156,22 +56,23 @@ Status LoadAsuClientConfig(const std::string& configPath, AsuClientConfig& confi config = AsuClientConfig{}; std::unordered_map asuInfos; + std::vector> transportFields; std::string line; while (std::getline(configFile, line)) { - line = Trim(line); + line = TrimConfigValue(line); if (line.empty() || line[0] == '#') { continue; } const auto pos = line.find('='); if (pos == std::string::npos) { continue; } - const auto key = Trim(line.substr(0, pos)); - const auto value = Trim(line.substr(pos + 1)); + const auto key = TrimConfigValue(line.substr(0, pos)); + const auto value = TrimConfigValue(line.substr(pos + 1)); if (key == "clientId" || key == "client_id") { config.clientId = value; } else if (key == "viewServiceAddrs" || key == "view_service_addrs") { - config.viewServiceAddrs = Split(value, ','); + config.viewServiceAddrs = SplitConfigValue(value, ','); } else if (key == "defaultWaitTimeoutMs" || key == "default_wait_timeout_ms") { - config.defaultWaitTimeoutMs = ParseUint64(value); + config.defaultWaitTimeoutMs = ParseConfigUint64(value); } else if (key == "hashTable.type" || key == "hash_table.type") { auto type = value; std::transform(type.begin(), type.end(), type.begin(), @@ -190,20 +91,34 @@ Status LoadAsuClientConfig(const std::string& configPath, AsuClientConfig& confi config.attrs["ring_hash.virtual_node_count"] = value; } else if (key == "hashTable.maglev.tableSize" || key == "maglev.table_size") { config.attrs["maglev.table_size"] = value; - } else if (key == "transport.asuIds" || key == "transport_asu_ids" || key == "asuIds" || - key == "asu_ids") { - for (const auto& asuIdText : Split(value, ',')) { + } else if (key == "transport.asuIds" || key == "asuIds" || key == "asu_ids") { + for (const auto& asuIdText : SplitConfigValue(value, ',')) { TransportConfig transportConfig; - transportConfig.asuId = ParseUint64(asuIdText); + transportConfig.asuId = ParseConfigUint64(asuIdText); config.transportConfigs.emplace_back(std::move(transportConfig)); } } else { AsuId asuId{0}; - if (TryParseAsuInfoKey(key, asuId)) { asuInfos[asuId] = ParseAsuInfo(value); } + std::string attrKey; + if (TryParseAsuInfoKey(key, asuId)) { + asuInfos[asuId] = ParseAsuInfo(value); + } else if (TryGetTransportAttrKey(key, attrKey)) { + transportFields.emplace_back(attrKey, value); + } } } for (auto& transportConfig : config.transportConfigs) { + for (const auto& field : transportFields) { + if (ApplyTransportBufferConfigField(transportConfig, field.first, field.second)) { + continue; + } + if (ApplyTransportIoNumConfigField(transportConfig, field.first, field.second)) { + continue; + } + transportConfig.attrs.emplace(field); + } + auto iter = asuInfos.find(transportConfig.asuId); if (iter == asuInfos.end()) { continue; } ApplyAsuInfoToTransportConfig(iter->second, transportConfig); diff --git a/ucm/transport/kv/asu/client/src/client_task_manager.h b/ucm/transport/kv/asu/client/src/client_task_manager.h index 3360191a2..2743872b7 100644 --- a/ucm/transport/kv/asu/client/src/client_task_manager.h +++ b/ucm/transport/kv/asu/client/src/client_task_manager.h @@ -53,8 +53,7 @@ enum class ClientTaskState { PENDING = 0, INFLIGHT = 1, COMPLETED = 2, - FAILED = 3, - CANCELED = 4, + CANCELED = 3, }; enum class ClientOpType { @@ -80,8 +79,7 @@ struct ClientTaskContext { bool Done() const { auto s = state.load(std::memory_order_acquire); - return s == ClientTaskState::COMPLETED || s == ClientTaskState::FAILED || - s == ClientTaskState::CANCELED; + return s == ClientTaskState::COMPLETED || s == ClientTaskState::CANCELED; } }; diff --git a/ucm/transport/kv/asu/client/src/view_server.cpp b/ucm/transport/kv/asu/client/src/view_server.cpp index 287a721c2..43623854a 100644 --- a/ucm/transport/kv/asu/client/src/view_server.cpp +++ b/ucm/transport/kv/asu/client/src/view_server.cpp @@ -23,48 +23,14 @@ * */ #include "view_server.h" #include -#include #include -#include #include #include "asu_client/asu_client.h" +#include "config_parser_common.h" namespace UC::ASU { namespace { -std::string Trim(const std::string& value) -{ - const auto begin = value.find_first_not_of(" \t\r\n"); - if (begin == std::string::npos) { return ""; } - const auto end = value.find_last_not_of(" \t\r\n"); - return value.substr(begin, end - begin + 1); -} - -std::vector Split(const std::string& value, char delimiter) -{ - std::vector parts; - std::stringstream stream{value}; - std::string part; - while (std::getline(stream, part, delimiter)) { - part = Trim(part); - if (!part.empty()) { parts.emplace_back(std::move(part)); } - } - return parts; -} - -std::uint64_t ParseUint64(const std::string& value) { return std::stoull(value, nullptr, 0); } - -Protocol ToTransportProtocol(const std::string& value) -{ - auto protocol = value; - std::transform(protocol.begin(), protocol.end(), protocol.begin(), - [](unsigned char ch) { return static_cast(std::toupper(ch)); }); - if (protocol == "UB" || protocol == "UBOE") { return Protocol::UB; } - if (protocol == "ROCE") { return Protocol::ROCE; } - if (protocol == "TCP") { return Protocol::TCP; } - return Protocol::TCP; -} - AsuInfo ExtractAsuInfo(const TransportConfig& config) { AsuInfo info; @@ -72,79 +38,11 @@ AsuInfo ExtractAsuInfo(const TransportConfig& config) return info; } -bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId) -{ - constexpr const char* kCamelPrefix = "asuInfo."; - constexpr const char* kSnakePrefix = "asu_info."; - if (key.rfind(kCamelPrefix, 0) == 0) { - asuId = std::stoull(key.substr(std::string{kCamelPrefix}.size())); - return true; - } - if (key.rfind(kSnakePrefix, 0) == 0) { - asuId = std::stoull(key.substr(std::string{kSnakePrefix}.size())); - return true; - } - return false; -} - -void SetEndpointAttr(AsuEndpoint& endpoint, const std::string& key, const std::string& value) -{ - endpoint.attrs[key] = value; -} - -AsuEndpoint ParseAsuEndpoint(const std::string& value) -{ - AsuEndpoint endpoint; - if (value.find('=') == std::string::npos) { - auto parts = Split(value, ':'); - if (!parts.empty()) { endpoint.ip = parts[0]; } - if (parts.size() > 1) { endpoint.port = static_cast(ParseUint64(parts[1])); } - if (parts.size() > 2) { - endpoint.protocol = ToTransportProtocol(parts[2]); - SetEndpointAttr(endpoint, "protocol", parts[2]); - } - return endpoint; - } - - for (const auto& item : Split(value, ',')) { - const auto pos = item.find('='); - if (pos == std::string::npos) { continue; } - - const auto key = Trim(item.substr(0, pos)); - const auto fieldValue = Trim(item.substr(pos + 1)); - if (key == "protocol") { - endpoint.protocol = ToTransportProtocol(fieldValue); - SetEndpointAttr(endpoint, "protocol", fieldValue); - } else if (key == "placement") { - SetEndpointAttr(endpoint, "placement", fieldValue); - } else if (key == "port") { - endpoint.port = static_cast(ParseUint64(fieldValue)); - } else if (key == "local.comm_id" || key == "localCommId") { - endpoint.ip = fieldValue; - } else if (key == "local.phy_device_id" || key == "localPhyDeviceId") { - endpoint.deviceId = static_cast(ParseUint64(fieldValue)); - } else if (key == "tc") { - SetEndpointAttr(endpoint, "tc", fieldValue); - } else if (key == "sl") { - SetEndpointAttr(endpoint, "sl", fieldValue); - } else if (key == "send_size" || key == "sendSize") { - SetEndpointAttr(endpoint, "send_size", fieldValue); - } else if (key == "flag_size" || key == "flagSize") { - SetEndpointAttr(endpoint, "flag_size", fieldValue); - } else if (key == "remote_send_addr" || key == "remoteSendAddr") { - SetEndpointAttr(endpoint, "remote_send_addr", fieldValue); - } else if (key == "remote_flag_addr" || key == "remoteFlagAddr") { - SetEndpointAttr(endpoint, "remote_flag_addr", fieldValue); - } - } - return endpoint; -} - AsuInfo ParseAsuInfo(const std::string& value) { AsuInfo info; - for (const auto& endpointValue : Split(value, ';')) { - info.endpoints.emplace_back(ParseAsuEndpoint(endpointValue)); + for (const auto& endpointValue : SplitConfigValue(value, ';')) { + info.endpoints.emplace_back(ParseClientViewEndpoint(endpointValue)); } return info; } @@ -166,26 +64,26 @@ class ConfigFileViewServer final : public ViewServer { GlobalView nextView; std::string line; while (std::getline(configFile, line)) { - line = Trim(line); + line = TrimConfigValue(line); if (line.empty() || line[0] == '#') { continue; } const auto pos = line.find('='); if (pos == std::string::npos) { continue; } - const auto key = Trim(line.substr(0, pos)); - const auto value = Trim(line.substr(pos + 1)); + const auto key = TrimConfigValue(line.substr(0, pos)); + const auto value = TrimConfigValue(line.substr(pos + 1)); if (key == "viewEpoch" || key == "view_epoch") { - nextView.viewEpoch = std::stoull(value); + nextView.viewEpoch = ParseConfigUint64(value); } else if (key == "viewId" || key == "view_id") { - nextView.viewId = std::stoull(value); + nextView.viewId = ParseConfigUint64(value); } else if (key == "createTimeMs" || key == "create_time_ms") { - nextView.createTimeMs = std::stoull(value); + nextView.createTimeMs = ParseConfigUint64(value); } else if (key == "expireTimeMs" || key == "expire_time_ms") { - nextView.expireTimeMs = std::stoull(value); + nextView.expireTimeMs = ParseConfigUint64(value); } else if (key == "asuIds" || key == "asu_ids") { nextView.asuMap.clear(); - for (const auto& asuId : Split(value, ',')) { - nextView.asuMap.emplace(std::stoull(asuId), AsuInfo{}); + for (const auto& asuId : SplitConfigValue(value, ',')) { + nextView.asuMap.emplace(ParseConfigUint64(asuId), AsuInfo{}); } } else { AsuId asuId{0}; diff --git a/ucm/transport/kv/asu/client/test/asu_client_impl_test.cpp b/ucm/transport/kv/asu/client/test/asu_client_impl_test.cpp index 3ebe183ba..07dbce49c 100644 --- a/ucm/transport/kv/asu/client/test/asu_client_impl_test.cpp +++ b/ucm/transport/kv/asu/client/test/asu_client_impl_test.cpp @@ -475,6 +475,14 @@ TEST(AsuClientImplTest, Lifecycle_PublicInitLoadsClientConfigFile) ASSERT_TRUE(configFile.is_open()); configFile << "clientId=file-init-test\n"; configFile << "transport.asuIds=10,20\n"; + configFile << "transport.send_buffer_slot_size=8192\n"; + configFile << "transport.send_buffer_slot_num=2\n"; + configFile << "transport.flag_buffer_slot_size=256\n"; + configFile << "transport.flag_buffer_slot_num=32\n"; + configFile << "transport.batch_load_io_num=11\n"; + configFile << "transport.batch_store_io_num=12\n"; + configFile << "transport.delete_io_num=13\n"; + configFile << "transport.query_io_num=14\n"; configFile << "asuInfo.20=protocol=roce,placement=device,port=6000," << "local.comm_id=192.168.1.20,local.phy_device_id=0\n"; } @@ -489,6 +497,16 @@ TEST(AsuClientImplTest, Lifecycle_PublicInitLoadsClientConfigFile) ASSERT_EQ(state->initConfigs[20].endpoints.size(), std::size_t{1}); EXPECT_EQ(state->initConfigs[20].endpoints[0].ip, "192.168.1.20"); EXPECT_EQ(state->initConfigs[20].endpoints[0].protocol, Protocol::ROCE); + for (auto asuId : {AsuId{10}, AsuId{20}}) { + EXPECT_EQ(state->initConfigs[asuId].sendBufferSlotSize, std::size_t{8192}); + EXPECT_EQ(state->initConfigs[asuId].sendBufferSlotNum, std::size_t{2}); + EXPECT_EQ(state->initConfigs[asuId].flagBufferSlotSize, std::size_t{256}); + EXPECT_EQ(state->initConfigs[asuId].flagBufferSlotNum, std::size_t{32}); + EXPECT_EQ(state->initConfigs[asuId].asuBatchLoadIoNum, std::size_t{11}); + EXPECT_EQ(state->initConfigs[asuId].asuBatchStoreIoNum, std::size_t{12}); + EXPECT_EQ(state->initConfigs[asuId].asuDeleteIoNum, std::size_t{13}); + EXPECT_EQ(state->initConfigs[asuId].asuQueryIoNum, std::size_t{14}); + } } TEST(AsuClientImplTest, ViewServer_InitFailsWhenViewReferencesMissingTransportConfig) diff --git a/ucm/transport/kv/asu/common/config_parser_common.cpp b/ucm/transport/kv/asu/common/config_parser_common.cpp new file mode 100644 index 000000000..f1f21799e --- /dev/null +++ b/ucm/transport/kv/asu/common/config_parser_common.cpp @@ -0,0 +1,234 @@ +/** + * MIT License + * + * Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "config_parser_common.h" +#include +#include +#include +#include + +namespace UC::ASU { +namespace { + +void SetEndpointAttr(AsuEndpoint& endpoint, const std::string& key, const std::string& value) +{ + endpoint.attrs[key] = value; +} + +void ApplyTransportEndpointField(AsuEndpoint& endpoint, const std::string& key, + const std::string& value) +{ + if (key == "ip" || key == "local.comm_id" || key == "localCommId") { + endpoint.ip = value; + } else if (key == "port") { + endpoint.port = static_cast(ParseConfigUint64(value)); + } else if (key == "protocol") { + endpoint.protocol = ParseConfigProtocol(value); + } else if (key == "numa_node" || key == "numaNode") { + endpoint.numaNode = static_cast(ParseConfigUint64(value)); + } else if (key == "device_id" || key == "deviceId" || key == "local.phy_device_id" || + key == "localPhyDeviceId") { + endpoint.deviceId = static_cast(ParseConfigUint64(value)); + } else if (key == "hca_name" || key == "hcaName") { + endpoint.hcaName = value; + } else if (key == "hca_port" || key == "hcaPort") { + endpoint.hcaPort = static_cast(ParseConfigUint64(value)); + } else { + endpoint.attrs[key] = value; + } +} + +void ApplyClientViewEndpointField(AsuEndpoint& endpoint, const std::string& key, + const std::string& value) +{ + if (key == "protocol") { + endpoint.protocol = ParseConfigProtocol(value); + SetEndpointAttr(endpoint, "protocol", value); + } else if (key == "placement") { + SetEndpointAttr(endpoint, "placement", value); + } else if (key == "port") { + endpoint.port = static_cast(ParseConfigUint64(value)); + } else if (key == "local.comm_id" || key == "localCommId") { + endpoint.ip = value; + } else if (key == "local.phy_device_id" || key == "localPhyDeviceId") { + endpoint.deviceId = static_cast(ParseConfigUint64(value)); + } else if (key == "tc") { + SetEndpointAttr(endpoint, "tc", value); + } else if (key == "sl") { + SetEndpointAttr(endpoint, "sl", value); + } else if (key == "send_size" || key == "sendSize") { + SetEndpointAttr(endpoint, "send_size", value); + } else if (key == "flag_size" || key == "flagSize") { + SetEndpointAttr(endpoint, "flag_size", value); + } else if (key == "remote_send_addr" || key == "remoteSendAddr") { + SetEndpointAttr(endpoint, "remote_send_addr", value); + } else if (key == "remote_flag_addr" || key == "remoteFlagAddr") { + SetEndpointAttr(endpoint, "remote_flag_addr", value); + } +} + +} // namespace + +std::string TrimConfigValue(const std::string& value) +{ + const auto begin = value.find_first_not_of(" \t\r\n"); + if (begin == std::string::npos) { return ""; } + const auto end = value.find_last_not_of(" \t\r\n"); + return value.substr(begin, end - begin + 1); +} + +std::vector SplitConfigValue(const std::string& value, char delimiter) +{ + std::vector parts; + std::stringstream stream{value}; + std::string part; + while (std::getline(stream, part, delimiter)) { + part = TrimConfigValue(part); + if (!part.empty()) { parts.emplace_back(std::move(part)); } + } + return parts; +} + +std::uint64_t ParseConfigUint64(const std::string& value) { return std::stoull(value, nullptr, 0); } + +Protocol ParseConfigProtocol(std::string value) +{ + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::toupper(ch)); }); + if (value == "UB" || value == "UBOE") { return Protocol::UB; } + if (value == "ROCE") { return Protocol::ROCE; } + if (value == "TCP") { return Protocol::TCP; } + return Protocol::TCP; +} + +bool ApplyTransportBufferConfigField(TransportConfig& config, const std::string& key, + const std::string& value) +{ + if (key == "sendBufferSlotSize" || key == "send_buffer_slot_size" || + key == "ioBuffer.sendBufferSlotSize" || key == "io_buffer.send_buffer_slot_size") { + config.sendBufferSlotSize = static_cast(ParseConfigUint64(value)); + } else if (key == "sendBufferSlotNum" || key == "send_buffer_slot_num" || + key == "ioBuffer.sendBufferSlotNum" || key == "io_buffer.send_buffer_slot_num") { + config.sendBufferSlotNum = static_cast(ParseConfigUint64(value)); + } else if (key == "flagBufferSlotSize" || key == "flag_buffer_slot_size" || + key == "ioBuffer.flagBufferSlotSize" || key == "io_buffer.flag_buffer_slot_size") { + config.flagBufferSlotSize = static_cast(ParseConfigUint64(value)); + } else if (key == "flagBufferSlotNum" || key == "flag_buffer_slot_num" || + key == "ioBuffer.flagBufferSlotNum" || key == "io_buffer.flag_buffer_slot_num") { + config.flagBufferSlotNum = static_cast(ParseConfigUint64(value)); + } else { + return false; + } + return true; +} + +bool ApplyTransportIoNumConfigField(TransportConfig& config, const std::string& key, + const std::string& value) +{ + if (key == "batchLoadIoNum" || key == "batch_load_io_num") { + config.asuBatchLoadIoNum = static_cast(ParseConfigUint64(value)); + } else if (key == "batchStoreIoNum" || key == "batch_store_io_num") { + config.asuBatchStoreIoNum = static_cast(ParseConfigUint64(value)); + } else if (key == "deleteIoNum" || key == "delete_io_num") { + config.asuDeleteIoNum = static_cast(ParseConfigUint64(value)); + } else if (key == "queryIoNum" || key == "query_io_num") { + config.asuQueryIoNum = static_cast(ParseConfigUint64(value)); + } else { + return false; + } + return true; +} + +bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId) +{ + constexpr const char* kCamelPrefix = "asuInfo."; + constexpr const char* kSnakePrefix = "asu_info."; + if (key.rfind(kCamelPrefix, 0) == 0) { + asuId = std::stoull(key.substr(std::string{kCamelPrefix}.size())); + return true; + } + if (key.rfind(kSnakePrefix, 0) == 0) { + asuId = std::stoull(key.substr(std::string{kSnakePrefix}.size())); + return true; + } + return false; +} + +bool TryGetTransportAttrKey(const std::string& key, std::string& attrKey) +{ + constexpr const char* kCamelPrefix = "transport."; + if (key.rfind(kCamelPrefix, 0) == 0) { + attrKey = key.substr(std::string{kCamelPrefix}.size()); + return !attrKey.empty(); + } + return false; +} + +AsuEndpoint ParseTransportEndpoint(const std::string& value) +{ + AsuEndpoint endpoint; + if (value.find('=') == std::string::npos) { + auto parts = SplitConfigValue(value, ':'); + if (!parts.empty()) { endpoint.ip = parts[0]; } + if (parts.size() > 1) { + endpoint.port = static_cast(ParseConfigUint64(parts[1])); + } + if (parts.size() > 2) { endpoint.protocol = ParseConfigProtocol(parts[2]); } + return endpoint; + } + + for (const auto& item : SplitConfigValue(value, ',')) { + const auto pos = item.find('='); + if (pos == std::string::npos) { continue; } + ApplyTransportEndpointField(endpoint, TrimConfigValue(item.substr(0, pos)), + TrimConfigValue(item.substr(pos + 1))); + } + return endpoint; +} + +AsuEndpoint ParseClientViewEndpoint(const std::string& value) +{ + AsuEndpoint endpoint; + if (value.find('=') == std::string::npos) { + auto parts = SplitConfigValue(value, ':'); + if (!parts.empty()) { endpoint.ip = parts[0]; } + if (parts.size() > 1) { + endpoint.port = static_cast(ParseConfigUint64(parts[1])); + } + if (parts.size() > 2) { + endpoint.protocol = ParseConfigProtocol(parts[2]); + SetEndpointAttr(endpoint, "protocol", parts[2]); + } + return endpoint; + } + + for (const auto& item : SplitConfigValue(value, ',')) { + const auto pos = item.find('='); + if (pos == std::string::npos) { continue; } + ApplyClientViewEndpointField(endpoint, TrimConfigValue(item.substr(0, pos)), + TrimConfigValue(item.substr(pos + 1))); + } + return endpoint; +} + +} // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/transport_task_completion.h b/ucm/transport/kv/asu/common/config_parser_common.h similarity index 57% rename from ucm/transport/kv/asu/trans/src/transport_task_completion.h rename to ucm/transport/kv/asu/common/config_parser_common.h index 4fb87de90..4e9c4ec3e 100644 --- a/ucm/transport/kv/asu/trans/src/transport_task_completion.h +++ b/ucm/transport/kv/asu/common/config_parser_common.h @@ -23,25 +23,26 @@ * */ #pragma once -#include "transport_task_manager.h" +#include +#include +#include +#include "asu_transport/asu_transport.h" namespace UC::ASU { -class BufferManager; +std::string TrimConfigValue(const std::string& value); +std::vector SplitConfigValue(const std::string& value, char delimiter); +std::uint64_t ParseConfigUint64(const std::string& value); +Protocol ParseConfigProtocol(std::string value); -void InitializeTerminalSubBatchCount(TransportTaskContext& ctx); +bool ApplyTransportBufferConfigField(TransportConfig& config, const std::string& key, + const std::string& value); +bool ApplyTransportIoNumConfigField(TransportConfig& config, const std::string& key, + const std::string& value); +bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId); +bool TryGetTransportAttrKey(const std::string& key, std::string& attrKey); -Status ReleaseSubBatchResources(TransportSubBatchContext& subBatchContext, - BufferManager& sendBufferManager, BufferManager& flagBufferManager); - -Status ReleaseAllSubBatchResources(std::vector& subBatchContexts, - BufferManager& sendBufferManager, - BufferManager& flagBufferManager); - -void CompleteSubBatch(TransportTaskContext& ctx, TransportSubBatchContext& subBatchContext, - TransportSubBatchState state, const Status& status, - BufferManager& sendBufferManager, BufferManager& flagBufferManager); - -void TryFinalizeTaskFromSubBatches(TransportTaskContext& ctx); +AsuEndpoint ParseTransportEndpoint(const std::string& value); +AsuEndpoint ParseClientViewEndpoint(const std::string& value); } // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h b/ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h index 590bee4a5..1a3877768 100644 --- a/ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h +++ b/ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h @@ -23,6 +23,7 @@ * */ #pragma once +#include #include #include #include @@ -69,6 +70,17 @@ struct TransportConfig { bool preconnect{true}; bool bindCqPoller{true}; + std::size_t sendBufferSlotSize{4096}; + std::size_t sendBufferSlotNum{1}; + std::size_t flagBufferSlotSize{128}; + std::size_t flagBufferSlotNum{4096}; + std::size_t asuBatchLoadIoNum{110}; + std::size_t asuBatchStoreIoNum{110}; + std::size_t asuDeleteIoNum{254}; + std::size_t asuQueryIoNum{256}; + + // Transport attrs loaded from config, including SQE request attrs + // (kv_ns_id, dtype, dspec, lr, sc) and send attrs (kernel_count, quiet_count). std::unordered_map attrs; }; diff --git a/ucm/transport/kv/asu/trans/src/asu_response_status.cpp b/ucm/transport/kv/asu/trans/src/asu_response_status.cpp index 1c34044ac..91c84da18 100644 --- a/ucm/transport/kv/asu/trans/src/asu_response_status.cpp +++ b/ucm/transport/kv/asu/trans/src/asu_response_status.cpp @@ -55,9 +55,9 @@ StatusCode CqeStatusCode(std::uint16_t rawStatus) Status ResultBufferEntryToStatus(TransportOpType opType, std::uint8_t rawResult) { - if (opType == TransportOpType::BATCH_STORE || opType == TransportOpType::BATCH_LOAD) { - if (rawResult == 0x00) { return Status::OK(); } + if (opType != TransportOpType::QUERY && rawResult == 0x00) { return Status::OK(); } + if (opType == TransportOpType::BATCH_STORE || opType == TransportOpType::BATCH_LOAD) { switch (EntryStatusCode(opType, rawResult)) { case StatusCode::ASU_ENTRY_RETRY_ADVISED: return Status::Error(StatusCode::ASU_ENTRY_RETRY_ADVISED, @@ -76,8 +76,6 @@ Status ResultBufferEntryToStatus(TransportOpType opType, std::uint8_t rawResult) } if (opType == TransportOpType::DELETE) { - if (rawResult == 0x00) { return Status::OK(); } - switch (EntryStatusCode(opType, rawResult)) { case StatusCode::ASU_ENTRY_DELETE_FAILED: return Status::Error(StatusCode::ASU_ENTRY_DELETE_FAILED, "delete failed"); @@ -99,9 +97,7 @@ Status ResultBufferEntryToStatus(TransportOpType opType, std::uint8_t rawResult) } } - return rawResult == 0 ? Status::OK() - : Status::Error(StatusCode::IO_ERROR, - "entry CQE status is " + std::to_string(rawResult)); + return Status::Error(StatusCode::IO_ERROR, "entry CQE status is " + std::to_string(rawResult)); } } // namespace diff --git a/ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp b/ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp index 2d07a3a01..2ac78973c 100644 --- a/ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp +++ b/ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp @@ -21,12 +21,11 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "asu_submit_flow.h" #include #include #include +#include "asu_transport_impl.h" #include "connection_internal.h" -#include "transport_task_completion.h" namespace UC::ASU { @@ -41,60 +40,49 @@ std::uint32_t GetSendCountAttr(const std::unordered_map& attrs, - const SqeCidAllocator& allocateSqeCid, BufferManager& sendBufferManager, - BufferManager& flagBufferManager, ProtocolManager& protocolManager, - std::vector& subBatchContexts) +Status AsuTransportImpl::SubmitTaskRequests(const TransportTaskContext& ctx, + std::vector& subBatchContexts) { Status finalStatus = Status::OK(); + if (IsEntryBatchOp(ctx.opType)) { - const auto subBatches = ioScheduler.SplitForAsu(ctx.entries, GetSqeBatchLimit(ctx.opType)); + const auto subBatches = ioScheduler_.SplitForAsu(ctx.entries, ctx.opType); subBatchContexts.reserve(subBatches.size()); for (const auto& subBatch : subBatches) { - TransportSubBatchContext subBatchContext; - auto status = SubmitEntrySubBatchRequest(ctx.opType, subBatch, attrs, allocateSqeCid, - sendBufferManager, flagBufferManager, - protocolManager, subBatchContext); + auto& subBatchContext = subBatchContexts.emplace_back(); + auto status = SubmitEntrySubBatchRequest(ctx.opType, subBatch, subBatchContext); subBatchContext.status = status; if (!status.ok() && finalStatus.ok()) { finalStatus = status; } - subBatchContexts.push_back(std::move(subBatchContext)); } } else if (IsKeyBatchOp(ctx.opType)) { - const auto subBatches = ioScheduler.SplitForAsu(ctx.keys, GetSqeBatchLimit(ctx.opType)); + const auto subBatches = ioScheduler_.SplitForAsu(ctx.keys, ctx.opType); subBatchContexts.reserve(subBatches.size()); for (const auto& subBatch : subBatches) { - TransportSubBatchContext subBatchContext; - auto status = SubmitKeySubBatchRequest(ctx.opType, subBatch, attrs, allocateSqeCid, - sendBufferManager, flagBufferManager, - protocolManager, subBatchContext); + auto& subBatchContext = subBatchContexts.emplace_back(); + auto status = SubmitKeySubBatchRequest(ctx.opType, subBatch, subBatchContext); subBatchContext.status = status; if (!status.ok() && finalStatus.ok()) { finalStatus = status; } - subBatchContexts.push_back(std::move(subBatchContext)); } - } else if (ctx.opType == TransportOpType::KEEP_ALIVE) { - TransportSubBatchContext subBatchContext; - auto status = SubmitKeepAliveRequest(allocateSqeCid, sendBufferManager, flagBufferManager, - protocolManager, subBatchContext); + } else if (IsKeepAliveOp(ctx.opType)) { + auto& subBatchContext = subBatchContexts.emplace_back(); + auto status = SubmitKeepAliveRequest(subBatchContext); subBatchContext.status = status; if (!status.ok() && finalStatus.ok()) { finalStatus = status; } - subBatchContexts.push_back(std::move(subBatchContext)); } else { finalStatus = Status::Error(StatusCode::UNSUPPORTED, "transport operation is unsupported"); } return finalStatus; } -Status BuildSubBatchSendBuffers(std::vector& subBatchContexts, - std::vector& ioBatches, - std::vector& subBatchIndexes, - BufferManager& sendBufferManager, BufferManager& flagBufferManager) +Status AsuTransportImpl::BuildSubBatchSendBuffers( + std::vector& subBatchContexts, std::vector& ioBatches, + std::vector& subBatchIndexes) { Status finalStatus = Status::OK(); ioBatches.reserve(subBatchContexts.size()); @@ -102,14 +90,12 @@ Status BuildSubBatchSendBuffers(std::vector& subBatchC for (std::size_t index = 0; index < subBatchContexts.size(); ++index) { auto& subBatchContext = subBatchContexts[index]; - if (subBatchContext.state == TransportSubBatchState::FAILED) { + if (!subBatchContext.status.ok()) { if (finalStatus.ok()) { finalStatus = Status::Error(StatusCode::PARTIAL_FAILED, "one or more sub-batches failed before send"); } - const auto releaseStatus = - ReleaseSubBatchResources(subBatchContext, sendBufferManager, flagBufferManager); - if (finalStatus.ok() && !releaseStatus.ok()) { finalStatus = releaseStatus; } + ReleaseSubBatchResources(subBatchContext); continue; } @@ -118,31 +104,27 @@ Status BuildSubBatchSendBuffers(std::vector& subBatchC Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch flag buffer is not ready"); SetSubBatchSendFailed(subBatchContext, status); if (finalStatus.ok()) { finalStatus = status; } - const auto releaseStatus = - ReleaseSubBatchResources(subBatchContext, sendBufferManager, flagBufferManager); - if (finalStatus.ok() && !releaseStatus.ok()) { finalStatus = releaseStatus; } + ReleaseSubBatchResources(subBatchContext); continue; } ioBatches.push_back( SendIoBatch{subBatchContext.channel->GetNativeQp(), &subBatchContext.sendSge}); - subBatchIndexes.push_back(index); + subBatchIndexes.emplace_back(index); } return finalStatus; } -Status SendSubBatchBuffers(std::vector& subBatchContexts, - const std::vector& ioBatches, - const std::vector& subBatchIndexes, - const std::unordered_map& attrs, - ConnectionManager& connManager) +Status AsuTransportImpl::SendSubBatchBuffers( + std::vector& subBatchContexts, + const std::vector& ioBatches, const std::vector& subBatchIndexes) { Status finalStatus = Status::OK(); if (ioBatches.empty()) { return finalStatus; } - const auto kernelCount = GetSendCountAttr(attrs, "kernel_count"); - const auto quietCount = GetSendCountAttr(attrs, "quiet_count"); + const auto kernelCount = GetSendCountAttr(config_.attrs, "kernel_count"); + const auto quietCount = GetSendCountAttr(config_.attrs, "quiet_count"); const auto sendStatuses = Send(ioBatches, kernelCount, quietCount); if (sendStatuses.size() != ioBatches.size()) { @@ -161,7 +143,7 @@ Status SendSubBatchBuffers(std::vector& subBatchContex if (status.ok()) { continue; } SetSubBatchSendFailed(subBatchContext, status); - connManager.ReportFailure(subBatchContext.channel); + connManager_.ReportFailure(subBatchContext.channel); if (finalStatus.ok()) { finalStatus = status; } } return finalStatus; diff --git a/ucm/transport/kv/asu/trans/src/asu_submit_flow.h b/ucm/transport/kv/asu/trans/src/asu_submit_flow.h deleted file mode 100644 index 60b74fb18..000000000 --- a/ucm/transport/kv/asu/trans/src/asu_submit_flow.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#pragma once - -#include -#include -#include -#include -#include "connection_manager.h" -#include "io_scheduler.h" -#include "sqe_request.h" -#include "transport_task_manager.h" - -namespace UC::ASU { - -class BufferManager; -class ProtocolManager; - -Status SubmitTaskRequests(const TransportTaskContext& ctx, const IoScheduler& ioScheduler, - const std::unordered_map& attrs, - const SqeCidAllocator& allocateSqeCid, BufferManager& sendBufferManager, - BufferManager& flagBufferManager, ProtocolManager& protocolManager, - std::vector& subBatchContexts); - -Status BuildSubBatchSendBuffers(std::vector& subBatchContexts, - std::vector& ioBatches, - std::vector& subBatchIndexes, - BufferManager& sendBufferManager, BufferManager& flagBufferManager); - -Status SendSubBatchBuffers(std::vector& subBatchContexts, - const std::vector& ioBatches, - const std::vector& subBatchIndexes, - const std::unordered_map& attrs, - ConnectionManager& connManager); - -} // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp b/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp index 957596735..264cbbe98 100644 --- a/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp +++ b/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp @@ -26,14 +26,20 @@ #include #include #include +#include +#include "asu_response_status.h" #include "asu_transport/asu_transport.h" -#include "asu_transport/types.h" #include "connection_internal.h" +#include "connection_manager.h" #include "logger.h" #include "transport_config_parser.h" namespace UC::ASU { +namespace { + +} // namespace + AsuTransportImpl::~AsuTransportImpl() { Shutdown(); } Status AsuTransportImpl::Init(const std::string& configPath) @@ -52,6 +58,7 @@ Status AsuTransportImpl::Init(const TransportConfig& config) return Status::OK(); } config_ = config; + ioScheduler_ = IoScheduler(config_); connManager_.PrepareForInit(); @@ -68,10 +75,23 @@ Status AsuTransportImpl::Init(const TransportConfig& config) connManager_.StartRecoverLoop(); + auto status = ValidateSqeRequestAttrs(); + if (!status.ok()) { return status; } + + status = sendBufferManager_.Init("asu send buffer", MemoryType::HOST, + config_.sendBufferSlotSize, config_.sendBufferSlotNum); + if (!status.ok()) { return status; } + + status = flagBufferManager_.Init("asu flag buffer", MemoryType::HOST, + config_.flagBufferSlotSize, config_.flagBufferSlotNum); + if (!status.ok()) { return status; } + protocolManager_ = std::make_unique(); + auto queueDepth = std::max(2, static_cast(config_.maxInflightTasks)); executeQueue_.Setup(queueDepth + 1); stop_.store(false, std::memory_order_release); worker_ = std::thread(&AsuTransportImpl::WorkerLoop, this); + completionWorker_ = std::thread(&AsuTransportImpl::CompletionLoop, this); UC_DEBUG("AsuTransportImpl::Init OK: queueDepth={}", queueDepth); return Status::OK(); } @@ -85,6 +105,7 @@ Status AsuTransportImpl::Shutdown() UC_DEBUG("AsuTransportImpl::Shutdown stopping worker thread"); worker_.join(); } + if (completionWorker_.joinable()) { completionWorker_.join(); } for (const auto& ctx : taskManager_.GetAll()) { if (ctx != nullptr) { (void)taskManager_.Remove(ctx->taskId); } } @@ -95,7 +116,7 @@ Status AsuTransportImpl::Shutdown() Status AsuTransportImpl::CheckHealth() { - if (!worker_.joinable()) { + if (!worker_.joinable() || !completionWorker_.joinable()) { return Status::Error(StatusCode::NOT_INITIALIZED, "transport worker is not running"); } return Status::OK(); @@ -130,7 +151,7 @@ Status AsuTransportImpl::QueryAsync(const std::vector& keys, const Que Status AsuTransportImpl::LoadAsync(const std::vector& entries, TaskId& taskId) { auto ctx = std::make_unique(); - ctx->opType = TransportOpType::LOAD; + ctx->opType = TransportOpType::BATCH_LOAD; ctx->entries = BatchView{entries.data(), entries.size()}; ctx->entryStatus.assign(entries.size(), Status::OK()); return SubmitAsync(std::move(ctx), taskId); @@ -139,7 +160,7 @@ Status AsuTransportImpl::LoadAsync(const std::vector& entries, TaskId& Status AsuTransportImpl::StoreAsync(const std::vector& entries, TaskId& taskId) { auto ctx = std::make_unique(); - ctx->opType = TransportOpType::STORE; + ctx->opType = TransportOpType::BATCH_STORE; ctx->entries = BatchView{entries.data(), entries.size()}; ctx->entryStatus.assign(entries.size(), Status::OK()); return SubmitAsync(std::move(ctx), taskId); @@ -238,6 +259,13 @@ Status AsuTransportImpl::UnregisterRegions(const std::vector& handles) return Status::OK(); } +std::uint16_t AsuTransportImpl::AllocateRequestCid() +{ + auto requestCid = nextRequestCid_.fetch_add(1, std::memory_order_relaxed); + if (requestCid == 0) { requestCid = nextRequestCid_.fetch_add(1, std::memory_order_relaxed); } + return requestCid; +} + Status AsuTransportImpl::SubmitAsync(std::unique_ptr ctx, TaskId& taskId) { if (!worker_.joinable()) { @@ -268,14 +296,45 @@ void AsuTransportImpl::WorkerLoop() { executeQueue_.ConsumerLoop(stop_, [this](TransportTaskContextPtr ctx) { if (!ctx) { return; } - CompleteTask(ctx); + ProcessTask(ctx); }); UC_DEBUG("AsuTransportImpl::WorkerLoop stopped"); } -void AsuTransportImpl::CompleteTask(const TransportTaskContextPtr& ctx) +void AsuTransportImpl::CompletionLoop() +{ + while (!stop_.load(std::memory_order_acquire)) { + for (const auto& ctx : taskManager_.GetAll()) { PollTaskCompletions(ctx); } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +Status AsuTransportImpl::AssignSubBatchConnections( + std::vector& subBatchContexts) +{ + Status finalStatus = Status::OK(); + for (auto& subBatchContext : subBatchContexts) { + if (!subBatchContext.status.ok()) { continue; } + + auto* channel = connManager_.SelectConnection(); + if (channel == nullptr) { + const auto status = + Status::Error(StatusCode::CONNECTION_ERROR, "no available connection channel"); + std::fill(subBatchContext.entryStatus.begin(), subBatchContext.entryStatus.end(), + status); + subBatchContext.state = TransportSubBatchState::COMPLETED; + subBatchContext.status = status; + if (finalStatus.ok()) { finalStatus = status; } + continue; + } + + subBatchContext.channel = channel; + } + return finalStatus; +} + +void AsuTransportImpl::ProcessTask(const TransportTaskContextPtr& ctx) { - // TODO: do REAL work here TransportTaskState expected = TransportTaskState::PENDING; if (!ctx->state.compare_exchange_strong(expected, TransportTaskState::INFLIGHT, std::memory_order_acq_rel)) { @@ -285,30 +344,104 @@ void AsuTransportImpl::CompleteTask(const TransportTaskContextPtr& ctx) return; } + std::vector subBatchContexts; + auto finalStatus = SubmitTaskRequests(*ctx, subBatchContexts); + auto connectionStatus = AssignSubBatchConnections(subBatchContexts); + if (!connectionStatus.ok() && finalStatus.ok()) { finalStatus = connectionStatus; } + + std::vector ioBatches; + std::vector subBatchIndexes; + auto buildStatus = BuildSubBatchSendBuffers(subBatchContexts, ioBatches, subBatchIndexes); + if (!buildStatus.ok() && finalStatus.ok()) { finalStatus = buildStatus; } + + auto sendStatus = SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes); + if (!sendStatus.ok() && finalStatus.ok()) { finalStatus = sendStatus; } + std::lock_guard lock(ctx->waitMu); if (ctx->state.load(std::memory_order_acquire) == TransportTaskState::CANCELED) { + ReleaseAllSubBatchResources(subBatchContexts); ctx->cv.notify_all(); return; } - if (ctx->state.load(std::memory_order_acquire) == TransportTaskState::CANCELED) { - ctx->cv.notify_all(); + + if (!subBatchContexts.empty()) { ctx->subBatchContexts = std::move(subBatchContexts); } + ctx->finalStatus = finalStatus; + ctx->InitializeTerminalSubBatchCount(); + ctx->TryFinalizeFromSubBatches(); + + for (auto& subBatchContext : ctx->subBatchContexts) { + if (subBatchContext.status.ok()) { continue; } + ReleaseSubBatchResources(subBatchContext); + } + if (ctx->Done()) { ctx->cv.notify_all(); } +} + +void AsuTransportImpl::PollTaskCompletions(const TransportTaskContextPtr& ctx) +{ + if (!ctx || ctx->state.load(std::memory_order_acquire) != TransportTaskState::INFLIGHT) { return; } - if (ctx->opType == TransportOpType::QUERY) { - ctx->queryResult.exists.assign(ctx->keys.size, 0); - ctx->queryResult.prefixHitKeys = 0; + + std::lock_guard lock(ctx->waitMu); + if (ctx->subBatchContexts.empty()) { return; } + + for (auto& subBatchContext : ctx->subBatchContexts) { + if (subBatchContext.state != TransportSubBatchState::PENDING) { continue; } + + std::uint16_t completedCid = 0; + auto cidStatus = protocolManager_->PollResponseCid( + reinterpret_cast(subBatchContext.flagBuffer.addr), completedCid); + if (!cidStatus.ok()) { continue; } + if (completedCid == 0 || completedCid != subBatchContext.cid) { continue; } + + KvResponse response; + const auto batchNumber = static_cast(subBatchContext.entryStatus.size()); + const auto unpackStatus = protocolManager_->UnpackResponse( + reinterpret_cast(subBatchContext.flagBuffer.addr), + ToKvOpcode(subBatchContext.opType), batchNumber, response); + if (!unpackStatus.ok()) { + std::fill(subBatchContext.entryStatus.begin(), subBatchContext.entryStatus.end(), + unpackStatus); + CompleteSubBatch(*ctx, subBatchContext, unpackStatus); + continue; + } + + subBatchContext.status = KvResponseStatusToSubBatchStatus(response.status); + FillEntryStatusFromCqeResult(response, subBatchContext); + + if (subBatchContext.status.ok()) { + CompleteSubBatch(*ctx, subBatchContext, Status::OK()); + continue; + } + + if (subBatchContext.status.code == StatusCode::ASU_CQE_INTERNAL_ERROR || + subBatchContext.status.code == StatusCode::ASU_CQE_IO_TIMEOUT) { + connManager_.ReportFailure(subBatchContext.channel); + } + CompleteSubBatch(*ctx, subBatchContext, subBatchContext.status); } - ctx->finalStatus = Status::OK(); - ctx->state.store(TransportTaskState::COMPLETED, std::memory_order_release); - ctx->cv.notify_all(); + ctx->TryFinalizeFromSubBatches(); + if (ctx->Done()) { ctx->cv.notify_all(); } } void AsuTransportImpl::BuildResult(const TransportTaskContext& ctx, TaskResult& result) { result.status = ctx.finalStatus; result.entryStatus = ctx.entryStatus; + if (!ctx.subBatchContexts.empty()) { + std::size_t resultIndex = 0; + for (const auto& subBatchContext : ctx.subBatchContexts) { + for (const auto& status : subBatchContext.entryStatus) { + if (resultIndex >= result.entryStatus.size()) { break; } + result.entryStatus[resultIndex++] = status; + } + } + } + result.queryResult.reset(); - if (ctx.opType == TransportOpType::QUERY) { result.queryResult = ctx.queryResult; } + if (ctx.opType == TransportOpType::QUERY) { + result.queryResult = BuildQueryResultFromEntryStatus(result.entryStatus); + } } std::unique_ptr CreateAsuTransport() { return std::make_unique(); } diff --git a/ucm/transport/kv/asu/trans/src/asu_transport_impl.h b/ucm/transport/kv/asu/trans/src/asu_transport_impl.h index 00524fae1..1c52bd3b4 100644 --- a/ucm/transport/kv/asu/trans/src/asu_transport_impl.h +++ b/ucm/transport/kv/asu/trans/src/asu_transport_impl.h @@ -32,7 +32,10 @@ #include #include #include "asu_transport/asu_transport.h" +#include "buffer_manager.h" #include "connection_manager.h" +#include "io_scheduler.h" +#include "kv_protocol.h" #include "template/spsc_ring_queue.h" #include "transport_task_manager.h" @@ -40,6 +43,20 @@ namespace UC::ASU { using TransportTaskContextPtr = std::shared_ptr; +inline KvOpcode ToKvOpcode(TransportOpType opType) +{ + switch (opType) { + case TransportOpType::LOAD: return KvOpcode::Retrieve; + case TransportOpType::STORE: return KvOpcode::Store; + case TransportOpType::BATCH_LOAD: return KvOpcode::BatchRetrieve; + case TransportOpType::BATCH_STORE: return KvOpcode::BatchStore; + case TransportOpType::DELETE: return KvOpcode::Delete; + case TransportOpType::QUERY: return KvOpcode::Exist; + case TransportOpType::KEEP_ALIVE: return KvOpcode::KeepAlive; + } + return KvOpcode::KeepAlive; +} + class AsuTransportImpl final : public AsuTransport { public: AsuTransportImpl() = default; @@ -73,13 +90,41 @@ class AsuTransportImpl final : public AsuTransport { private: using TransportTaskContextPtr = std::shared_ptr; + std::uint16_t AllocateRequestCid(); Status SubmitAsync(std::unique_ptr ctx, TaskId& taskId); void WorkerLoop(); - void CompleteTask(const TransportTaskContextPtr& ctx); - + void CompletionLoop(); + void ProcessTask(const TransportTaskContextPtr& ctx); + Status AssignSubBatchConnections(std::vector& subBatchContexts); + Status SubmitTaskRequests(const TransportTaskContext& ctx, + std::vector& subBatchContexts); + Status BuildSubBatchSendBuffers(std::vector& subBatchContexts, + std::vector& ioBatches, + std::vector& subBatchIndexes); + Status SendSubBatchBuffers(std::vector& subBatchContexts, + const std::vector& ioBatches, + const std::vector& subBatchIndexes); + Status ValidateSqeRequestAttrs(); + Status SubmitEntrySubBatchRequest(TransportOpType opType, + const IoScheduler::ScheduledIoBatch& subBatch, + TransportSubBatchContext& subBatchContext); + Status SubmitKeySubBatchRequest(TransportOpType opType, + const IoScheduler::ScheduledKeyBatch& subBatch, + TransportSubBatchContext& subBatchContext); + Status SubmitKeepAliveRequest(TransportSubBatchContext& subBatchContext); + void ReleaseSubBatchResources(TransportSubBatchContext& subBatchContext); + void ReleaseAllSubBatchResources(std::vector& subBatchContexts); + void CompleteSubBatch(TransportTaskContext& ctx, TransportSubBatchContext& subBatchContext, + const Status& status); + + void PollTaskCompletions(const TransportTaskContextPtr& ctx); void BuildResult(const TransportTaskContext& ctx, TaskResult& result); TransportConfig config_; + IoScheduler ioScheduler_; + BufferManager sendBufferManager_; + BufferManager flagBufferManager_; + std::unique_ptr protocolManager_; ConnectionManager connManager_; TransportTaskManager taskManager_; @@ -88,7 +133,9 @@ class AsuTransportImpl final : public AsuTransport { std::mutex producerMu_; std::thread worker_; + std::thread completionWorker_; std::atomic_bool stop_{false}; + std::atomic nextRequestCid_{1}; std::mutex registeredRegionsMu_; std::atomic nextMrHandle_{1}; diff --git a/ucm/transport/kv/asu/trans/src/io_scheduler.cpp b/ucm/transport/kv/asu/trans/src/io_scheduler.cpp index c00dd64da..00ec28927 100644 --- a/ucm/transport/kv/asu/trans/src/io_scheduler.cpp +++ b/ucm/transport/kv/asu/trans/src/io_scheduler.cpp @@ -28,58 +28,68 @@ namespace UC::ASU { namespace { -std::size_t GetSubBatchCount(std::size_t total, std::size_t maxIoNum) +std::size_t GetSubBatchCount(std::size_t total, std::size_t ioNum) { - if (total == 0 || maxIoNum == 0) { return 0; } - return 1 + (total - 1) / maxIoNum; + if (total == 0 || ioNum == 0) { return 0; } + return 1 + (total - 1) / ioNum; } -} // namespace - -std::vector IoScheduler::SplitForAsu( - const BatchView& entries, std::size_t maxIoNum) const +template +std::vector SplitBatchView(const BatchView& view, std::size_t ioNum, + SetView setView) { - std::vector result; - if (entries.empty() || maxIoNum == 0) { return result; } + std::vector result; + if (view.empty() || ioNum == 0) { return result; } - const std::size_t subBatchCount = GetSubBatchCount(entries.size, maxIoNum); + const std::size_t subBatchCount = GetSubBatchCount(view.size, ioNum); result.reserve(subBatchCount); - for (std::size_t offset = 0; offset < entries.size; offset += maxIoNum) { - const std::size_t end = std::min(offset + maxIoNum, entries.size); + for (std::size_t offset = 0; offset < view.size; offset += ioNum) { + const std::size_t end = std::min(offset + ioNum, view.size); - ScheduledIoBatch batch; - batch.entries = BatchView{entries.data + offset, end - offset}; - result.push_back(batch); + auto& batch = result.emplace_back(); + setView(batch, BatchView{view.data + offset, end - offset}); } return result; } -std::vector IoScheduler::SplitForAsu( - const BatchView& keys, std::size_t maxIoNum) const -{ - std::vector result; - if (keys.empty() || maxIoNum == 0) { return result; } - - const std::size_t subBatchCount = GetSubBatchCount(keys.size, maxIoNum); - result.reserve(subBatchCount); +} // namespace - for (std::size_t offset = 0; offset < keys.size; offset += maxIoNum) { - const std::size_t end = std::min(offset + maxIoNum, keys.size); +IoScheduler::IoScheduler(const TransportConfig& config) + : batchLoadIoNum_(config.asuBatchLoadIoNum), + batchStoreIoNum_(config.asuBatchStoreIoNum), + deleteIoNum_(config.asuDeleteIoNum), + queryIoNum_(config.asuQueryIoNum) +{ +} - ScheduledKeyBatch batch; - batch.keys = BatchView{keys.data + offset, end - offset}; - result.push_back(batch); - } +std::vector IoScheduler::SplitForAsu( + const BatchView& entries, TransportOpType opType) const +{ + return SplitBatchView( + entries, GetSqeIoNum(opType), + [](ScheduledIoBatch& batch, BatchView view) { batch.entries = view; }); +} - return result; +std::vector IoScheduler::SplitForAsu( + const BatchView& keys, TransportOpType opType) const +{ + return SplitBatchView( + keys, GetSqeIoNum(opType), + [](ScheduledKeyBatch& batch, BatchView view) { batch.keys = view; }); } -std::size_t GetSqeBatchLimit(TransportOpType opType) +std::size_t IoScheduler::GetSqeIoNum(TransportOpType opType) const { if (opType == TransportOpType::LOAD || opType == TransportOpType::STORE) { return 1; } - return GetAsuMaxIoNum(opType); + switch (opType) { + case TransportOpType::BATCH_LOAD: return batchLoadIoNum_; + case TransportOpType::BATCH_STORE: return batchStoreIoNum_; + case TransportOpType::DELETE: return deleteIoNum_; + case TransportOpType::QUERY: return queryIoNum_; + default: return 0; + } } } // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/io_scheduler.h b/ucm/transport/kv/asu/trans/src/io_scheduler.h index 7ad3dc0a1..92080eabb 100644 --- a/ucm/transport/kv/asu/trans/src/io_scheduler.h +++ b/ucm/transport/kv/asu/trans/src/io_scheduler.h @@ -26,12 +26,16 @@ #include #include #include +#include "asu_transport/asu_transport.h" #include "transport_task_manager.h" namespace UC::ASU { class IoScheduler { public: + IoScheduler() = default; + explicit IoScheduler(const TransportConfig& config); + struct ScheduledIoBatch { BatchView entries; }; @@ -41,11 +45,16 @@ class IoScheduler { }; std::vector SplitForAsu(const BatchView& entries, - std::size_t maxIoNum) const; + TransportOpType opType) const; std::vector SplitForAsu(const BatchView& keys, - std::size_t maxIoNum) const; -}; + TransportOpType opType) const; + std::size_t GetSqeIoNum(TransportOpType opType) const; -std::size_t GetSqeBatchLimit(TransportOpType opType); +private: + std::size_t batchLoadIoNum_{110}; + std::size_t batchStoreIoNum_{110}; + std::size_t deleteIoNum_{254}; + std::size_t queryIoNum_{256}; +}; } // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/sqe_request.cpp b/ucm/transport/kv/asu/trans/src/sqe_request.cpp index 9f3f3266f..bdc5dcd6d 100644 --- a/ucm/transport/kv/asu/trans/src/sqe_request.cpp +++ b/ucm/transport/kv/asu/trans/src/sqe_request.cpp @@ -21,14 +21,15 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "sqe_request.h" #include #include #include +#include #include #include #include #include +#include "asu_transport_impl.h" #include "buffer_manager.h" #include "connection_manager.h" #include "kv_protocol.h" @@ -45,15 +46,6 @@ std::uint32_t ToSqeMrKey(MRHandle handle) return static_cast(handle); } -std::uint64_t GetResponseBufferAddr(const ScatterGatherEntry& flagBuffer) -{ - return flagBuffer.addr; -} - -std::uint32_t GetResponseMrKey(const ScatterGatherEntry& flagBuffer) { return flagBuffer.lkey; } - -std::uint16_t NextSqeCid(const SqeCidAllocator& allocateSqeCid) { return allocateSqeCid(); } - std::string ToLower(std::string value) { std::transform(value.begin(), value.end(), value.begin(), @@ -95,12 +87,82 @@ Status AllocateSubBatchFlagBuffer(std::size_t batchNum, BufferManager& flagBuffe Status SetSubBatchBuildFailed(TransportSubBatchContext& subBatchContext, const Status& status) { - subBatchContext.state = TransportSubBatchState::FAILED; + subBatchContext.state = TransportSubBatchState::COMPLETED; subBatchContext.status = status; std::fill(subBatchContext.entryStatus.begin(), subBatchContext.entryStatus.end(), status); return status; } +struct SubBatchRequestSource { + const BatchView* entries{nullptr}; + const BatchView* keys{nullptr}; + + static SubBatchRequestSource FromEntries(const BatchView& value) + { + return SubBatchRequestSource{&value, nullptr}; + } + + static SubBatchRequestSource FromKeys(const BatchView& value) + { + return SubBatchRequestSource{nullptr, &value}; + } + + static SubBatchRequestSource KeepAlive() { return SubBatchRequestSource{}; } +}; + +void ResetSubBatchContext(std::size_t batchNum, TransportSubBatchContext& subBatchContext) +{ + subBatchContext.entryStatus.assign(batchNum, Status::OK()); + subBatchContext.state = TransportSubBatchState::PENDING; + subBatchContext.status = Status::OK(); +} + +Status PrepareSubBatchRequest(TransportOpType opType, std::uint16_t cid, std::size_t batchNum, + BufferManager& flagBufferManager, + TransportSubBatchContext& subBatchContext) +{ + subBatchContext.opType = opType; + subBatchContext.cid = cid; + auto status = AllocateSubBatchFlagBuffer(batchNum, flagBufferManager, subBatchContext); + if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } + return Status::OK(); +} + +Status PackSubBatchRequest(ProtocolManager& protocolManager, BufferManager& sendBufferManager, + KvOpcode opcode, const SqeRequest& request, + TransportSubBatchContext& subBatchContext) +{ + auto packedSize = protocolManager.GetPackedSize(opcode, request); + auto status = sendBufferManager.Allocate(packedSize, subBatchContext.sendSge); + if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } + + status = protocolManager.PackRequest(reinterpret_cast(subBatchContext.sendSge.addr), + opcode, request); + if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } + + subBatchContext.status = status; + return status; +} + +Status InitializeSubBatchSubmission(TransportOpType opType, std::size_t batchNum, bool isSupported, + const std::string& unsupportedMessage, + TransportSubBatchContext& subBatchContext, KvOpcode& opcode, + bool& shouldSubmit) +{ + ResetSubBatchContext(batchNum, subBatchContext); + shouldSubmit = false; + + if (batchNum == 0) { return Status::OK(); } + + if (!isSupported) { + auto status = Status::Error(StatusCode::UNSUPPORTED, unsupportedMessage); + return SetSubBatchBuildFailed(subBatchContext, status); + } + opcode = ToKvOpcode(opType); + shouldSubmit = true; + return Status::OK(); +} + KvBatchStoreRequest BuildBatchStoreRequest( const BatchView& entries, const std::unordered_map& attrs, std::uint16_t cid, const ScatterGatherEntry& flagBuffer) @@ -110,8 +172,8 @@ KvBatchStoreRequest BuildBatchStoreRequest( request.kv_ns_id = GetTransportConfigAttr(attrs, "kv_ns_id"); request.dtype = GetTransportConfigAttr(attrs, "dtype"); request.dspec = GetTransportConfigAttr(attrs, "dspec"); - request.response_buffer_addr = GetResponseBufferAddr(flagBuffer); - request.response_mr_key = GetResponseMrKey(flagBuffer); + request.response_buffer_addr = flagBuffer.addr; + request.response_mr_key = flagBuffer.lkey; request.lr = GetTransportConfigAttr(attrs, "lr"); request.rflag = true; request.batch_number = static_cast(entries.size); @@ -135,8 +197,8 @@ KvBatchRetrieveRequest BuildBatchRetrieveRequest( KvBatchRetrieveRequest request; request.cid = cid; request.kv_ns_id = GetTransportConfigAttr(attrs, "kv_ns_id"); - request.response_buffer_addr = GetResponseBufferAddr(flagBuffer); - request.response_mr_key = GetResponseMrKey(flagBuffer); + request.response_buffer_addr = flagBuffer.addr; + request.response_mr_key = flagBuffer.lkey; request.lr = GetTransportConfigAttr(attrs, "lr"); request.rflag = true; request.batch_number = static_cast(entries.size); @@ -170,8 +232,8 @@ KvDeleteRequest BuildDeleteRequest(const BatchView& keys, KvDeleteRequest request; request.cid = cid; request.kv_ns_id = GetTransportConfigAttr(attrs, "kv_ns_id"); - request.response_buffer_addr = GetResponseBufferAddr(flagBuffer); - request.response_mr_key = GetResponseMrKey(flagBuffer); + request.response_buffer_addr = flagBuffer.addr; + request.response_mr_key = flagBuffer.lkey; request.rflag = true; request.keys = CopyKeys(keys); request.batch_number = static_cast(request.keys.size()); @@ -185,8 +247,8 @@ KvExistRequest BuildExistRequest(const BatchView& keys, KvExistRequest request; request.cid = cid; request.kv_ns_id = GetTransportConfigAttr(attrs, "kv_ns_id"); - request.response_buffer_addr = GetResponseBufferAddr(flagBuffer); - request.response_mr_key = GetResponseMrKey(flagBuffer); + request.response_buffer_addr = flagBuffer.addr; + request.response_mr_key = flagBuffer.lkey; request.rflag = true; request.sc = GetTransportConfigAttr(attrs, "sc"); request.keys = CopyKeys(keys); @@ -198,19 +260,49 @@ KvKeepAliveRequest BuildKeepAliveRequest(std::uint16_t cid, const ScatterGatherE { KvKeepAliveRequest request; request.cid = cid; - request.response_buffer_addr = GetResponseBufferAddr(flagBuffer); - request.response_mr_key = GetResponseMrKey(flagBuffer); + request.response_buffer_addr = flagBuffer.addr; + request.response_mr_key = flagBuffer.lkey; request.rflag = true; return request; } +std::unique_ptr BuildSqeRequest( + KvOpcode opcode, const SubBatchRequestSource& source, + const std::unordered_map& attrs, std::uint16_t cid, + const ScatterGatherEntry& flagBuffer, TransportSubBatchContext& subBatchContext) +{ + switch (opcode) { + case KvOpcode::BatchRetrieve: + if (source.entries == nullptr) { return nullptr; } + return std::make_unique( + BuildBatchRetrieveRequest(*source.entries, attrs, cid, flagBuffer)); + case KvOpcode::BatchStore: + if (source.entries == nullptr) { return nullptr; } + return std::make_unique( + BuildBatchStoreRequest(*source.entries, attrs, cid, flagBuffer)); + case KvOpcode::Delete: + if (source.keys == nullptr) { return nullptr; } + return std::make_unique( + BuildDeleteRequest(*source.keys, attrs, cid, flagBuffer)); + case KvOpcode::Exist: { + if (source.keys == nullptr) { return nullptr; } + auto request = BuildExistRequest(*source.keys, attrs, cid, flagBuffer); + subBatchContext.useSeekControl = request.sc; + return std::make_unique(std::move(request)); + } + case KvOpcode::KeepAlive: + return std::make_unique(BuildKeepAliveRequest(cid, flagBuffer)); + default: return nullptr; + } +} + } // namespace -Status ValidateSqeRequestAttrs(const std::unordered_map& attrs) +Status AsuTransportImpl::ValidateSqeRequestAttrs() { - const auto validateInteger = [&attrs](const std::string& name, auto maxValue) -> Status { - auto iter = attrs.find(name); - if (iter == attrs.end()) { return Status::OK(); } + const auto validateInteger = [this](const std::string& name, auto maxValue) -> Status { + auto iter = config_.attrs.find(name); + if (iter == config_.attrs.end()) { return Status::OK(); } try { const auto parsed = std::stoull(iter->second, nullptr, 0); if (parsed > maxValue) { @@ -222,10 +314,10 @@ Status ValidateSqeRequestAttrs(const std::unordered_map Status { - auto iter = attrs.find(name); - if (iter == attrs.end()) { + const auto validateRequiredPositiveInteger = [this](const std::string& name, + auto maxValue) -> Status { + auto iter = config_.attrs.find(name); + if (iter == config_.attrs.end()) { return Status::Error(StatusCode::INVALID_ARGUMENT, name + " is required"); } try { @@ -243,9 +335,9 @@ Status ValidateSqeRequestAttrs(const std::unordered_map Status { - auto iter = attrs.find(name); - if (iter == attrs.end()) { return Status::OK(); } + const auto validateBool = [this](const std::string& name) -> Status { + auto iter = config_.attrs.find(name); + if (iter == config_.attrs.end()) { return Status::OK(); } const auto value = ToLower(iter->second); if (value == "1" || value == "0" || value == "true" || value == "false") { return Status::OK(); @@ -270,132 +362,72 @@ Status ValidateSqeRequestAttrs(const std::unordered_map::max()); } -Status SubmitEntrySubBatchRequest(TransportOpType opType, - const IoScheduler::ScheduledIoBatch& subBatch, - const std::unordered_map& attrs, - const SqeCidAllocator& allocateSqeCid, - BufferManager& sendBufferManager, - BufferManager& flagBufferManager, - ProtocolManager& protocolManager, - TransportSubBatchContext& subBatchContext) +Status AsuTransportImpl::SubmitEntrySubBatchRequest(TransportOpType opType, + const IoScheduler::ScheduledIoBatch& subBatch, + TransportSubBatchContext& subBatchContext) { - subBatchContext.entryStatus.assign(subBatch.entries.size, Status::OK()); - subBatchContext.state = TransportSubBatchState::PENDING; - subBatchContext.status = Status::OK(); - - if (subBatch.entries.empty()) { return Status::OK(); } - - if (opType == TransportOpType::BATCH_LOAD) { - subBatchContext.opType = opType; - subBatchContext.cid = NextSqeCid(allocateSqeCid); - auto status = - AllocateSubBatchFlagBuffer(subBatch.entries.size, flagBufferManager, subBatchContext); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - auto request = BuildBatchRetrieveRequest(subBatch.entries, attrs, subBatchContext.cid, - subBatchContext.flagBuffer); - auto packedSize = protocolManager.GetPackedSize(KvOpcode::BatchRetrieve, request); - status = sendBufferManager.Allocate(packedSize, subBatchContext.sendSge); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - status = protocolManager.PackRequest(reinterpret_cast(subBatchContext.sendSge.addr), - KvOpcode::BatchRetrieve, request); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - subBatchContext.status = status; - return status; - } + constexpr auto kUnsupportedMessage = + "entry batch submit only supports batch store/retrieve operations"; + const auto source = SubBatchRequestSource::FromEntries(subBatch.entries); + KvOpcode opcode{}; + bool shouldSubmit = false; + auto status = + InitializeSubBatchSubmission(opType, subBatch.entries.size, IsEntryBatchOp(opType), + kUnsupportedMessage, subBatchContext, opcode, shouldSubmit); + if (!status.ok() || !shouldSubmit) { return status; } - if (opType == TransportOpType::BATCH_STORE) { - subBatchContext.opType = opType; - subBatchContext.cid = NextSqeCid(allocateSqeCid); - auto status = - AllocateSubBatchFlagBuffer(subBatch.entries.size, flagBufferManager, subBatchContext); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - auto request = BuildBatchStoreRequest(subBatch.entries, attrs, subBatchContext.cid, - subBatchContext.flagBuffer); - auto packedSize = protocolManager.GetPackedSize(KvOpcode::BatchStore, request); - status = sendBufferManager.Allocate(packedSize, subBatchContext.sendSge); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - status = protocolManager.PackRequest(reinterpret_cast(subBatchContext.sendSge.addr), - KvOpcode::BatchStore, request); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - subBatchContext.status = status; - return status; - } + status = PrepareSubBatchRequest(opType, AllocateRequestCid(), subBatch.entries.size, + flagBufferManager_, subBatchContext); + if (!status.ok()) { return status; } - auto status = Status::Error(StatusCode::UNSUPPORTED, - "entry batch submit only supports batch store/retrieve operations"); - return SetSubBatchBuildFailed(subBatchContext, status); + auto request = BuildSqeRequest(opcode, source, config_.attrs, subBatchContext.cid, + subBatchContext.flagBuffer, subBatchContext); + return PackSubBatchRequest(*protocolManager_, sendBufferManager_, opcode, *request, + subBatchContext); } -Status SubmitKeySubBatchRequest(TransportOpType opType, - const IoScheduler::ScheduledKeyBatch& subBatch, - const std::unordered_map& attrs, - const SqeCidAllocator& allocateSqeCid, - BufferManager& sendBufferManager, BufferManager& flagBufferManager, - ProtocolManager& protocolManager, - TransportSubBatchContext& subBatchContext) +Status AsuTransportImpl::SubmitKeySubBatchRequest(TransportOpType opType, + const IoScheduler::ScheduledKeyBatch& subBatch, + TransportSubBatchContext& subBatchContext) { - subBatchContext.entryStatus.assign(subBatch.keys.size, Status::OK()); - subBatchContext.state = TransportSubBatchState::PENDING; - subBatchContext.status = Status::OK(); - - if (subBatch.keys.empty()) { return Status::OK(); } - - if (opType != TransportOpType::QUERY && opType != TransportOpType::DELETE) { - auto status = - Status::Error(StatusCode::UNSUPPORTED, "key batch submit only supports query/delete"); - return SetSubBatchBuildFailed(subBatchContext, status); - } - - subBatchContext.cid = NextSqeCid(allocateSqeCid); - subBatchContext.opType = opType; + constexpr auto kUnsupportedMessage = "key batch submit only supports query/delete"; + const auto source = SubBatchRequestSource::FromKeys(subBatch.keys); + KvOpcode opcode{}; + bool shouldSubmit = false; auto status = - AllocateSubBatchFlagBuffer(subBatch.keys.size, flagBufferManager, subBatchContext); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - if (opType == TransportOpType::DELETE) { - auto request = BuildDeleteRequest(subBatch.keys, attrs, subBatchContext.cid, - subBatchContext.flagBuffer); - auto packedSize = protocolManager.GetPackedSize(KvOpcode::Delete, request); - status = sendBufferManager.Allocate(packedSize, subBatchContext.sendSge); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - status = protocolManager.PackRequest(reinterpret_cast(subBatchContext.sendSge.addr), - KvOpcode::Delete, request); - } else if (opType == TransportOpType::QUERY) { - auto request = BuildExistRequest(subBatch.keys, attrs, subBatchContext.cid, - subBatchContext.flagBuffer); - subBatchContext.useSeekControl = request.sc; - auto packedSize = protocolManager.GetPackedSize(KvOpcode::Exist, request); - status = sendBufferManager.Allocate(packedSize, subBatchContext.sendSge); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - status = protocolManager.PackRequest(reinterpret_cast(subBatchContext.sendSge.addr), - KvOpcode::Exist, request); - } - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } + InitializeSubBatchSubmission(opType, subBatch.keys.size, IsKeyBatchOp(opType), + kUnsupportedMessage, subBatchContext, opcode, shouldSubmit); + if (!status.ok() || !shouldSubmit) { return status; } - subBatchContext.status = Status::OK(); - return Status::OK(); + status = PrepareSubBatchRequest(opType, AllocateRequestCid(), subBatch.keys.size, + flagBufferManager_, subBatchContext); + if (!status.ok()) { return status; } + + auto request = BuildSqeRequest(opcode, source, config_.attrs, subBatchContext.cid, + subBatchContext.flagBuffer, subBatchContext); + return PackSubBatchRequest(*protocolManager_, sendBufferManager_, opcode, *request, + subBatchContext); } -Status SubmitKeepAliveRequest(const SqeCidAllocator& allocateSqeCid, - BufferManager& sendBufferManager, BufferManager& flagBufferManager, - ProtocolManager& protocolManager, - TransportSubBatchContext& subBatchContext) +Status AsuTransportImpl::SubmitKeepAliveRequest(TransportSubBatchContext& subBatchContext) { - subBatchContext.cid = NextSqeCid(allocateSqeCid); - subBatchContext.opType = TransportOpType::KEEP_ALIVE; - subBatchContext.state = TransportSubBatchState::PENDING; - subBatchContext.status = Status::OK(); - subBatchContext.entryStatus = {Status::OK()}; - auto status = AllocateSubBatchFlagBuffer(1, flagBufferManager, subBatchContext); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - auto request = BuildKeepAliveRequest(subBatchContext.cid, subBatchContext.flagBuffer); - auto packedSize = protocolManager.GetPackedSize(KvOpcode::KeepAlive, request); - status = sendBufferManager.Allocate(packedSize, subBatchContext.sendSge); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - status = protocolManager.PackRequest(reinterpret_cast(subBatchContext.sendSge.addr), - KvOpcode::KeepAlive, request); - if (!status.ok()) { return SetSubBatchBuildFailed(subBatchContext, status); } - subBatchContext.status = status; - return status; + constexpr auto kUnsupportedMessage = "keep alive submit only supports keep alive"; + const auto opType = TransportOpType::KEEP_ALIVE; + const auto source = SubBatchRequestSource::KeepAlive(); + KvOpcode opcode{}; + bool shouldSubmit = false; + auto status = InitializeSubBatchSubmission(opType, 1, true, kUnsupportedMessage, + subBatchContext, opcode, shouldSubmit); + if (!status.ok() || !shouldSubmit) { return status; } + + status = PrepareSubBatchRequest(opType, AllocateRequestCid(), 1, flagBufferManager_, + subBatchContext); + if (!status.ok()) { return status; } + + auto request = BuildSqeRequest(opcode, source, config_.attrs, subBatchContext.cid, + subBatchContext.flagBuffer, subBatchContext); + return PackSubBatchRequest(*protocolManager_, sendBufferManager_, opcode, *request, + subBatchContext); } } // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/sqe_request.h b/ucm/transport/kv/asu/trans/src/sqe_request.h deleted file mode 100644 index 22de8ff7f..000000000 --- a/ucm/transport/kv/asu/trans/src/sqe_request.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#pragma once - -#include -#include -#include -#include -#include -#include "asu_transport/types.h" -#include "io_scheduler.h" -#include "transport_task_manager.h" - -namespace UC::ASU { - -class BufferManager; -class ProtocolManager; - -using SqeCidAllocator = std::function; - -Status ValidateSqeRequestAttrs(const std::unordered_map& attrs); - -Status SubmitEntrySubBatchRequest(TransportOpType opType, - const IoScheduler::ScheduledIoBatch& subBatch, - const std::unordered_map& attrs, - const SqeCidAllocator& allocateSqeCid, - BufferManager& sendBufferManager, - BufferManager& flagBufferManager, - ProtocolManager& protocolManager, - TransportSubBatchContext& subBatchContext); - -Status SubmitKeySubBatchRequest(TransportOpType opType, - const IoScheduler::ScheduledKeyBatch& subBatch, - const std::unordered_map& attrs, - const SqeCidAllocator& allocateSqeCid, - BufferManager& sendBufferManager, BufferManager& flagBufferManager, - ProtocolManager& protocolManager, - TransportSubBatchContext& subBatchContext); - -Status SubmitKeepAliveRequest(const SqeCidAllocator& allocateSqeCid, - BufferManager& sendBufferManager, BufferManager& flagBufferManager, - ProtocolManager& protocolManager, - TransportSubBatchContext& subBatchContext); - -} // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp b/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp index b12bbfcd8..8bf7a3f8a 100644 --- a/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp +++ b/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp @@ -22,89 +22,12 @@ * SOFTWARE. * */ #include "transport_config_parser.h" -#include -#include #include -#include #include #include "asu_transport/asu_transport.h" +#include "config_parser_common.h" namespace UC::ASU { -namespace { - -std::string Trim(const std::string& value) -{ - const auto begin = value.find_first_not_of(" \t\r\n"); - if (begin == std::string::npos) { return ""; } - const auto end = value.find_last_not_of(" \t\r\n"); - return value.substr(begin, end - begin + 1); -} - -std::vector Split(const std::string& value, char delimiter) -{ - std::vector parts; - std::stringstream stream{value}; - std::string part; - while (std::getline(stream, part, delimiter)) { - part = Trim(part); - if (!part.empty()) { parts.emplace_back(std::move(part)); } - } - return parts; -} - -std::uint64_t ParseUint64(const std::string& value) { return std::stoull(value, nullptr, 0); } - -Protocol ParseProtocol(std::string value) -{ - std::transform(value.begin(), value.end(), value.begin(), - [](unsigned char ch) { return static_cast(std::toupper(ch)); }); - if (value == "UB" || value == "UBOE") { return Protocol::UB; } - if (value == "ROCE") { return Protocol::ROCE; } - return Protocol::TCP; -} - -void ApplyEndpointField(AsuEndpoint& endpoint, const std::string& key, const std::string& value) -{ - if (key == "ip" || key == "local.comm_id" || key == "localCommId") { - endpoint.ip = value; - } else if (key == "port") { - endpoint.port = static_cast(ParseUint64(value)); - } else if (key == "protocol") { - endpoint.protocol = ParseProtocol(value); - } else if (key == "numa_node" || key == "numaNode") { - endpoint.numaNode = static_cast(ParseUint64(value)); - } else if (key == "device_id" || key == "deviceId" || key == "local.phy_device_id" || - key == "localPhyDeviceId") { - endpoint.deviceId = static_cast(ParseUint64(value)); - } else if (key == "hca_name" || key == "hcaName") { - endpoint.hcaName = value; - } else if (key == "hca_port" || key == "hcaPort") { - endpoint.hcaPort = static_cast(ParseUint64(value)); - } else { - endpoint.attrs[key] = value; - } -} - -AsuEndpoint ParseEndpoint(const std::string& value) -{ - AsuEndpoint endpoint; - if (value.find('=') == std::string::npos) { - auto parts = Split(value, ':'); - if (!parts.empty()) { endpoint.ip = parts[0]; } - if (parts.size() > 1) { endpoint.port = static_cast(ParseUint64(parts[1])); } - if (parts.size() > 2) { endpoint.protocol = ParseProtocol(parts[2]); } - return endpoint; - } - for (const auto& item : Split(value, ',')) { - const auto pos = item.find('='); - if (pos == std::string::npos) { continue; } - ApplyEndpointField(endpoint, Trim(item.substr(0, pos)), Trim(item.substr(pos + 1))); - } - return endpoint; -} - -} // namespace - Status LoadTransportConfig(const std::string& configPath, TransportConfig& config) { std::ifstream configFile{configPath}; @@ -116,33 +39,37 @@ Status LoadTransportConfig(const std::string& configPath, TransportConfig& confi config = TransportConfig{}; std::string line; while (std::getline(configFile, line)) { - line = Trim(line); + line = TrimConfigValue(line); if (line.empty() || line[0] == '#') { continue; } const auto pos = line.find('='); if (pos == std::string::npos) { continue; } - const auto key = Trim(line.substr(0, pos)); - const auto value = Trim(line.substr(pos + 1)); + const auto key = TrimConfigValue(line.substr(0, pos)); + const auto value = TrimConfigValue(line.substr(pos + 1)); if (key == "asuName" || key == "asu_name") { config.asuName = value; } else if (key == "asuId" || key == "asu_id") { - config.asuId = ParseUint64(value); + config.asuId = ParseConfigUint64(value); } else if (key == "endpoint" || key == "endpoints") { config.endpoints.clear(); - for (const auto& endpointValue : Split(value, ';')) { - config.endpoints.emplace_back(ParseEndpoint(endpointValue)); + for (const auto& endpointValue : SplitConfigValue(value, ';')) { + config.endpoints.emplace_back(ParseTransportEndpoint(endpointValue)); } } else if (key == "queryTimeoutMs" || key == "query_timeout_ms") { - config.queryTimeoutMs = ParseUint64(value); + config.queryTimeoutMs = ParseConfigUint64(value); } else if (key == "loadTimeoutMs" || key == "load_timeout_ms") { - config.loadTimeoutMs = ParseUint64(value); + config.loadTimeoutMs = ParseConfigUint64(value); } else if (key == "storeTimeoutMs" || key == "store_timeout_ms") { - config.storeTimeoutMs = ParseUint64(value); + config.storeTimeoutMs = ParseConfigUint64(value); } else if (key == "maxInflightTasks" || key == "max_inflight_tasks") { - config.maxInflightTasks = static_cast(ParseUint64(value)); + config.maxInflightTasks = static_cast(ParseConfigUint64(value)); } else if (key == "maxInflightBytes" || key == "max_inflight_bytes") { - config.maxInflightBytes = ParseUint64(value); + config.maxInflightBytes = ParseConfigUint64(value); + } else if (ApplyTransportBufferConfigField(config, key, value)) { + continue; + } else if (ApplyTransportIoNumConfigField(config, key, value)) { + continue; } else { config.attrs[key] = value; } diff --git a/ucm/transport/kv/asu/trans/src/transport_task_completion.cpp b/ucm/transport/kv/asu/trans/src/transport_task_completion.cpp index 2946b69f6..9892d93ca 100644 --- a/ucm/transport/kv/asu/trans/src/transport_task_completion.cpp +++ b/ucm/transport/kv/asu/trans/src/transport_task_completion.cpp @@ -21,9 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "transport_task_completion.h" +#include "asu_transport_impl.h" #include "buffer_manager.h" #include "connection_internal.h" +#include "logger.h" namespace UC::ASU { @@ -31,7 +32,7 @@ namespace { bool IsSubBatchTerminal(TransportSubBatchState state) { - return state == TransportSubBatchState::COMPLETED || state == TransportSubBatchState::FAILED; + return state == TransportSubBatchState::COMPLETED; } Status BuildTaskFinalStatus(const TransportTaskContext& ctx) @@ -46,46 +47,37 @@ Status BuildTaskFinalStatus(const TransportTaskContext& ctx) return Status::OK(); } -TransportTaskState BuildTaskStateFromSubBatches(const TransportTaskContext& ctx) -{ - for (const auto& subBatchContext : ctx.subBatchContexts) { - if (subBatchContext.state == TransportSubBatchState::FAILED) { - return TransportTaskState::FAILED; - } - } - return TransportTaskState::COMPLETED; -} - } // namespace -void InitializeTerminalSubBatchCount(TransportTaskContext& ctx) +void TransportTaskContext::InitializeTerminalSubBatchCount() { // At submit completion time, terminal sub-batches are usually submit/send failures. - ctx.completedSubBatchCount = 0; - for (const auto& subBatchContext : ctx.subBatchContexts) { + completedSubBatchCount = 0; + for (const auto& subBatchContext : subBatchContexts) { if (!IsSubBatchTerminal(subBatchContext.state)) { continue; } - ++ctx.completedSubBatchCount; + ++completedSubBatchCount; } } -Status ReleaseSubBatchResources(TransportSubBatchContext& subBatchContext, - BufferManager& sendBufferManager, BufferManager& flagBufferManager) +void AsuTransportImpl::ReleaseSubBatchResources(TransportSubBatchContext& subBatchContext) { - Status finalStatus = Status::OK(); - if (subBatchContext.sendSge.slot_index != UINT32_MAX) { - auto status = sendBufferManager.Free(subBatchContext.sendSge.slot_index); + const auto slotIndex = subBatchContext.sendSge.slot_index; + auto status = sendBufferManager_.Free(slotIndex); if (!status.ok()) { - if (finalStatus.ok()) { finalStatus = status; } + UC_ERROR("Failed to release sub-batch send buffer slot({}): {}", slotIndex, + status.message); } subBatchContext.sendSge = {}; } if (subBatchContext.flagBuffer.slot_index != UINT32_MAX) { - auto status = flagBufferManager.Free(subBatchContext.flagBuffer.slot_index); + const auto slotIndex = subBatchContext.flagBuffer.slot_index; + auto status = flagBufferManager_.Free(slotIndex); if (!status.ok()) { - if (finalStatus.ok()) { finalStatus = status; } + UC_ERROR("Failed to release sub-batch flag buffer slot({}): {}", slotIndex, + status.message); } subBatchContext.flagBuffer = {}; } @@ -94,54 +86,37 @@ Status ReleaseSubBatchResources(TransportSubBatchContext& subBatchContext, subBatchContext.channel->ReleaseInflight(); subBatchContext.channel = nullptr; } - - return finalStatus; } -Status ReleaseAllSubBatchResources(std::vector& subBatchContexts, - BufferManager& sendBufferManager, - BufferManager& flagBufferManager) +void AsuTransportImpl::ReleaseAllSubBatchResources( + std::vector& subBatchContexts) { - Status finalStatus = Status::OK(); - for (auto& subBatchContext : subBatchContexts) { - const auto status = - ReleaseSubBatchResources(subBatchContext, sendBufferManager, flagBufferManager); - if (finalStatus.ok() && !status.ok()) { finalStatus = status; } - } - return finalStatus; + for (auto& subBatchContext : subBatchContexts) { ReleaseSubBatchResources(subBatchContext); } } -void CompleteSubBatch(TransportTaskContext& ctx, TransportSubBatchContext& subBatchContext, - TransportSubBatchState state, const Status& status, - BufferManager& sendBufferManager, BufferManager& flagBufferManager) +void AsuTransportImpl::CompleteSubBatch(TransportTaskContext& ctx, + TransportSubBatchContext& subBatchContext, + const Status& status) { if (subBatchContext.state != TransportSubBatchState::PENDING) { return; } - const auto releaseStatus = - ReleaseSubBatchResources(subBatchContext, sendBufferManager, flagBufferManager); - const auto completionStatus = status.ok() ? releaseStatus : status; - subBatchContext.state = (!completionStatus.ok() && state == TransportSubBatchState::COMPLETED) - ? TransportSubBatchState::FAILED - : state; - subBatchContext.status = completionStatus; + ReleaseSubBatchResources(subBatchContext); + subBatchContext.state = TransportSubBatchState::COMPLETED; + subBatchContext.status = status; ++ctx.completedSubBatchCount; } -void TryFinalizeTaskFromSubBatches(TransportTaskContext& ctx) +void TransportTaskContext::TryFinalizeFromSubBatches() { - if (ctx.subBatchContexts.empty()) { - ctx.state.store( - ctx.finalStatus.ok() ? TransportTaskState::COMPLETED : TransportTaskState::FAILED, - std::memory_order_release); + if (subBatchContexts.empty()) { + state.store(TransportTaskState::COMPLETED, std::memory_order_release); return; } - if (ctx.completedSubBatchCount != static_cast(ctx.subBatchContexts.size())) { - return; - } + if (completedSubBatchCount != static_cast(subBatchContexts.size())) { return; } - ctx.finalStatus = BuildTaskFinalStatus(ctx); - ctx.state.store(BuildTaskStateFromSubBatches(ctx), std::memory_order_release); + finalStatus = BuildTaskFinalStatus(*this); + state.store(TransportTaskState::COMPLETED, std::memory_order_release); } } // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/transport_task_manager.cpp b/ucm/transport/kv/asu/trans/src/transport_task_manager.cpp index d52241c51..7734abd16 100644 --- a/ucm/transport/kv/asu/trans/src/transport_task_manager.cpp +++ b/ucm/transport/kv/asu/trans/src/transport_task_manager.cpp @@ -5,8 +5,7 @@ namespace UC::ASU { bool TransportTaskContext::Done() const { auto s = state.load(std::memory_order_acquire); - return s == TransportTaskState::COMPLETED || s == TransportTaskState::FAILED || - s == TransportTaskState::CANCELED; + return s == TransportTaskState::COMPLETED || s == TransportTaskState::CANCELED; } void TransportTaskManager::CancelAll() diff --git a/ucm/transport/kv/asu/trans/src/transport_task_manager.h b/ucm/transport/kv/asu/trans/src/transport_task_manager.h index d5dab3600..3f21a44c9 100644 --- a/ucm/transport/kv/asu/trans/src/transport_task_manager.h +++ b/ucm/transport/kv/asu/trans/src/transport_task_manager.h @@ -51,32 +51,14 @@ enum class TransportTaskState { PENDING = 0, INFLIGHT = 1, COMPLETED = 2, - FAILED = 3, - CANCELED = 4, + CANCELED = 3, }; enum class TransportSubBatchState { PENDING = 0, COMPLETED = 1, - FAILED = 2, }; -constexpr std::size_t kAsuBatchLoadMaxIoNum = 110; -constexpr std::size_t kAsuBatchStoreMaxIoNum = 110; -constexpr std::size_t kAsuDeleteMaxIoNum = 254; -constexpr std::size_t kAsuQueryMaxIoNum = 256; - -inline std::size_t GetAsuMaxIoNum(TransportOpType opType) -{ - switch (opType) { - case TransportOpType::BATCH_LOAD: return kAsuBatchLoadMaxIoNum; - case TransportOpType::BATCH_STORE: return kAsuBatchStoreMaxIoNum; - case TransportOpType::DELETE: return kAsuDeleteMaxIoNum; - case TransportOpType::QUERY: return kAsuQueryMaxIoNum; - default: return 0; - } -} - inline bool IsEntryBatchOp(TransportOpType opType) { return opType == TransportOpType::BATCH_LOAD || opType == TransportOpType::BATCH_STORE; @@ -87,6 +69,8 @@ inline bool IsKeyBatchOp(TransportOpType opType) return opType == TransportOpType::DELETE || opType == TransportOpType::QUERY; } +inline bool IsKeepAliveOp(TransportOpType opType) { return opType == TransportOpType::KEEP_ALIVE; } + template struct BatchView { const T* data{nullptr}; @@ -128,6 +112,8 @@ struct TransportTaskContext { std::condition_variable cv; bool Done() const; + void InitializeTerminalSubBatchCount(); + void TryFinalizeFromSubBatches(); }; class TransportTaskManager : public TaskManagerBase { diff --git a/ucm/transport/kv/asu/trans/test/asu_submit_flow_test.cpp b/ucm/transport/kv/asu/trans/test/asu_submit_flow_test.cpp index 7b8c719a1..210f5b224 100644 --- a/ucm/transport/kv/asu/trans/test/asu_submit_flow_test.cpp +++ b/ucm/transport/kv/asu/trans/test/asu_submit_flow_test.cpp @@ -21,11 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "asu_submit_flow.h" #include #include #include #include +#define private public +#include "asu_transport_impl.h" +#undef private #include "buffer_manager.h" #include "connection_internal.h" @@ -55,7 +57,8 @@ TEST(AsuSubmitFlowTest, SendSubBatchBuffersReadsSendCountsFromAttrs) g_quietCount = 0; g_sendStatuses.clear(); - std::unordered_map attrs = { + AsuTransportImpl transport; + transport.config_.attrs = { {"kernel_count", "3"}, {"quiet_count", "7"}, }; @@ -69,9 +72,7 @@ TEST(AsuSubmitFlowTest, SendSubBatchBuffersReadsSendCountsFromAttrs) subBatchContexts[0].state = TransportSubBatchState::PENDING; subBatchContexts[0].entryStatus.assign(1, Status::OK()); - ConnectionManager connManager; - const auto status = - SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes, attrs, connManager); + const auto status = transport.SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes); EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(g_kernelCount, std::uint32_t{3}); @@ -87,15 +88,15 @@ TEST(AsuSubmitFlowTest, SendSubBatchBuffersReportsSendFailures) Status::Error(StatusCode::CONNECTION_ERROR, "fake send failure"), }; - std::unordered_map attrs = { + AsuTransportImpl transport; + transport.config_.attrs = { {"kernel_count", "3"}, {"quiet_count", "7"}, }; - ConnectionManager connManager; - ASSERT_TRUE(connManager.AddGroup(AsuEndpoint{}, 1).ok()); - auto* channel0 = connManager.SelectConnection(); - auto* channel1 = connManager.SelectConnection(); + ASSERT_TRUE(transport.connManager_.AddGroup(AsuEndpoint{}, 1).ok()); + auto* channel0 = transport.connManager_.SelectConnection(); + auto* channel1 = transport.connManager_.SelectConnection(); ASSERT_NE(channel0, nullptr); ASSERT_EQ(channel0, channel1); @@ -115,35 +116,35 @@ TEST(AsuSubmitFlowTest, SendSubBatchBuffersReportsSendFailures) subBatchContexts[1].channel = channel1; subBatchContexts[1].entryStatus.assign(1, Status::OK()); - const auto status = - SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes, attrs, connManager); + const auto status = transport.SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes); EXPECT_EQ(status.code, StatusCode::CONNECTION_ERROR); EXPECT_EQ(channel0->GetState(), ChannelState::DRAINING); - EXPECT_EQ(subBatchContexts[0].state, TransportSubBatchState::FAILED); - EXPECT_EQ(subBatchContexts[1].state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContexts[0].state, TransportSubBatchState::COMPLETED); + EXPECT_EQ(subBatchContexts[1].state, TransportSubBatchState::COMPLETED); g_sendStatuses.clear(); } TEST(AsuSubmitFlowTest, BuildSubBatchSendBuffersReleasesPreFailedSubBatches) { - BufferManager sendBufferManager; - BufferManager flagBufferManager; - ASSERT_TRUE(sendBufferManager.Init("test send buffer", MemoryType::HOST, 4096, 1).ok()); - ASSERT_TRUE(flagBufferManager.Init("test flag buffer", MemoryType::HOST, 128, 1).ok()); + AsuTransportImpl transport; + ASSERT_TRUE( + transport.sendBufferManager_.Init("test send buffer", MemoryType::HOST, 4096, 1).ok()); + ASSERT_TRUE( + transport.flagBufferManager_.Init("test flag buffer", MemoryType::HOST, 128, 1).ok()); std::vector subBatchContexts(1); auto& subBatchContext = subBatchContexts[0]; - subBatchContext.state = TransportSubBatchState::FAILED; + subBatchContext.state = TransportSubBatchState::COMPLETED; subBatchContext.status = Status::Error(StatusCode::INVALID_ARGUMENT, "pre-send failure"); subBatchContext.entryStatus.assign(1, subBatchContext.status); - ASSERT_TRUE(sendBufferManager.Allocate(64, subBatchContext.sendSge).ok()); - ASSERT_TRUE(flagBufferManager.Allocate(64, subBatchContext.flagBuffer).ok()); + ASSERT_TRUE(transport.sendBufferManager_.Allocate(64, subBatchContext.sendSge).ok()); + ASSERT_TRUE(transport.flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); std::vector ioBatches; std::vector subBatchIndexes; - const auto status = BuildSubBatchSendBuffers(subBatchContexts, ioBatches, subBatchIndexes, - sendBufferManager, flagBufferManager); + const auto status = + transport.BuildSubBatchSendBuffers(subBatchContexts, ioBatches, subBatchIndexes); EXPECT_EQ(status.code, StatusCode::PARTIAL_FAILED); EXPECT_TRUE(ioBatches.empty()); @@ -156,14 +157,14 @@ TEST(AsuSubmitFlowTest, BuildSubBatchSendBuffersReleasesPreFailedSubBatches) TEST(AsuSubmitFlowTest, BuildSubBatchSendBuffersMarksMissingFlagBufferFailed) { - BufferManager sendBufferManager; - BufferManager flagBufferManager; - ASSERT_TRUE(sendBufferManager.Init("test send buffer", MemoryType::HOST, 4096, 1).ok()); - ASSERT_TRUE(flagBufferManager.Init("test flag buffer", MemoryType::HOST, 128, 1).ok()); - - ConnectionManager connManager; - ASSERT_TRUE(connManager.AddGroup(AsuEndpoint{}, 1).ok()); - auto* channel = connManager.SelectConnection(); + AsuTransportImpl transport; + ASSERT_TRUE( + transport.sendBufferManager_.Init("test send buffer", MemoryType::HOST, 4096, 1).ok()); + ASSERT_TRUE( + transport.flagBufferManager_.Init("test flag buffer", MemoryType::HOST, 128, 1).ok()); + + ASSERT_TRUE(transport.connManager_.AddGroup(AsuEndpoint{}, 1).ok()); + auto* channel = transport.connManager_.SelectConnection(); ASSERT_NE(channel, nullptr); EXPECT_EQ(channel->GetInflightCount(), std::uint32_t{1}); @@ -172,17 +173,17 @@ TEST(AsuSubmitFlowTest, BuildSubBatchSendBuffersMarksMissingFlagBufferFailed) subBatchContext.state = TransportSubBatchState::PENDING; subBatchContext.channel = channel; subBatchContext.entryStatus.assign(2, Status::OK()); - ASSERT_TRUE(sendBufferManager.Allocate(64, subBatchContext.sendSge).ok()); + ASSERT_TRUE(transport.sendBufferManager_.Allocate(64, subBatchContext.sendSge).ok()); std::vector ioBatches; std::vector subBatchIndexes; - const auto status = BuildSubBatchSendBuffers(subBatchContexts, ioBatches, subBatchIndexes, - sendBufferManager, flagBufferManager); + const auto status = + transport.BuildSubBatchSendBuffers(subBatchContexts, ioBatches, subBatchIndexes); EXPECT_EQ(status.code, StatusCode::NOT_INITIALIZED); EXPECT_TRUE(ioBatches.empty()); EXPECT_TRUE(subBatchIndexes.empty()); - EXPECT_EQ(subBatchContext.state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContext.state, TransportSubBatchState::COMPLETED); EXPECT_EQ(subBatchContext.status.code, StatusCode::NOT_INITIALIZED); EXPECT_EQ(subBatchContext.channel, nullptr); EXPECT_EQ(channel->GetInflightCount(), std::uint32_t{0}); @@ -196,7 +197,8 @@ TEST(AsuSubmitFlowTest, SendSubBatchBuffersFailsAllSentSubBatchesWhenStatusCount { g_sendStatuses = {Status::OK()}; - std::unordered_map attrs = { + AsuTransportImpl transport; + transport.config_.attrs = { {"kernel_count", "3"}, {"quiet_count", "7"}, }; @@ -215,15 +217,13 @@ TEST(AsuSubmitFlowTest, SendSubBatchBuffersFailsAllSentSubBatchesWhenStatusCount subBatchContexts[1].state = TransportSubBatchState::PENDING; subBatchContexts[1].entryStatus.assign(1, Status::OK()); - ConnectionManager connManager; - const auto status = - SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes, attrs, connManager); + const auto status = transport.SendSubBatchBuffers(subBatchContexts, ioBatches, subBatchIndexes); EXPECT_EQ(status.code, StatusCode::INTERNAL_ERROR); - EXPECT_EQ(subBatchContexts[0].state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContexts[0].state, TransportSubBatchState::COMPLETED); EXPECT_EQ(subBatchContexts[0].status.code, StatusCode::INTERNAL_ERROR); EXPECT_EQ(subBatchContexts[0].entryStatus[0].code, StatusCode::INTERNAL_ERROR); - EXPECT_EQ(subBatchContexts[1].state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContexts[1].state, TransportSubBatchState::COMPLETED); EXPECT_EQ(subBatchContexts[1].status.code, StatusCode::INTERNAL_ERROR); EXPECT_EQ(subBatchContexts[1].entryStatus[0].code, StatusCode::INTERNAL_ERROR); g_sendStatuses.clear(); diff --git a/ucm/transport/kv/asu/trans/test/io_scheduler_test.cpp b/ucm/transport/kv/asu/trans/test/io_scheduler_test.cpp index 77f310d50..3000bc218 100644 --- a/ucm/transport/kv/asu/trans/test/io_scheduler_test.cpp +++ b/ucm/transport/kv/asu/trans/test/io_scheduler_test.cpp @@ -36,9 +36,11 @@ TEST(IoSchedulerTest, SplitEntryBatchPreservesOrderAndUsesViews) entries[index].key = "key_" + std::to_string(index); } - IoScheduler scheduler; - const auto batches = - scheduler.SplitForAsu(BatchView{entries.data(), entries.size()}, 2); + TransportConfig config; + config.asuBatchLoadIoNum = 2; + IoScheduler scheduler(config); + const auto batches = scheduler.SplitForAsu(BatchView{entries.data(), entries.size()}, + TransportOpType::BATCH_LOAD); ASSERT_EQ(batches.size(), std::size_t{3}); EXPECT_EQ(batches[0].entries.size, std::size_t{2}); @@ -52,21 +54,49 @@ TEST(IoSchedulerTest, SplitEntryBatchPreservesOrderAndUsesViews) TEST(IoSchedulerTest, SplitKeyBatchReturnsEmptyForEmptyInputOrZeroLimit) { - IoScheduler scheduler; + TransportConfig config; + config.asuQueryIoNum = 0; + IoScheduler scheduler(config); std::vector keys = {"a", "b"}; - EXPECT_TRUE(scheduler.SplitForAsu(BatchView{keys.data(), 0}, 2).empty()); - EXPECT_TRUE(scheduler.SplitForAsu(BatchView{keys.data(), keys.size()}, 0).empty()); + EXPECT_TRUE(scheduler.SplitForAsu(BatchView{keys.data(), 0}, TransportOpType::DELETE) + .empty()); + EXPECT_TRUE( + scheduler.SplitForAsu(BatchView{keys.data(), keys.size()}, TransportOpType::QUERY) + .empty()); } -TEST(IoSchedulerTest, GetSqeBatchLimitMatchesOperationKind) +TEST(IoSchedulerTest, GetSqeIoNumMatchesOperationKind) { - EXPECT_EQ(GetSqeBatchLimit(TransportOpType::LOAD), std::size_t{1}); - EXPECT_EQ(GetSqeBatchLimit(TransportOpType::STORE), std::size_t{1}); - EXPECT_EQ(GetSqeBatchLimit(TransportOpType::BATCH_LOAD), kAsuBatchLoadMaxIoNum); - EXPECT_EQ(GetSqeBatchLimit(TransportOpType::BATCH_STORE), kAsuBatchStoreMaxIoNum); - EXPECT_EQ(GetSqeBatchLimit(TransportOpType::DELETE), kAsuDeleteMaxIoNum); - EXPECT_EQ(GetSqeBatchLimit(TransportOpType::QUERY), kAsuQueryMaxIoNum); + TransportConfig config; + config.asuBatchLoadIoNum = 3; + config.asuBatchStoreIoNum = 4; + config.asuDeleteIoNum = 5; + config.asuQueryIoNum = 6; + IoScheduler scheduler(config); + + EXPECT_EQ(scheduler.GetSqeIoNum(TransportOpType::LOAD), std::size_t{1}); + EXPECT_EQ(scheduler.GetSqeIoNum(TransportOpType::STORE), std::size_t{1}); + EXPECT_EQ(scheduler.GetSqeIoNum(TransportOpType::BATCH_LOAD), std::size_t{3}); + EXPECT_EQ(scheduler.GetSqeIoNum(TransportOpType::BATCH_STORE), std::size_t{4}); + EXPECT_EQ(scheduler.GetSqeIoNum(TransportOpType::DELETE), std::size_t{5}); + EXPECT_EQ(scheduler.GetSqeIoNum(TransportOpType::QUERY), std::size_t{6}); +} + +TEST(IoSchedulerTest, SplitByOperationUsesHeldConfig) +{ + TransportConfig config; + config.asuBatchLoadIoNum = 2; + IoScheduler scheduler(config); + std::vector entries(5); + + const auto batches = scheduler.SplitForAsu(BatchView{entries.data(), entries.size()}, + TransportOpType::BATCH_LOAD); + + ASSERT_EQ(batches.size(), std::size_t{3}); + EXPECT_EQ(batches[0].entries.size, std::size_t{2}); + EXPECT_EQ(batches[1].entries.size, std::size_t{2}); + EXPECT_EQ(batches[2].entries.size, std::size_t{1}); } } // namespace diff --git a/ucm/transport/kv/asu/trans/test/sqe_request_test.cpp b/ucm/transport/kv/asu/trans/test/sqe_request_test.cpp index a32194b34..8d4e4d491 100644 --- a/ucm/transport/kv/asu/trans/test/sqe_request_test.cpp +++ b/ucm/transport/kv/asu/trans/test/sqe_request_test.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "sqe_request.h" #include #include #include @@ -29,6 +28,9 @@ #include #include #include +#define private public +#include "asu_transport_impl.h" +#undef private #include "buffer_manager.h" #include "kv_protocol.h" @@ -85,39 +87,44 @@ class SqeRequestTest : public ::testing::Test { void SetUp() override { - auto status = flagBufferManager_.Init("test flag buffer", MemoryType::HOST, - kFlagBufferSlotSize, kFlagBufferSlotNum); + transport_.config_.attrs = DefaultAttrs(); + transport_.nextRequestCid_.store(1, std::memory_order_relaxed); + auto status = transport_.flagBufferManager_.Init("test flag buffer", MemoryType::HOST, + kFlagBufferSlotSize, kFlagBufferSlotNum); ASSERT_TRUE(status.ok()) << status.message; - status = sendBufferManager_.Init("test send buffer", MemoryType::HOST, - kTestSendBufferSlotSize, kTestSendBufferSlotNum); + status = transport_.sendBufferManager_.Init( + "test send buffer", MemoryType::HOST, kTestSendBufferSlotSize, kTestSendBufferSlotNum); ASSERT_TRUE(status.ok()) << status.message; - protocolManager_ = std::make_unique(); + transport_.protocolManager_ = std::make_unique(); } - BufferManager sendBufferManager_; - BufferManager flagBufferManager_; - std::unique_ptr protocolManager_; + AsuTransportImpl transport_; }; TEST_F(SqeRequestTest, ValidateSqeRequestAttrsRejectsMalformedValues) { - EXPECT_TRUE(ValidateSqeRequestAttrs(DefaultAttrs()).ok()); + transport_.config_.attrs = DefaultAttrs(); + EXPECT_TRUE(transport_.ValidateSqeRequestAttrs().ok()); auto attrs = DefaultAttrs(); attrs["dtype"] = "256"; - EXPECT_EQ(ValidateSqeRequestAttrs(attrs).code, StatusCode::INVALID_ARGUMENT); + transport_.config_.attrs = attrs; + EXPECT_EQ(transport_.ValidateSqeRequestAttrs().code, StatusCode::INVALID_ARGUMENT); attrs = DefaultAttrs(); attrs["lr"] = "maybe"; - EXPECT_EQ(ValidateSqeRequestAttrs(attrs).code, StatusCode::INVALID_ARGUMENT); + transport_.config_.attrs = attrs; + EXPECT_EQ(transport_.ValidateSqeRequestAttrs().code, StatusCode::INVALID_ARGUMENT); attrs = DefaultAttrs(); attrs.erase("kernel_count"); - EXPECT_EQ(ValidateSqeRequestAttrs(attrs).code, StatusCode::INVALID_ARGUMENT); + transport_.config_.attrs = attrs; + EXPECT_EQ(transport_.ValidateSqeRequestAttrs().code, StatusCode::INVALID_ARGUMENT); attrs = DefaultAttrs(); attrs["quiet_count"] = "0"; - EXPECT_EQ(ValidateSqeRequestAttrs(attrs).code, StatusCode::INVALID_ARGUMENT); + transport_.config_.attrs = attrs; + EXPECT_EQ(transport_.ValidateSqeRequestAttrs().code, StatusCode::INVALID_ARGUMENT); } TEST_F(SqeRequestTest, SubmitBatchStoreAllocatesFlagBufferAndBuildsRequest) @@ -127,11 +134,10 @@ TEST_F(SqeRequestTest, SubmitBatchStoreAllocatesFlagBufferAndBuildsRequest) BatchView{entries.data(), entries.size()} }; TransportSubBatchContext subBatchContext; - std::uint16_t nextCid = 41; + transport_.nextRequestCid_.store(41, std::memory_order_relaxed); - const auto status = SubmitEntrySubBatchRequest( - TransportOpType::BATCH_STORE, subBatch, DefaultAttrs(), [&nextCid] { return nextCid++; }, - sendBufferManager_, flagBufferManager_, *protocolManager_, subBatchContext); + const auto status = transport_.SubmitEntrySubBatchRequest(TransportOpType::BATCH_STORE, + subBatch, subBatchContext); EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(subBatchContext.flagBuffer.length, kFlagBufferHeaderSize + (entries.size() + 1) / 2); @@ -151,10 +157,10 @@ TEST_F(SqeRequestTest, SubmitBatchRetrieveUsesRetrieveOpcodeAndRequest) BatchView{entries.data(), entries.size()} }; TransportSubBatchContext subBatchContext; + transport_.nextRequestCid_.store(9, std::memory_order_relaxed); - const auto status = SubmitEntrySubBatchRequest( - TransportOpType::BATCH_LOAD, subBatch, DefaultAttrs(), [] { return std::uint16_t{9}; }, - sendBufferManager_, flagBufferManager_, *protocolManager_, subBatchContext); + const auto status = transport_.SubmitEntrySubBatchRequest(TransportOpType::BATCH_LOAD, subBatch, + subBatchContext); EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(subBatchContext.opType, TransportOpType::BATCH_LOAD); @@ -171,10 +177,10 @@ TEST_F(SqeRequestTest, SubmitDeleteCopiesKeysAndBuildsFlagBackedRequest) BatchView{keys.data(), keys.size()} }; TransportSubBatchContext subBatchContext; + transport_.nextRequestCid_.store(55, std::memory_order_relaxed); - const auto status = SubmitKeySubBatchRequest( - TransportOpType::DELETE, subBatch, DefaultAttrs(), [] { return std::uint16_t{55}; }, - sendBufferManager_, flagBufferManager_, *protocolManager_, subBatchContext); + const auto status = + transport_.SubmitKeySubBatchRequest(TransportOpType::DELETE, subBatch, subBatchContext); EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(subBatchContext.opType, TransportOpType::DELETE); @@ -193,10 +199,10 @@ TEST_F(SqeRequestTest, SubmitExistReadsScAttribute) BatchView{keys.data(), keys.size()} }; TransportSubBatchContext subBatchContext; + transport_.nextRequestCid_.store(13, std::memory_order_relaxed); - const auto status = SubmitKeySubBatchRequest( - TransportOpType::QUERY, subBatch, DefaultAttrs(), [] { return std::uint16_t{13}; }, - sendBufferManager_, flagBufferManager_, *protocolManager_, subBatchContext); + const auto status = + transport_.SubmitKeySubBatchRequest(TransportOpType::QUERY, subBatch, subBatchContext); EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(subBatchContext.opType, TransportOpType::QUERY); @@ -210,6 +216,8 @@ TEST_F(SqeRequestTest, SubmitExistDisablesSeekControlWhenScDisabled) { auto attrs = DefaultAttrs(); attrs["sc"] = "false"; + transport_.config_.attrs = attrs; + transport_.nextRequestCid_.store(13, std::memory_order_relaxed); std::vector keys = {"k0"}; IoScheduler::ScheduledKeyBatch subBatch{ @@ -217,9 +225,8 @@ TEST_F(SqeRequestTest, SubmitExistDisablesSeekControlWhenScDisabled) }; TransportSubBatchContext subBatchContext; - const auto status = SubmitKeySubBatchRequest( - TransportOpType::QUERY, subBatch, attrs, [] { return std::uint16_t{13}; }, - sendBufferManager_, flagBufferManager_, *protocolManager_, subBatchContext); + const auto status = + transport_.SubmitKeySubBatchRequest(TransportOpType::QUERY, subBatch, subBatchContext); EXPECT_TRUE(status.ok()) << status.message; EXPECT_FALSE(subBatchContext.useSeekControl); @@ -232,14 +239,20 @@ TEST_F(SqeRequestTest, AllocationFailureMarksWholeSubBatchFailed) BatchView{entries.data(), entries.size()} }; TransportSubBatchContext subBatchContext; - BufferManager uninitializedFlagBufferManager; - - const auto status = SubmitEntrySubBatchRequest( - TransportOpType::BATCH_STORE, subBatch, DefaultAttrs(), [] { return std::uint16_t{3}; }, - sendBufferManager_, uninitializedFlagBufferManager, *protocolManager_, subBatchContext); + AsuTransportImpl uninitializedFlagTransport; + uninitializedFlagTransport.config_.attrs = DefaultAttrs(); + uninitializedFlagTransport.nextRequestCid_.store(3, std::memory_order_relaxed); + ASSERT_TRUE(uninitializedFlagTransport.sendBufferManager_ + .Init("test send buffer", MemoryType::HOST, kTestSendBufferSlotSize, + kTestSendBufferSlotNum) + .ok()); + uninitializedFlagTransport.protocolManager_ = std::make_unique(); + + const auto status = uninitializedFlagTransport.SubmitEntrySubBatchRequest( + TransportOpType::BATCH_STORE, subBatch, subBatchContext); EXPECT_EQ(status.code, StatusCode::NOT_INITIALIZED); - EXPECT_EQ(subBatchContext.state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContext.state, TransportSubBatchState::COMPLETED); EXPECT_EQ(subBatchContext.status.code, StatusCode::NOT_INITIALIZED); EXPECT_EQ(subBatchContext.flagBuffer.addr, std::uint64_t{0}); EXPECT_EQ(subBatchContext.sendSge.addr, std::uint64_t{0}); @@ -252,10 +265,9 @@ TEST_F(SqeRequestTest, AllocationFailureMarksWholeSubBatchFailed) TEST_F(SqeRequestTest, SubmitKeepAliveBuildsFlagBackedRequest) { TransportSubBatchContext subBatchContext; + transport_.nextRequestCid_.store(77, std::memory_order_relaxed); - const auto status = - SubmitKeepAliveRequest([] { return std::uint16_t{77}; }, sendBufferManager_, - flagBufferManager_, *protocolManager_, subBatchContext); + const auto status = transport_.SubmitKeepAliveRequest(subBatchContext); EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(subBatchContext.cid, std::uint16_t{77}); diff --git a/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp b/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp index 2a29380a8..53a0af236 100644 --- a/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp +++ b/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp @@ -21,11 +21,13 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "transport_task_completion.h" #include #include #include #include +#define private public +#include "asu_transport_impl.h" +#undef private #include "buffer_manager.h" #include "connection_internal.h" @@ -51,16 +53,15 @@ class TransportTaskCompletionTest : public ::testing::Test { void SetUp() override { - auto status = sendBufferManager_.Init("test send buffer", MemoryType::HOST, - kTestBufferSlotSize, kTestBufferSlotNum); + auto status = transport_.sendBufferManager_.Init("test send buffer", MemoryType::HOST, + kTestBufferSlotSize, kTestBufferSlotNum); ASSERT_TRUE(status.ok()) << status.message; - status = flagBufferManager_.Init("test flag buffer", MemoryType::HOST, kTestBufferSlotSize, - kTestBufferSlotNum); + status = transport_.flagBufferManager_.Init("test flag buffer", MemoryType::HOST, + kTestBufferSlotSize, kTestBufferSlotNum); ASSERT_TRUE(status.ok()) << status.message; } - BufferManager sendBufferManager_; - BufferManager flagBufferManager_; + AsuTransportImpl transport_; }; TEST_F(TransportTaskCompletionTest, InitializeCountsAlreadyTerminalSubBatches) @@ -69,9 +70,10 @@ TEST_F(TransportTaskCompletionTest, InitializeCountsAlreadyTerminalSubBatches) ctx.subBatchContexts.resize(3); ctx.subBatchContexts[0].state = TransportSubBatchState::PENDING; ctx.subBatchContexts[1].state = TransportSubBatchState::COMPLETED; - ctx.subBatchContexts[2].state = TransportSubBatchState::FAILED; + ctx.subBatchContexts[2].state = TransportSubBatchState::COMPLETED; + ctx.subBatchContexts[2].status = Status::Error(StatusCode::IO_ERROR, "fake error"); - InitializeTerminalSubBatchCount(ctx); + ctx.InitializeTerminalSubBatchCount(); EXPECT_EQ(ctx.completedSubBatchCount, std::uint32_t{2}); } @@ -82,26 +84,22 @@ TEST_F(TransportTaskCompletionTest, CompleteSubBatchOnlyCountsPendingSubBatchOnc TransportSubBatchContext subBatchContext; const auto status = Status::Error(StatusCode::IO_ERROR, "fake error"); - CompleteSubBatch(ctx, subBatchContext, TransportSubBatchState::FAILED, status, - sendBufferManager_, flagBufferManager_); - CompleteSubBatch(ctx, subBatchContext, TransportSubBatchState::FAILED, status, - sendBufferManager_, flagBufferManager_); + transport_.CompleteSubBatch(ctx, subBatchContext, status); + transport_.CompleteSubBatch(ctx, subBatchContext, status); EXPECT_EQ(ctx.completedSubBatchCount, std::uint32_t{1}); - EXPECT_EQ(subBatchContext.state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContext.state, TransportSubBatchState::COMPLETED); EXPECT_EQ(subBatchContext.status.code, StatusCode::IO_ERROR); } TEST_F(TransportTaskCompletionTest, ReleaseSubBatchResourcesClearsAllocatedSlots) { TransportSubBatchContext subBatchContext; - ASSERT_TRUE(sendBufferManager_.Allocate(64, subBatchContext.sendSge).ok()); - ASSERT_TRUE(flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); + ASSERT_TRUE(transport_.sendBufferManager_.Allocate(64, subBatchContext.sendSge).ok()); + ASSERT_TRUE(transport_.flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); - const auto status = - ReleaseSubBatchResources(subBatchContext, sendBufferManager_, flagBufferManager_); + transport_.ReleaseSubBatchResources(subBatchContext); - EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(subBatchContext.sendSge.slot_index, UINT32_MAX); EXPECT_EQ(subBatchContext.flagBuffer.slot_index, UINT32_MAX); EXPECT_EQ(subBatchContext.sendSge.addr, std::uint64_t{0}); @@ -111,29 +109,25 @@ TEST_F(TransportTaskCompletionTest, ReleaseSubBatchResourcesClearsAllocatedSlots TEST_F(TransportTaskCompletionTest, ReleaseSubBatchResourcesPreservesSubBatchStatus) { TransportSubBatchContext subBatchContext; - subBatchContext.state = TransportSubBatchState::FAILED; + subBatchContext.state = TransportSubBatchState::COMPLETED; subBatchContext.status = Status::Error(StatusCode::IO_ERROR, "send failed"); - ASSERT_TRUE(sendBufferManager_.Allocate(64, subBatchContext.sendSge).ok()); - ASSERT_TRUE(flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); + ASSERT_TRUE(transport_.sendBufferManager_.Allocate(64, subBatchContext.sendSge).ok()); + ASSERT_TRUE(transport_.flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); - const auto releaseStatus = - ReleaseSubBatchResources(subBatchContext, sendBufferManager_, flagBufferManager_); + transport_.ReleaseSubBatchResources(subBatchContext); - EXPECT_TRUE(releaseStatus.ok()) << releaseStatus.message; - EXPECT_EQ(subBatchContext.state, TransportSubBatchState::FAILED); + EXPECT_EQ(subBatchContext.state, TransportSubBatchState::COMPLETED); EXPECT_EQ(subBatchContext.status.code, StatusCode::IO_ERROR); } -TEST_F(TransportTaskCompletionTest, ReleaseSubBatchResourcesReturnsFreeFailure) +TEST_F(TransportTaskCompletionTest, ReleaseSubBatchResourcesClearsSlotsAfterFreeFailure) { TransportSubBatchContext subBatchContext; subBatchContext.sendSge.slot_index = kTestBufferSlotNum; - ASSERT_TRUE(flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); + ASSERT_TRUE(transport_.flagBufferManager_.Allocate(64, subBatchContext.flagBuffer).ok()); - const auto status = - ReleaseSubBatchResources(subBatchContext, sendBufferManager_, flagBufferManager_); + transport_.ReleaseSubBatchResources(subBatchContext); - EXPECT_EQ(status.code, StatusCode::INVALID_ARGUMENT); EXPECT_EQ(subBatchContext.sendSge.slot_index, UINT32_MAX); EXPECT_EQ(subBatchContext.flagBuffer.slot_index, UINT32_MAX); } @@ -149,10 +143,8 @@ TEST_F(TransportTaskCompletionTest, ReleaseSubBatchResourcesReleasesChannelInfli TransportSubBatchContext subBatchContext; subBatchContext.channel = channel; - const auto status = - ReleaseSubBatchResources(subBatchContext, sendBufferManager_, flagBufferManager_); + transport_.ReleaseSubBatchResources(subBatchContext); - EXPECT_TRUE(status.ok()) << status.message; EXPECT_EQ(channel->GetInflightCount(), std::uint32_t{0}); EXPECT_EQ(subBatchContext.channel, nullptr); } @@ -162,16 +154,16 @@ TEST_F(TransportTaskCompletionTest, TryFinalizeEmptyTaskUsesExistingFinalStatus) TransportTaskContext ctx; ctx.finalStatus = Status::OK(); - TryFinalizeTaskFromSubBatches(ctx); + ctx.TryFinalizeFromSubBatches(); EXPECT_EQ(ctx.state.load(std::memory_order_acquire), TransportTaskState::COMPLETED); ctx.state.store(TransportTaskState::PENDING, std::memory_order_release); ctx.finalStatus = Status::Error(StatusCode::UNSUPPORTED, "unsupported"); - TryFinalizeTaskFromSubBatches(ctx); + ctx.TryFinalizeFromSubBatches(); - EXPECT_EQ(ctx.state.load(std::memory_order_acquire), TransportTaskState::FAILED); + EXPECT_EQ(ctx.state.load(std::memory_order_acquire), TransportTaskState::COMPLETED); } TEST_F(TransportTaskCompletionTest, TryFinalizeWaitsUntilAllSubBatchesFinish) @@ -181,7 +173,7 @@ TEST_F(TransportTaskCompletionTest, TryFinalizeWaitsUntilAllSubBatchesFinish) ctx.completedSubBatchCount = 1; ctx.state.store(TransportTaskState::INFLIGHT, std::memory_order_release); - TryFinalizeTaskFromSubBatches(ctx); + ctx.TryFinalizeFromSubBatches(); EXPECT_EQ(ctx.state.load(std::memory_order_acquire), TransportTaskState::INFLIGHT); } @@ -192,13 +184,13 @@ TEST_F(TransportTaskCompletionTest, TryFinalizeAggregatesSuccessAndFailure) ctx.subBatchContexts.resize(2); ctx.subBatchContexts[0].state = TransportSubBatchState::COMPLETED; ctx.subBatchContexts[0].status = Status::OK(); - ctx.subBatchContexts[1].state = TransportSubBatchState::FAILED; + ctx.subBatchContexts[1].state = TransportSubBatchState::COMPLETED; ctx.subBatchContexts[1].status = Status::Error(StatusCode::IO_ERROR, "sub-batch failed"); ctx.completedSubBatchCount = 2; - TryFinalizeTaskFromSubBatches(ctx); + ctx.TryFinalizeFromSubBatches(); - EXPECT_EQ(ctx.state.load(std::memory_order_acquire), TransportTaskState::FAILED); + EXPECT_EQ(ctx.state.load(std::memory_order_acquire), TransportTaskState::COMPLETED); EXPECT_EQ(ctx.finalStatus.code, StatusCode::PARTIAL_FAILED); TransportTaskContext successCtx; @@ -207,7 +199,7 @@ TEST_F(TransportTaskCompletionTest, TryFinalizeAggregatesSuccessAndFailure) successCtx.subBatchContexts[0].status = Status::OK(); successCtx.completedSubBatchCount = 1; - TryFinalizeTaskFromSubBatches(successCtx); + successCtx.TryFinalizeFromSubBatches(); EXPECT_EQ(successCtx.state.load(std::memory_order_acquire), TransportTaskState::COMPLETED); EXPECT_TRUE(successCtx.finalStatus.ok());