Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .package/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ services:
- SERVICE_NAME=global_ldap_server
volumes:
- ./certs:/certs
- ./logs:/app/logs
- ldap_keytab:/LDAP_keytab/
env_file:
- .env
Expand Down
49 changes: 48 additions & 1 deletion app/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""

from typing import AsyncIterator, NewType
from typing import TYPE_CHECKING, AsyncIterator, NewType

import httpx
import redis.asyncio as redis
Expand Down Expand Up @@ -155,6 +155,15 @@
MFAHTTPClient = NewType("MFAHTTPClient", httpx.AsyncClient)
DHCPManagerHTTPClient = NewType("DHCPManagerHTTPClient", httpx.AsyncClient)

if TYPE_CHECKING:
from loguru import Logger

LDAPLogger = Logger
GlobalCatalogLogger = Logger
else:
LDAPLogger = NewType("LDAPLogger", type[logger]) # type: ignore
GlobalCatalogLogger = NewType("GlobalCatalogLogger", type[logger]) # type: ignore


class MainProvider(Provider):
"""Provider for ldap."""
Expand Down Expand Up @@ -693,6 +702,7 @@ class LDAPServerProvider(LDAPContextProvider):
"""Provider with session scope."""

scope = Scope.SESSION
_ldap_logger_handler_id: int | None = None

