Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 57 additions & 56 deletions ucm/transport/kv/asu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,71 +1,72 @@
if(DEFINED ASCEND_ROOT AND NOT DEFINED Ascend_ROOT)
set(Ascend_ROOT "${ASCEND_ROOT}" CACHE PATH "Path to Ascend root directory")
if(NOT RUNTIME_ENVIRONMENT STREQUAL "ascend")
message(FATAL_ERROR "BUILD_UCM_ASU requires RUNTIME_ENVIRONMENT=ascend. Current value: ${RUNTIME_ENVIRONMENT}")
endif()
if(NOT DEFINED Ascend_ROOT)
set(Ascend_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory")

if(NOT DEFINED ASCEND_ROOT)
set(ASCEND_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory")
endif()

find_path(ASU_ASCEND_INCLUDE_DIR
NAMES acl/acl.h
HINTS
${Ascend_ROOT}/include
${Ascend_ROOT}/aarch64-linux/include
${Ascend_ROOT}/arm64-linux/include
${ASCEND_ROOT}/include
${ASCEND_ROOT}/aarch64-linux/include
${ASCEND_ROOT}/arm64-linux/include
NO_DEFAULT_PATH
)
if(NOT ASU_ASCEND_INCLUDE_DIR)
message(FATAL_ERROR "Cannot find acl/acl.h under Ascend_ROOT=${Ascend_ROOT}")
message(FATAL_ERROR "Cannot find acl/acl.h under ASCEND_ROOT=${ASCEND_ROOT}")
endif()
if(RUNTIME_ENVIRONMENT STREQUAL "ascend")
file(GLOB ASU_TRANSPORT_SOURCES CONFIGURE_DEPENDS trans/src/*.cpp)
add_library(asu_transport SHARED ${ASU_TRANSPORT_SOURCES})
target_include_directories(asu_transport
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/trans/include
${ASU_ASCEND_INCLUDE_DIR}
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/common
${CMAKE_CURRENT_SOURCE_DIR}/trans/src
${UCM_ROOT_DIR}/ucm/shared/infra
)
target_link_libraries(asu_transport PUBLIC infra_logger pthread)
file(GLOB ASU_COMMON_SOURCES CONFIGURE_DEPENDS common/*.cpp)
file(GLOB ASU_TRANSPORT_SOURCES CONFIGURE_DEPENDS trans/src/*.cpp)
list(APPEND ASU_TRANSPORT_SOURCES ${ASU_COMMON_SOURCES})
add_library(asu_transport SHARED ${ASU_TRANSPORT_SOURCES})
target_include_directories(asu_transport
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/trans/include
${ASU_ASCEND_INCLUDE_DIR}
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/common
${CMAKE_CURRENT_SOURCE_DIR}/trans/src
${UCM_ROOT_DIR}/ucm/shared/infra
)
target_link_libraries(asu_transport PUBLIC infra_logger pthread)

target_link_libraries(asu_transport PUBLIC trans)
target_link_libraries(asu_transport PUBLIC trans)

file(GLOB ASU_CLIENT_SOURCES CONFIGURE_DEPENDS client/src/*.cpp)
add_library(asu_client SHARED ${ASU_CLIENT_SOURCES})
target_include_directories(asu_client
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/client/include
file(GLOB ASU_CLIENT_SOURCES CONFIGURE_DEPENDS client/src/*.cpp)
add_library(asu_client SHARED ${ASU_CLIENT_SOURCES})
target_include_directories(asu_client
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/client/include
)
target_include_directories(asu_client
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/common
${CMAKE_CURRENT_SOURCE_DIR}/client/src
${UCM_ROOT_DIR}/ucm/shared/infra
)
target_link_libraries(asu_client PUBLIC asu_transport kv_common pthread)

if(BUILD_UNIT_TESTS)
target_compile_definitions(asu_transport PRIVATE ASU_BUILD_TESTS)
include(GoogleTest)
file(GLOB_RECURSE ASU_TEST_SOURCE_FILES CONFIGURE_DEPENDS test/case/*.cc)
file(GLOB_RECURSE ASU_TRANSPORT_TEST_SOURCE_FILES CONFIGURE_DEPENDS trans/test/*.cpp)
list(APPEND ASU_TEST_SOURCE_FILES ${ASU_TRANSPORT_TEST_SOURCE_FILES})
file(GLOB_RECURSE ASU_CLIENT_TEST_SOURCE_FILES CONFIGURE_DEPENDS client/test/*.cpp)
list(APPEND ASU_TEST_SOURCE_FILES ${ASU_CLIENT_TEST_SOURCE_FILES})
add_executable(asu.test ${ASU_TEST_SOURCE_FILES})
target_compile_definitions(asu.test PRIVATE ASU_BUILD_TESTS)
target_include_directories(asu.test PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/common
${CMAKE_CURRENT_SOURCE_DIR}/client/src
${CMAKE_CURRENT_SOURCE_DIR}/trans/src
${UCM_ROOT_DIR}/ucm/shared/infra
)
target_include_directories(asu_client
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/common
${CMAKE_CURRENT_SOURCE_DIR}/client/src
${UCM_ROOT_DIR}/ucm/shared/infra
target_link_libraries(asu.test PRIVATE
asu_client
gtest_main gtest
)
target_link_libraries(asu_client PUBLIC asu_transport kv_common pthread)

if(BUILD_UNIT_TESTS)
target_compile_definitions(asu_transport PRIVATE ASU_BUILD_TESTS)
include(GoogleTest)
file(GLOB_RECURSE ASU_TEST_SOURCE_FILES CONFIGURE_DEPENDS test/case/*.cc)
file(GLOB_RECURSE ASU_TRANSPORT_TEST_SOURCE_FILES CONFIGURE_DEPENDS trans/test/*.cpp)
list(APPEND ASU_TEST_SOURCE_FILES ${ASU_TRANSPORT_TEST_SOURCE_FILES})
file(GLOB_RECURSE ASU_CLIENT_TEST_SOURCE_FILES CONFIGURE_DEPENDS client/test/*.cpp)
list(APPEND ASU_TEST_SOURCE_FILES ${ASU_CLIENT_TEST_SOURCE_FILES})
add_executable(asu.test ${ASU_TEST_SOURCE_FILES})
target_compile_definitions(asu.test PRIVATE ASU_BUILD_TESTS)
target_include_directories(asu.test PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/common
${CMAKE_CURRENT_SOURCE_DIR}/client/src
${CMAKE_CURRENT_SOURCE_DIR}/trans/src
${UCM_ROOT_DIR}/ucm/shared/infra
)
target_link_libraries(asu.test PRIVATE
asu_client
gtest_main gtest
)
gtest_discover_tests(asu.test)
endif()
gtest_discover_tests(asu.test)
endif()
6 changes: 2 additions & 4 deletions ucm/transport/kv/asu/client/src/asu_client_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ bool AsuClientImpl::PollTask(const ClientTaskContextPtr& ctx)
ctx->finalStatus =
anyFailed ? Status::Error(StatusCode::PARTIAL_FAILED, "client task partially failed")
: Status::OK();
ctx->state.store(anyFailed ? ClientTaskState::FAILED : ClientTaskState::COMPLETED,
std::memory_order_release);
ctx->state.store(ClientTaskState::COMPLETED, std::memory_order_release);
ctx->cv.notify_all();
return true;
}
Expand Down Expand Up @@ -664,8 +663,7 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint
ctx->finalStatus = anyFailed ? Status::Error(StatusCode::PARTIAL_FAILED,
"client task partially failed")
: Status::OK();
ctx->state.store(anyFailed ? ClientTaskState::FAILED : ClientTaskState::COMPLETED,
std::memory_order_release);
ctx->state.store(ClientTaskState::COMPLETED, std::memory_order_release);
ctx->cv.notify_all();
break;
}
Expand Down
143 changes: 29 additions & 114 deletions ucm/transport/kv/asu/client/src/client_config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,121 +25,21 @@
#include <algorithm>
#include <cctype>
#include <fstream>
#include <sstream>
#include <unordered_map>
#include <utility>
#include <vector>
#include "asu_client/asu_client.h"
#include "config_parser_common.h"
#include "view_server.h"

namespace UC::ASU {
namespace {

std::string Trim(const std::string& value)
{
const auto begin = value.find_first_not_of(" \t\r\n");
if (begin == std::string::npos) { return ""; }
const auto end = value.find_last_not_of(" \t\r\n");
return value.substr(begin, end - begin + 1);
}

std::vector<std::string> Split(const std::string& value, char delimiter)
{
std::vector<std::string> parts;
std::stringstream stream{value};
std::string part;
while (std::getline(stream, part, delimiter)) {
part = Trim(part);
if (!part.empty()) { parts.emplace_back(std::move(part)); }
}
return parts;
}

std::uint64_t ParseUint64(const std::string& value) { return std::stoull(value, nullptr, 0); }

Protocol ToTransportProtocol(const std::string& value)
{
auto protocol = value;
std::transform(protocol.begin(), protocol.end(), protocol.begin(),
[](unsigned char ch) { return static_cast<char>(std::toupper(ch)); });
if (protocol == "UB" || protocol == "UBOE") { return Protocol::UB; }
if (protocol == "ROCE") { return Protocol::ROCE; }
if (protocol == "TCP") { return Protocol::TCP; }
return Protocol::TCP;
}

bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId)
{
constexpr const char* kCamelPrefix = "asuInfo.";
constexpr const char* kSnakePrefix = "asu_info.";
if (key.rfind(kCamelPrefix, 0) == 0) {
asuId = std::stoull(key.substr(std::string{kCamelPrefix}.size()));
return true;
}
if (key.rfind(kSnakePrefix, 0) == 0) {
asuId = std::stoull(key.substr(std::string{kSnakePrefix}.size()));
return true;
}
return false;
}

void SetEndpointAttr(AsuEndpoint& endpoint, const std::string& key, const std::string& value)
{
endpoint.attrs[key] = value;
}

AsuEndpoint ParseAsuEndpoint(const std::string& value)
{
AsuEndpoint endpoint;
if (value.find('=') == std::string::npos) {
auto parts = Split(value, ':');
if (!parts.empty()) { endpoint.ip = parts[0]; }
if (parts.size() > 1) { endpoint.port = static_cast<std::uint16_t>(ParseUint64(parts[1])); }
if (parts.size() > 2) {
endpoint.protocol = ToTransportProtocol(parts[2]);
SetEndpointAttr(endpoint, "protocol", parts[2]);
}
return endpoint;
}

for (const auto& item : Split(value, ',')) {
const auto pos = item.find('=');
if (pos == std::string::npos) { continue; }

const auto key = Trim(item.substr(0, pos));
const auto fieldValue = Trim(item.substr(pos + 1));
if (key == "protocol") {
endpoint.protocol = ToTransportProtocol(fieldValue);
SetEndpointAttr(endpoint, "protocol", fieldValue);
} else if (key == "placement") {
SetEndpointAttr(endpoint, "placement", fieldValue);
} else if (key == "port") {
endpoint.port = static_cast<std::uint16_t>(ParseUint64(fieldValue));
} else if (key == "local.comm_id" || key == "localCommId") {
endpoint.ip = fieldValue;
} else if (key == "local.phy_device_id" || key == "localPhyDeviceId") {
endpoint.deviceId = static_cast<std::int32_t>(ParseUint64(fieldValue));
} else if (key == "tc") {
SetEndpointAttr(endpoint, "tc", fieldValue);
} else if (key == "sl") {
SetEndpointAttr(endpoint, "sl", fieldValue);
} else if (key == "send_size" || key == "sendSize") {
SetEndpointAttr(endpoint, "send_size", fieldValue);
} else if (key == "flag_size" || key == "flagSize") {
SetEndpointAttr(endpoint, "flag_size", fieldValue);
} else if (key == "remote_send_addr" || key == "remoteSendAddr") {
SetEndpointAttr(endpoint, "remote_send_addr", fieldValue);
} else if (key == "remote_flag_addr" || key == "remoteFlagAddr") {
SetEndpointAttr(endpoint, "remote_flag_addr", fieldValue);
}
}
return endpoint;
}

AsuInfo ParseAsuInfo(const std::string& value)
{
AsuInfo info;
for (const auto& endpointValue : Split(value, ';')) {
info.endpoints.emplace_back(ParseAsuEndpoint(endpointValue));
for (const auto& endpointValue : SplitConfigValue(value, ';')) {
info.endpoints.emplace_back(ParseClientViewEndpoint(endpointValue));
}
return info;
}
Expand All @@ -156,22 +56,23 @@ Status LoadAsuClientConfig(const std::string& configPath, AsuClientConfig& confi

config = AsuClientConfig{};
std::unordered_map<AsuId, AsuInfo> asuInfos;
std::vector<std::pair<std::string, std::string>> transportFields;
std::string line;
while (std::getline(configFile, line)) {
line = Trim(line);
line = TrimConfigValue(line);
if (line.empty() || line[0] == '#') { continue; }

const auto pos = line.find('=');
if (pos == std::string::npos) { continue; }

const auto key = Trim(line.substr(0, pos));
const auto value = Trim(line.substr(pos + 1));
const auto key = TrimConfigValue(line.substr(0, pos));
const auto value = TrimConfigValue(line.substr(pos + 1));
if (key == "clientId" || key == "client_id") {
config.clientId = value;
} else if (key == "viewServiceAddrs" || key == "view_service_addrs") {
config.viewServiceAddrs = Split(value, ',');
config.viewServiceAddrs = SplitConfigValue(value, ',');
} else if (key == "defaultWaitTimeoutMs" || key == "default_wait_timeout_ms") {
config.defaultWaitTimeoutMs = ParseUint64(value);
config.defaultWaitTimeoutMs = ParseConfigUint64(value);
} else if (key == "hashTable.type" || key == "hash_table.type") {
auto type = value;
std::transform(type.begin(), type.end(), type.begin(),
Expand All @@ -190,20 +91,34 @@ Status LoadAsuClientConfig(const std::string& configPath, AsuClientConfig& confi
config.attrs["ring_hash.virtual_node_count"] = value;
} else if (key == "hashTable.maglev.tableSize" || key == "maglev.table_size") {
config.attrs["maglev.table_size"] = value;
} else if (key == "transport.asuIds" || key == "transport_asu_ids" || key == "asuIds" ||
key == "asu_ids") {
for (const auto& asuIdText : Split(value, ',')) {
} else if (key == "transport.asuIds" || key == "asuIds" || key == "asu_ids") {
for (const auto& asuIdText : SplitConfigValue(value, ',')) {
TransportConfig transportConfig;
transportConfig.asuId = ParseUint64(asuIdText);
transportConfig.asuId = ParseConfigUint64(asuIdText);
config.transportConfigs.emplace_back(std::move(transportConfig));
}
} else {
AsuId asuId{0};
if (TryParseAsuInfoKey(key, asuId)) { asuInfos[asuId] = ParseAsuInfo(value); }
std::string attrKey;
if (TryParseAsuInfoKey(key, asuId)) {
asuInfos[asuId] = ParseAsuInfo(value);
} else if (TryGetTransportAttrKey(key, attrKey)) {
transportFields.emplace_back(attrKey, value);
}
}
}

for (auto& transportConfig : config.transportConfigs) {
for (const auto& field : transportFields) {
if (ApplyTransportBufferConfigField(transportConfig, field.first, field.second)) {
continue;
}
if (ApplyTransportIoNumConfigField(transportConfig, field.first, field.second)) {
continue;
}
transportConfig.attrs.emplace(field);
}

auto iter = asuInfos.find(transportConfig.asuId);
if (iter == asuInfos.end()) { continue; }
ApplyAsuInfoToTransportConfig(iter->second, transportConfig);
Expand Down
6 changes: 2 additions & 4 deletions ucm/transport/kv/asu/client/src/client_task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ enum class ClientTaskState {
PENDING = 0,
INFLIGHT = 1,
COMPLETED = 2,
FAILED = 3,
CANCELED = 4,
CANCELED = 3,
};

enum class ClientOpType {
Expand All @@ -80,8 +79,7 @@ struct ClientTaskContext {
bool Done() const
{
auto s = state.load(std::memory_order_acquire);
return s == ClientTaskState::COMPLETED || s == ClientTaskState::FAILED ||
s == ClientTaskState::CANCELED;
return s == ClientTaskState::COMPLETED || s == ClientTaskState::CANCELED;
}
};

Expand Down
Loading
Loading