network_policy_validator_gateway = provide(
NetworkPolicyValidatorGateway,
Expand Down Expand Up @@ -720,11 +730,30 @@ async def get_session(
yield session
await session.disconnect()

@provide(scope=Scope.APP, provides=LDAPLogger)
def get_ldap_logger(self) -> LDAPLogger:
"""Get LDAP logger."""
log = logger.bind(name="ldap")

# Add handler only once to prevent duplicate log entries
if self._ldap_logger_handler_id is None:
self._ldap_logger_handler_id = log.add(
"logs/ldap_{time:DD-MM-YYYY}.log",
filter=lambda rec: rec["extra"].get("name") == "ldap",
retention="10 days",
rotation="1d",
colorize=False,
enqueue=True,
)

return log


class GlobalLDAPServerProvider(Provider):
"""Provider with session scope."""

scope = Scope.SESSION
_global_catalog_logger_handler_id: int | None = None

@provide(scope=Scope.SESSION, provides=LDAPSession)
async def get_session(
Expand Down Expand Up @@ -760,6 +789,24 @@ async def get_session(
scope=Scope.REQUEST,
)

@provide(scope=Scope.APP, provides=GlobalCatalogLogger)
def get_global_catalog_logger(self) -> GlobalCatalogLogger:
"""Get Global Catalog logger."""
log = logger.bind(name="global_catalog")

if self._global_catalog_logger_handler_id is None:
self._global_catalog_logger_handler_id = log.add(
"logs/global_catalog_{time:DD-MM-YYYY}.log",
filter=lambda rec: rec["extra"].get("name")
== "global_catalog",
retention="10 days",
rotation="1d",
colorize=False,
enqueue=True,
)

return log


class MFACredsProvider(Provider):
"""Creds provider."""
Expand Down
47 changes: 22 additions & 25 deletions app/ldap_protocol/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,18 @@
from typing import Literal, cast, overload

from dishka import AsyncContainer, Scope
from loguru import logger
from proxyprotocol import ProxyProtocolIncompleteError
from proxyprotocol.v2 import ProxyProtocolV2
from pydantic import ValidationError

from config import Settings
from ioc import GlobalCatalogLogger, LDAPLogger
from ldap_protocol import LDAPRequestMessage, LDAPSession
from ldap_protocol.ldap_requests.bind_methods import GSSAPISL
from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase

from .data_logger import DataLogger

log = logger.bind(name="ldap")
log.add(
"logs/ldap_{time:DD-MM-YYYY}.log",
filter=lambda rec: rec["extra"].get("name") == "ldap",
retention="10 days",
rotation="1d",
colorize=False,
enqueue=True,
)

infinity = cast("int", math.inf)
pp_v2 = ProxyProtocolV2()

Expand All @@ -53,14 +43,20 @@ class PoolClientHandler:

ssl_context: ssl.SSLContext | None = None

def __init__(self, settings: Settings, container: AsyncContainer):
def __init__(
self,
settings: Settings,
container: AsyncContainer,
log: LDAPLogger | GlobalCatalogLogger,
):
"""Set workers number for single client concurrent handling."""
self.container = container
self.settings = settings
self.num_workers = self.settings.COROUTINES_NUM_PER_CLIENT
self._size = self.settings.TCP_PACKET_SIZE

self.logger = DataLogger(log, is_full=self.settings.DEBUG)
self.log = log
self.logger = DataLogger(self.log, is_full=self.settings.DEBUG)

self._load_ssl_context()

Expand All @@ -79,7 +75,7 @@ async def __call__(
)
ldap_session.ip = addr

logger.info(f"Connection {addr} opened")
self.log.info(f"Connection {addr} opened")

try:
async with session_scope(scope=Scope.REQUEST) as r:
Expand All @@ -92,7 +88,7 @@ async def __call__(
network_policy_use_case,
)
except PermissionError:
log.warning(f"Whitelist violation from {addr}")
self.log.warning(f"Whitelist violation from {addr}")
return

async with asyncio.TaskGroup() as tg:
Expand All @@ -117,7 +113,9 @@ async def __call__(
)

except* RuntimeError as err:
log.error(f"Response handling error {err}: {format_exc()}")
self.log.error(
f"Response handling error {err}: {format_exc()}",
)

finally:
await session_scope.close()
Expand All @@ -126,18 +124,18 @@ async def __call__(
writer.close()
await writer.wait_closed()

logger.info(f"Connection {addr} closed")
self.log.info(f"Connection {addr} closed")

def _load_ssl_context(self) -> None:
"""Load SSL context for LDAPS."""
if self.settings.USE_CORE_TLS and self.settings.LDAP_LOAD_SSL_CERT:
if not self.settings.check_certs_exist():
log.critical("Certs not found, exiting...")
self.log.critical("Certs not found, exiting...")
raise SystemExit(1)

cert_name = self.settings.SSL_CERT
key_name = self.settings.SSL_KEY
log.success("Found existing cert and key, loading...")
self.log.success("Found existing cert and key, loading...")
self.ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.ssl_context.load_cert_chain(cert_name, key_name)

Expand Down Expand Up @@ -166,7 +164,7 @@ def _extract_proxy_protocol_address(
header_length = int.from_bytes(data[14:16], "big")
return addr, data[16 + header_length :]
except (ValueError, ProxyProtocolIncompleteError) as err:
log.error(f"Proxy Protocol processing error: {err}")
self.log.error(f"Proxy Protocol processing error: {err}")
return peer_addr, data

@overload
Expand Down Expand Up @@ -279,7 +277,7 @@ async def _handle_request(
request = LDAPRequestMessage.from_bytes(data)

except (ValidationError, IndexError, KeyError, ValueError) as err:
log.error(f"Invalid schema {format_exc()}")
self.log.error(f"Invalid schema {format_exc()}")

writer.write(LDAPRequestMessage.from_err(data, err).encode())
await writer.drain()
Expand Down Expand Up @@ -440,15 +438,14 @@ async def _run_server(server: asyncio.base_events.Server) -> None:
async with server:
await server.serve_forever()

@staticmethod
def log_addrs(server: asyncio.base_events.Server) -> None:
def log_addrs(self, server: asyncio.base_events.Server) -> None:
addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets)
log.info(f"Server on {addrs}")
self.log.info(f"Server on {addrs}")

async def start(self) -> None:
"""Run and log tcp server."""
server = await self._get_server()
log.info(
self.log.info(
f"started {'DEBUG' if self.settings.DEBUG else 'PROD'} "
f"{'LDAPS' if self.settings.USE_CORE_TLS else 'LDAP'} server",
)
Expand Down
8 changes: 6 additions & 2 deletions app/multidirectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
from extra.dump_acme_certs import dump_acme_cert
from ioc import (
EventSenderProvider,
GlobalCatalogLogger,
GlobalLDAPServerProvider,
HTTPProvider,
LDAPLogger,
LDAPServerProvider,
MainProvider,
MFACredsProvider,
Expand Down Expand Up @@ -199,7 +201,8 @@ async def ldap_factory(settings: Settings) -> None:
)

settings = await container.get(Settings)
servers.append(PoolClientHandler(settings, container).start())
log = await container.get(LDAPLogger)
servers.append(PoolClientHandler(settings, container, log).start())

await asyncio.gather(*servers)

Expand Down Expand Up @@ -234,7 +237,8 @@ async def global_ldap_server_factory(settings: Settings) -> None:
)

settings = await container.get(Settings)
servers.append(PoolClientHandler(settings, container).start())
log = await container.get(GlobalCatalogLogger)
servers.append(PoolClientHandler(settings, container, log).start())

await asyncio.gather(*servers)

Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from dishka.integrations.fastapi import setup_dishka
from fastapi import FastAPI, Request, Response
from loguru import logger
from multidirectory import _create_basic_app
from sqlalchemy import schema, text
from sqlalchemy.ext.asyncio import (
Expand Down Expand Up @@ -1079,8 +1080,9 @@ async def handler(
) -> AsyncIterator[PoolClientHandler]:
"""Create test handler."""
settings.set_test_port()
test_log = logger.bind(name="ldap_test")
async with container(scope=Scope.APP) as app_scope:
yield PoolClientHandler(settings, app_scope)
yield PoolClientHandler(settings, app_scope, test_log)


@pytest_asyncio.fixture(scope="function")
Expand Down