diff --git a/.gitignore b/.gitignore index 9eea4ed..1b9dda6 100644 --- a/.gitignore +++ b/.gitignore @@ -3,10 +3,13 @@ .DS_Cache .DS_Store +__debug_bin* + /devdata /vendor /example/e2e/logs +/example/e2e/data/triton_model_repository # binaries /example/client/mlyc/mlyc @@ -14,4 +17,4 @@ /toolsv2/aerospike/aerospike /toolsv2/smasher/cmd/cmd -/toolsv2/toolsv2 \ No newline at end of file +/toolsv2/toolsv2 diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..9c34e38 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,193 @@ +# Code Architecture + +mly consists of the following large conceptual (not strictly programmatically linked) steps and sub-steps: + +1. Configuration processing +2. Model initialization + 1. Platform evaluator creation + 2. Caching support +3. Request handling + 1. Input processing + 2. Model inference + 3. Post-prediction processing + 4. Prediction logging +4. Model reloading + +--- + +## 1 Configuration processing + +This is a standard step in many services. This will read configuration files and populate defaults. +Details are mainly in [CONFIG.md](CONFIG.md). + +*Quirk*: The configuration struct contains both the read configuration and follow-up processed configuration values. For example, `Modified` is populated during model loading, and `DictMeta` is updated when the dictionary is loaded. + +--- + +## 2 Model Initialization + +Model initialization occurs in `service.New()` which orchestrates the creation of platform-specific evaluators and supporting infrastructure. + +### 2.1 Evaluator Creation + +A core concept in mly is the "Evaluator." +An Evaluator is essentially something that can provide some kind of model inference. +All Evaluators implement `platform.PlatformEvaluator`: + +```go +type PlatformEvaluator interface { + Predict(ctx context.Context, params []interface{}) ([]interface{}, error) + Signature() *domain.Signature + Dictionary() *common.Dictionary + Inputs() map[string]*domain.Input + ReloadIfNeeded(ctx context.Context) error + Close() error +} +``` + +There are currently 3 Evaluators: + +1. TensorFlow - this Evaluator operates with a `libtensorflow` backend, and has additional logic that supports timeout-based batching. +2. Triton - this Evaluator supports sending prediction requests to a single Triton server via HTTP or gRPC. +3. Router - this Evaluator does not generate any prediction but enables rows in a prediction request to be sent to other Evaluators based on the input. + +*Potential design issue*: Evaluator overloading and over-abstraction - the Router operates on the same interface as the TensorFlow and Triton evaluators, but vary in their behavioral labels. + +### 2.2 Caching Support + +Caching is implemented via `shared/datastore.Service`, which provides a multi-layer cache: + +1. **Local in-memory cache** ([`scache`](https://github.com/viant/scache)): Fast local cache with TTL expiration +2. **L1 cache** (Aerospike): Primary distributed cache +3. **L2 cache** (Aerospike): Secondary distributed cache for cache warming + +An important concept in mly caching is the *Dictionary hash*. +This is stored with cached values, and is intended to invalidate entries when the model changes (e.g., there is a model weights update). + +*Quirk*: Client-read, server-write - based on the observation that if a client does not find a cache entry, then the server is unlikely to also find a cache entry, and to skip the latency overhead from a remote cache check, the server does not check for a cache entry. + +*Design debt*: The current client-read, server-write introduces a case when multiple clients concurrently find that a cache entry is missing, and sends the same request to potentially the same mly server, causing the same server to run the same prediction multiple times. This should be controlled on the server side, to avoid unnecessary compute. + +*Design debt*: Aerospike coupling - the current implementation depends on Aerospike constructs. + +--- + +## 3 Request Handling + +The mly service occupies most of its lifetime serving this purpose. Currently, mly is designed to focus around HTTP requests, using HTTP/2. + +Data flow: + +``` +HTTP Request +→ service.Handler.ServeHTTP() +→ service.Service.Do() +→ service.Evaluator.Predict() +→ service.domain.Transformer +→ service.Response +``` + +### 3.1 Input Processing + +The Input processing step is primarily focused around logic of pulling data from an HTTP-compliant, JSON or URL-based payload and pushing it into a Go (and CGo) compatible data structure for model inference. + +This step revolves mainly around the `service/request.Request` struct. + +Key components: +- **Feeds**: `[]interface{}` shaped as `[numInputs]([batchSize][1]T)` for model consumption +- **Input**: `*transfer.Input` for Transformer support + +The `UnmarshalJSONObject()` method implements `gojay.UnmarshalerJSONObject` for high-performance JSON parsing. + +The interaction with *Model inference* involves the `Feeds` field. + +*Quirk*: client batching payload reduction - mly provides a convenience / payload reduction feature that permits payloads to have both inputs with a list of 1 values as well as inputs with a list of batch size of values. The server will expand the payload to fit the expected batch size times inputs matrix for the Evaluators. + +*Quirk*: payload reading order - the JSON payload must have the `batch_size` key existing before other input keys, as that is required to know if the parser should be expecting a list of values or scalar values. + +*Design debt*: `Feeds` type - most of the requests are tracked via input names than offsets; the intermediate data form should be a `map[string]interface{}` (or even `map[string][]interface{}` to capture a potential batch layer), and the conversion to an offset-based slice should be isolated to TensorFlow graph related code. + +### 3.2 Model Inference + +Model inference is delegated to the platform-specific evaluator via `Predict()`: + +**TensorFlow** (`service/tfmodel`): +- Optional batching via `service/tfmodel/batcher.Service` aggregates concurrent requests +- Direct evaluation via `service/tfmodel/evaluator.Service` runs TensorFlow session +- Semaphore-controlled concurrency prevents overload + +For Triton, [concurrency and timeout-based batching is controlled via Triton](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html). + +**Triton** (`service/triton`): +- HTTP or gRPC call to Triton Inference Server +- Input tensors serialized per Triton protocol +- Timeout-controlled requests + +The router is a control layer that can route to various model inference services. +Each downstream evaluator is responsible for controlling their own lifetime and batching concerns. +In theory, a router can also route to another router, but no such implementation yet exists. + +**Router** (`service/platform/router`): +- Extracts routing key from input +- Groups rows by target model +- Parallel dispatch to downstream evaluators +- Result reassembly in original order + +### 3.3 Post-Prediction Processing + +After inference, `Service.buildResponse()` handles output transformation: + +**Transformer execution** +The configured `domain.Transformer` function transforms raw model output into a `common.Storable` for serialization. +The default transformer extracts values keyed by output tensor names. + +The `domain.Transformer` signature: + +```go +type Transformer func(ctx context.Context, signature *Signature, input *gtly.Object, output interface{}) (common.Storable, error) +``` + +**Cache storage** +If caching is enabled, transformed results are stored asynchronously via `datastore.Put()`. + +*Design debt*: Batch-based Transformer - the current Transformer API operates at the request level outputs but at row-level inputs, and is invoked per row. + +### 3.4 Prediction Logging + +If `Stream` is configured, the `stream.Service` logs requests for analytics: +- Request body +- Model output +- Inference duration + +Logging uses `github.com/viant/tapper` for configurable output destinations. + +--- + +## 4 Model Reloading + +Model reloading runs continuously in a background goroutine (`Service.pollModelReload()`), checking for updates at configurable intervals (`ReloadPollIntervalSeconds`). + +The `ReloadIfNeeded()` implementation is platform-specific, and varies similarly to model prediction in how much is implemented vs. delegated: + +**TensorFlow** (`service/tfmodel.Service`): +1. Check file modification times at `URL` +2. If changed: copy model files to `Location`, load SavedModel +3. Extract signature and dictionary from graph +4. Create new `service/tfmodel/evaluator.Service` and optionally `service/tfmodel/batcher.Service` +5. Atomically swap Evaluators under mutex protection + +**Triton** (`service/triton.TritonEvaluator`): +1. Check model health via `ModelReady()` API +2. If not ready and in EXPLICIT mode: call `ModelLoad()` +3. Refresh metadata via `ModelMetadata()` if signature not yet captured + +**Router** (`service/platform/router.Router`): +1. Check routing configuration file modification +2. Reload routing table if changed +3. Create/destroy downstream Evaluators as needed +4. Atomically swap routing table under mutex protection +5. Unload unused models from Triton via Model Control API + +Reload health is tracked via `Service.ReloadOK` for centralized health reporting. + +*Design issue*: Over-abstraction of `ReloadIfNeeded()` - we note that this is a very high-level abstraction that could be broken down into separate concerns e.g., check health, load model, check if reload needed, etc. \ No newline at end of file diff --git a/CONFIG.md b/CONFIG.md index 2256d13..8118bc6 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -19,10 +19,10 @@ Properties: * to use GCS, set environment variable `GOOGLE_APPLICATION_CREDENTIALS=true` - `Location`: `string` - optional - where a copy of the models will be stored when loading the model. Defaults to the system temporary directory. - `Dir`: `string` - optional - any further path elements in `Location`. Mainly used if using a ZIP file with additional directories. -- `DataStore`: `string` - optional - name of Datastore to cache, should match `Datastores[].ID`. +- `DataStore`: `string` - optional - name of Datastore to use for caching, should match `Datastores[].ID`. Server-side datastore writes are enabled only when `UseDict` is `true` or unset. - `Transformer`: `string` - optional - name of model output transformer. See [#Transformer](#Transformer). - `Batch`: optional - enables or overrides server-side batching configuration. See [`service/tfmodel/batcher/config/config.go`](service/tfmodel/batcher/config/config.go). -- `UseDict`: `bool` - optional - if true, enables capabilities designed to shrink the cache key space by replacing out-of-vocabulary inputs from cache keys with a special token. +- `UseDict`: `bool` - optional - if true or unset, enables dictionary-based cache behavior, including replacing out-of-vocabulary inputs in cache keys with a special token and allowing the server to generate datastore cache entries when `DataStore` is configured. If false, the server will not generate new datastore cache entries for the model. - `Inputs`: used to further provide or define inputs, a list of `shared.Field`. For TensorFlow models, this is automatically populated, but further caching configurations need to be specified. * `Name`: `string` - required - input name, only required if an entry is provided. * `Index`: `int` - optional - used to maintain cache key ordering. @@ -73,7 +73,7 @@ Can be empty - represent a list of caching data stores. Properties: -- `ID`: `string` - required - datastore ID (to be matched with `Models[].DataStores[].ID`) +- `ID`: `string` - required - datastore ID (to be matched with `Models[].DataStore`) - `Connection`: `string` - optional - connection ID - `Namespace`: `string` - optional - Aerospike namespace - `Dataset`: `string` - optional - Aerospike dataset @@ -109,9 +109,10 @@ mly := client.New("$modelID", []*client.Host{client.NewHost("mlServiceHost", mlS ``` Where optional `options` can be of, but not limited to, the following: - * `NewCacheSize(sizeOption)` - * `NewCacheScope(CacheScopeLocal|CacheScopeL1|CacheScopeL2)` - * `NewGmetric()` - custom instance of `gmetric` service + * `WithCacheSize(sizeOption)` + * `WithCacheScope(CacheScopeLocal|CacheScopeL1|CacheScopeL2)` + * `WithGmetrics()` - custom instance of `gmetric` service + * `WithHashValidation(true)` - enables client-side rejection of cached entries with a non-zero hash that differs from the client's current dictionary hash See [`shared/client/option.go`](shared/client/option.go) for more options. diff --git a/README.md b/README.md index 7756def..81e4e8a 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,9 @@ By default, the client will configure itself using the web service cache setting This enables the `mly` client to handle key generation without additional configuration or code. The library supports 3 types of caching: -- in-(process) memory -- external Aerospike cache -- hybrid +- in-(process) memory using [scache](https://github.com/viant/scache) +- external Aerospike cache (supports L1/L2 tiered caching for larger key spaces) +- hybrid (in-memory + external) The in-memory cache uses [scache](https://github.com/viant/scache)'s most-recently-used implementation. @@ -77,6 +77,7 @@ In this scenario, the L2 cache can be a very large SSD-backed Aerospike instance In this case, when we look for a cached value, first the in-memory cache is checked, followed by L1, then L2. Then with a cache miss, the value is calculated then copied to L2 - then from L2 to L1 and L1 to local memory. + **Example of `config.yaml` with both an in-memory and an Aerospike cache** ```yaml @@ -104,10 +105,11 @@ See [WORKFLOW.md](WORKFLOW.md) for Mermaid diagrams explaining the Client and mo ## Dictionary hash code In caching mode, in order to manage cache and client/server consistency every time a model/dictionary gets re/loaded, `mly` computes a dictionary hash code. -This hash code gets stored in the cache along with model prediction and is passed to the client in every response. -Once a client detects a change in dictionary hash code, it automatically initiates a dictionary reload and invalidates cache entries. +This hash code gets stored in the cache along with model prediction and is passed to the client in every non-cached response. +Once a client detects a change in the dictionary hash code, it will initiate a dictionary reload and, if `client.WithHashValidation(true)` was an option on client initialization, reject any cache entry with a non-zero, different hash code. -Note: The dictionary hash code is stored under a special key in Aerospike defined in `shared/common.HashBin`. To prevent conflicts, do not use that same key name for storing your own model predictions. +**Note** The dictionary hash code is stored under a bin in Aerospike defined in `shared/common.HashBin`. +To prevent conflicts, do not use that same bin name for storing your own model predictions. # Configuration @@ -263,21 +265,99 @@ In all these, `%s` is `Model[].ID` (i.e. from `config.yaml`) - `/v1/api/metric/operation/%sDictMeta` - Records metrics to client dictionary fetch. - `/v1/api/metric/operation/%sCfgMeta` - Records metrics to client configuration fetch. - `/v1/api/metric/operation/%sMetaHandler` - Records server-side metrics to client set up. +- `/v1/api/metric/operation/%sHTTPHandler` - Records per-request HTTP handler metrics. Includes `responseMarshalError` (response struct could not be marshaled; server returned 500) and `responseCommittedError` (body write failed after status was committed; client sees `200 OK` with a body shorter than `Content-Length`). ## `/v1/api/debug` Requires `EnableMemProf` and / or `EnableCPUProf` to be enabled. See [`service/endpoint/prof.go`](service/endpoint/prof.go) for details - otherwise, refer to `pprof` documentation. -## `/v1/api/model` +In the following sections, `%s` is `Model[].ID` (i.e. from `config.yaml`). -Model operations. +## `/v1/api/model/%s/eval` -In all these, `%s` is `Model[].ID` (i.e. from `config.yaml`) +Runs a model prediction. This is the primary data-plane endpoint and the only one with a structured request and response shape; the others are administrative or metadata. + +### Methods + +- `GET` - input values supplied as URL query parameters. Single-prediction mode only. +- `POST` - JSON body containing input values plus optional `batch_size` and `cache_key`. Supports both single-prediction and batch mode. + +### Request body (POST) + +A JSON object whose keys are model input names (as defined by the model's signature). Values are either scalars (single mode) or arrays of length `batch_size` or `1` (batch mode). Two reserved optional keys: + +- `batch_size` (integer, optional) - if present and `> 0`, switches the request to batch mode. Other input values must then be arrays. +- `cache_key` (string in single mode, string array of length `batch_size` in batch mode, optional) - explicit cache key(s) to use instead of letting the server derive one from the input values. + +A minimal single-mode request: + +```json +{"input1": "value1", "input2": 42} +``` + +A batch-mode request with two predictions and explicit cache keys: + +```json +{ + "batch_size": 2, + "cache_key": ["k1", "k2"], + "input1": ["value1", "value2"], + "input2": [42, 43] +} +``` + +### Successful response + +`200 OK` with `Content-Type: application/json` and an explicit `Content-Length`. The body is a JSON object: + +```json +{"status": "ok", "dictHash": 12345, "data": {...}, "serviceTimeMcs": 1100} +``` + +- `status` - always `"ok"` on success. +- `dictHash` - hash of the dictionary the prediction was made against. Clients use this to detect dictionary changes and trigger a reload (see [Dictionary hash code](#dictionary-hash-code)). +- `data` - the model output, shape determined by the model and any registered transformer. +- `serviceTimeMcs` - server-side time spent on this request in microseconds. + +A short read against the declared `Content-Length` indicates a transport failure (peer closed mid-response, broken pipe, etc.), not a successful empty response. Clients should surface short reads as errors rather than treating them as empty bodies. + +### Error response + +Errors return a non-2xx HTTP status code with the same `Content-Type: application/json` and explicit `Content-Length` as the success response. The body is the same `Response` JSON object with `status` set to `"error"`, the `error` field populated, and `data` omitted: + +```json +{"status": "error", "error": "", "serviceTimeMcs": 1100} +``` + +| status | cause | +| ------ | ----- | +| `400 Bad Request` | malformed query string, malformed JSON body, type mismatch on an input value, or any client-side input error | +| `413 Request Entity Too Large` | POST body exceeds the server's request buffer | +| `429 Too Many Requests` | server is overloaded (evaluator queue rejected the request) | +| `500 Internal Server Error` | prediction failure, server-side encoding failure, or any other server-side error | + +Clients can therefore parse the response body the same way regardless of HTTP status — the only differences are the status code and which fields are populated. As a fallback for the rare case where the server cannot encode an error response, a plain-text body may be returned with the same status code. + +### Example + +```bash +# GET, single prediction +curl 'http://localhost:8086/v1/api/model/ml0/eval?input1=value1&input2=42' + +# POST, batch prediction +curl -X POST 'http://localhost:8086/v1/api/model/ml0/eval' \ + -H 'Content-Type: application/json' \ + -d '{"batch_size":2,"input1":["a","b"],"input2":[1,2]}' +``` + +## `/v1/api/model/%s/meta/config` + +Returns the client configuration derived from the model (cache settings, input/output schema, etc.). Used by `mly` clients to bootstrap. + +## `/v1/api/model/%s/meta/dictionary` -- `/v1/api/model/%s/eval` - runs `GET` / `POST` model prediction. -- `/v1/api/model/%s/meta/config` - provides configuration for client related to model -- `/v1/api/model/%s/meta/dictionary` - provides current dictionary +Returns the current dictionary (categorical input vocabularies + dictionary hash). Used by `mly` clients to populate the local cache and detect dictionary changes. # Client Metrics (`gmetric`) @@ -299,6 +379,7 @@ all compatible with Apache License, Version 2. Please see individual files for d # Versioning Notes +- `v0.20.0` - error responses from `/v1/api/model/%s/eval` are now JSON-encoded `Response` objects (same `Content-Type` / `Content-Length` contract as success responses) instead of plain text; HTTP status codes are unchanged. The client populates `response.Error` from the parsed body, so consumers can rely on either the `err` return value or `response.Error` as the error signal. - `v0.14.1` last support for go 1.17 - `v0.8.0` - numeric features are supported. diff --git a/WORKFLOW.md b/WORKFLOW.md index d3fd475..03fc0dc 100644 --- a/WORKFLOW.md +++ b/WORKFLOW.md @@ -51,7 +51,7 @@ sequenceDiagram ```mermaid sequenceDiagram participant client as client.Service - + participant mlyserver as mly Server participant serverds as Server datastore.Service @@ -74,7 +74,7 @@ sequenceDiagram aerospike-->>datastore: KEY_NOT_FOUND_ERROR Note over datastore: L1NoSuchKey - alt L2 is configured + alt L2 is configured datastore->>aerospikel2: Get() aerospikel2-->>datastore: KEY_NOT_FOUND_ERROR Note over datastore: L2NoSuchKey @@ -94,14 +94,14 @@ sequenceDiagram alt mly Prediction Required client->>mlyserver: postRequest() - + activate mlyserver Note over mlyserver: Run TensorFlow model graph - par + par mlyserver->>serverds: Put() serverds->>aerospike: Put() - and + and mlyserver-->>client: response end @@ -143,6 +143,6 @@ sequenceDiagram client->>datastore: Put() datastore->>scache: Put() - + deactivate client ``` \ No newline at end of file diff --git a/example/client/option.go b/example/client/option.go index a28d251..2d3f045 100644 --- a/example/client/option.go +++ b/example/client/option.go @@ -37,14 +37,22 @@ type Options struct { SkipError bool `long:"skiperrs"` - NoOutput bool `long:"noout"` - Metrics bool `long:"metrics"` + // NoOutput suppresses model outputs. + NoOutput bool `long:"noout"` + + Metrics bool `long:"metrics" description:"print gmetric metrics"` + Prometheus bool `long:"prometheus" description:"print prometheus metrics"` ErrorHistory bool `long:"errhist"` // Report forces NoOutput and SkipError true, Metrics and ErrorHistory false. // Will generate a final JSON object as its only output to stdout. // stderr may have other output if Debug is true or there are other errors. Report bool `long:"report"` + + // OutputFile redirects result output to a file instead of stdout. + // When set, all model output, reports, metrics, and error history + // are written to this file path. + OutputFile string `short:"o" long:"output" description:"write results to file instead of stdout"` } type C uint8 diff --git a/example/client/payload.go b/example/client/payload.go index 27426a3..b3e504e 100644 --- a/example/client/payload.go +++ b/example/client/payload.go @@ -70,11 +70,19 @@ func Parse(p string, cp *CliPayload) error { for _, chunk := range chunks { def := strings.Split(chunk, ":") if len(def) != 2 { - return fmt.Errorf("chunk \"%s\" missing or has more than one \":\"", chunk) + return fmt.Errorf("chunk \"%s\" has more than one \":\"", chunk) + } + + valStr := def[1] + var vals []string + var err error + if valStr == "" { + vals = []string{""} + } else { + vals, err = csv.NewReader(strings.NewReader(valStr)).Read() } field := def[0] - vals, err := csv.NewReader(strings.NewReader(def[1])).Read() if err != nil { return fmt.Errorf("csv error for field %s: %v", field, err) } diff --git a/example/client/runner.go b/example/client/runner.go index 463ff33..55032e6 100644 --- a/example/client/runner.go +++ b/example/client/runner.go @@ -3,10 +3,14 @@ package client import ( "context" "fmt" + "io" "log" + "os" "sync" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/common/expfmt" "github.com/viant/gmetric" "github.com/viant/mly/service/endpoint/checker" "github.com/viant/mly/shared/client" @@ -18,6 +22,14 @@ import ( // Use CustomMakerRegistry with --maker to use a specific entity for Response.Data. var CustomMakerRegistry *customMakerRegistry = new(customMakerRegistry) +func dumpTo(w io.Writer, data interface{}) { + text, err := toolbox.AsJSONText(data) + if err != nil { + return + } + fmt.Fprintf(w, "%v\n", text) +} + func RunWithOptions(runOpts *Options) error { runOpts.Init() if err := runOpts.Validate(); err != nil { @@ -28,6 +40,16 @@ func RunWithOptions(runOpts *Options) error { return fmt.Errorf("could not determine model") } + output := io.Writer(os.Stdout) + if runOpts.OutputFile != "" { + f, err := os.Create(runOpts.OutputFile) + if err != nil { + return fmt.Errorf("failed to open output file %s: %w", runOpts.OutputFile, err) + } + defer f.Close() + output = f + } + payloads, err := runOpts.Payloads() if err != nil { return err @@ -132,7 +154,7 @@ func RunWithOptions(runOpts *Options) error { for i, pload := range payloads { rs.WPayloads[i] = WorkerPayload{Payload: pload} rd := &rs.WPayloads[i] - payloadedRunner := makePayloadRunner(cli, pload, runOpts, dataSetter) + payloadedRunner := makePayloadRunner(cli, pload, runOpts, dataSetter, output) fchan <- runContext{ WP: rd, @@ -180,18 +202,32 @@ func RunWithOptions(runOpts *Options) error { report.Metrics = opcs if runOpts.Metrics { - toolbox.Dump(opcs) + dumpTo(output, opcs) } if runOpts.ErrorHistory { tops := cli.ErrorHistory.TopK() for _, t := range tops { - fmt.Printf("%d %s\n", t.Count, string(t.Data)) + fmt.Fprintf(output, "%d %s\n", t.Count, string(t.Data)) + } + } + + if runOpts.Prometheus { + mfs, err := prometheus.DefaultGatherer.Gather() + if err != nil { + return fmt.Errorf("failed to gather prometheus metrics: %w", err) + } + + encoder := expfmt.NewEncoder(output, expfmt.FmtText) + for _, mf := range mfs { + if err := encoder.Encode(mf); err != nil { + return fmt.Errorf("failed to encode metric family %s: %w", mf.GetName(), err) + } } } if runOpts.Report { - toolbox.Dump(report) + dumpTo(output, report) } return err @@ -237,7 +273,7 @@ func worker(worker int, echan chan error, fchan chan runContext, closed chan str } func makePayloadRunner(cli *client.Service, pl *CliPayload, runOpts *Options, - builder func(int) func() interface{}) func() (*client.Response, error) { + builder func(int) func() interface{}, output io.Writer) func() (*client.Response, error) { maker := builder(pl.Batch) @@ -267,7 +303,7 @@ func makePayloadRunner(cli *client.Service, pl *CliPayload, runOpts *Options, } if !runOpts.NoOutput { - toolbox.Dump(response) + dumpTo(output, response) } return response, nil diff --git a/example/e2e/check-port.sh b/example/e2e/check-port.sh index dbd1de0..d76b0f5 100644 --- a/example/e2e/check-port.sh +++ b/example/e2e/check-port.sh @@ -4,26 +4,62 @@ set -x -e ADDR=$1 if [ -z "$ADDR" ]; then - echo "usage: $0 ADDRESS [TIMES] [SLEEP]" + echo "usage: $0 ADDRESS [TIMES] [SLEEP] [PIDFILE]" echo "address required" exit 2 fi TIMES=${2:-30} SLEEP=${3:-1} +PIDFILE=${4:-} + +# Check if process from PID file is still running +# Returns 0 if no pidfile specified, process is running, or pidfile doesn't exist yet +# Returns 1 if pidfile exists but process is dead +check_pid() { + if [ -z "$PIDFILE" ]; then + return 0 + fi + if [ ! -f "$PIDFILE" ]; then + # PID file doesn't exist yet - service may still be starting + return 0 + fi + local pid + pid=$(cat "$PIDFILE" 2>/dev/null) + if [ -z "$pid" ]; then + return 0 + fi + if kill -0 "$pid" 2>/dev/null; then + return 0 + fi + echo "process $pid from $PIDFILE is no longer running" + return 1 +} -LOOPS=0 ERROR=1 set +x for i in $(seq $TIMES); do - sleep 1 + sleep "$SLEEP" + + # Early exit if monitored process died + if ! check_pid; then + echo "service process terminated before becoming ready" + exit 3 + fi + set +e - curl $1 &>/dev/null + curl "$ADDR" &>/dev/null ERROR=$? set -e if [ $ERROR -eq 0 ]; then break fi -done +done + +# Final PID check - ensure process is still alive even if curl succeeded +if [ $ERROR -eq 0 ] && ! check_pid; then + echo "service process terminated" + exit 3 +fi exit $ERROR diff --git a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/saved_model.pb b/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/saved_model.pb deleted file mode 100644 index cb4a79b..0000000 Binary files a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/saved_model.pb and /dev/null differ diff --git a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.data-00000-of-00001 b/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.data-00000-of-00001 deleted file mode 100644 index 7d66744..0000000 Binary files a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.index b/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.index deleted file mode 100644 index 32856f2..0000000 Binary files a/example/e2e/data/triton_model_repository/sli_1/1/model.savedmodel/variables/variables.index and /dev/null differ diff --git a/example/e2e/data/triton_model_repository/sli_1/config.pbtxt b/example/e2e/data/triton_model_repository/sli_1/config.pbtxt deleted file mode 100644 index 566ef4d..0000000 --- a/example/e2e/data/triton_model_repository/sli_1/config.pbtxt +++ /dev/null @@ -1,29 +0,0 @@ - -backend: "tensorflow" -max_batch_size: 32 - -# Input specifications -input [ - { - name: "sa" - data_type: TYPE_STRING - dims: [ 1 ] - allow_ragged_batch: false - }, - { - name: "sl" - data_type: TYPE_STRING - dims: [ 1 ] - allow_ragged_batch: false - } -] - -# Output specifications -output [ - { - name: "expand" - data_type: TYPE_INT64 - dims: [ 1 ] - } -] - diff --git a/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/saved_model.pb b/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/saved_model.pb deleted file mode 100644 index cb4a79b..0000000 Binary files a/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/saved_model.pb and /dev/null differ diff --git a/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.data-00000-of-00001 b/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.data-00000-of-00001 deleted file mode 100644 index 7d66744..0000000 Binary files a/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.index b/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.index deleted file mode 100644 index 32856f2..0000000 Binary files a/example/e2e/data/triton_model_repository/sli_2/1/model.savedmodel/variables/variables.index and /dev/null differ diff --git a/example/e2e/data/triton_model_repository/sli_2/config.pbtxt b/example/e2e/data/triton_model_repository/sli_2/config.pbtxt deleted file mode 100644 index 49f049e..0000000 --- a/example/e2e/data/triton_model_repository/sli_2/config.pbtxt +++ /dev/null @@ -1,27 +0,0 @@ - -backend: "tensorflow" -max_batch_size: 32 - -# Input specifications -input [ - { - name: "sa" - data_type: TYPE_STRING - dims: [ 1 ] - }, - { - name: "sl" - data_type: TYPE_STRING - dims: [ 1 ] - } -] - -# Output specifications -output [ - { - name: "expand" - data_type: TYPE_INT64 - dims: [ 1 ] - } -] - diff --git a/example/e2e/deps.yaml b/example/e2e/deps.yaml index d17e368..d29b893 100644 --- a/example/e2e/deps.yaml +++ b/example/e2e/deps.yaml @@ -3,10 +3,10 @@ init: pipeline: deploy: - # set_sdk: - # action: sdk.set - # target: $target - # sdk: go:${goVersion} + set_sdk: + action: sdk.set + target: $target + sdk: go:${goVersion} install_dependencies: action: exec:run diff --git a/example/e2e/regression/app.yaml b/example/e2e/regression/app.yaml index 50d0e23..9ba9e83 100644 --- a/example/e2e/regression/app.yaml +++ b/example/e2e/regression/app.yaml @@ -18,7 +18,7 @@ pipeline: directory: /tmp/e2e checkError: true immuneToHangups: true - command: ./mly-endly -c=/tmp/e2e/config.yaml 2>&1 >/tmp/e2e/mly-endly.log + command: bash ${appPath}/example/e2e/start-with-pid.sh /tmp/e2e/mly-endly.pid ./mly-endly -c=/tmp/e2e/config.yaml >/tmp/e2e/mly-endly.log 2>&1 env: DEBUG: 'true' LD_LIBRARY_PATH: /usr/local/lib @@ -28,5 +28,4 @@ pipeline: target: $target checkError: true commands: - - bash ${appPath}/example/e2e/check-port.sh localhost:8086 30 - + - bash ${appPath}/example/e2e/check-port.sh localhost:8086 30 1s /tmp/e2e/mly-endly.pid diff --git a/example/e2e/regression/cases/001_e2e_client/expect-slft_batch-maker.json b/example/e2e/regression/cases/001_e2e_client/expect-slft_batch-maker.json new file mode 100644 index 0000000..18922f2 --- /dev/null +++ b/example/e2e/regression/cases/001_e2e_client/expect-slft_batch-maker.json @@ -0,0 +1,5 @@ +{ + "data": { + "Class": "five" + } +} diff --git a/example/e2e/regression/cases/001_e2e_client/expect-slft_batch.json b/example/e2e/regression/cases/001_e2e_client/expect-slft_batch.json new file mode 100644 index 0000000..fa6a876 --- /dev/null +++ b/example/e2e/regression/cases/001_e2e_client/expect-slft_batch.json @@ -0,0 +1,10 @@ +{ + "data": [ + { + "Class": "five" + }, + { + "Class": "five" + } + ] +} diff --git a/example/e2e/regression/cases/001_e2e_client/expect-sli.json b/example/e2e/regression/cases/001_e2e_client/expect-sli.json new file mode 100644 index 0000000..0012540 --- /dev/null +++ b/example/e2e/regression/cases/001_e2e_client/expect-sli.json @@ -0,0 +1,7 @@ +{ + "data": { + "D": { + "expand": 25 + } + } +} diff --git a/example/e2e/regression/cases/001_e2e_client/test.yaml b/example/e2e/regression/cases/001_e2e_client/test.yaml index c49a39d..000dc15 100644 --- a/example/e2e/regression/cases/001_e2e_client/test.yaml +++ b/example/e2e/regression/cases/001_e2e_client/test.yaml @@ -13,43 +13,15 @@ pipeline: assert: action: validator:assert - init: - actual: $AsJSON($test.Cmd[0].Stdout) - actual: $actual - expect: | - { - "data": { - "D": { - "expand": 25 - } - } - } + actual: $AsJSON($test.Cmd[0].Stdout) + expect: $LoadJSON('${parentPath}/expect-sli.json') assertStorable: action: validator:assert - init: - actual: $AsJSON($test.Cmd[1].Stdout) - actual: $actual - expect: | - { - "data": [ - { - "Class": "five" - }, - { - "Class": "five" - } - ] - } + actual: $AsJSON($test.Cmd[1].Stdout) + expect: $LoadJSON('${parentPath}/expect-slft_batch.json') assertCustom: action: validator:assert - init: - actual: $AsJSON($test.Cmd[2].Stdout) - actual: $actual - expect: | - { - "data": { - "Class": "five" - } - } + actual: $AsJSON($test.Cmd[2].Stdout) + expect: $LoadJSON('${parentPath}/expect-slft_batch-maker.json') diff --git a/example/e2e/regression/cases/002_sls_client/test.yaml b/example/e2e/regression/cases/002_sls_client/test.yaml index 63592b5..e7bd343 100644 --- a/example/e2e/regression/cases/002_sls_client/test.yaml +++ b/example/e2e/regression/cases/002_sls_client/test.yaml @@ -12,14 +12,10 @@ pipeline: assert: action: validator:assert - init: - actual: $AsJSON($test.Cmd[0].Stdout) - actual: $actual + actual: $AsJSON($test.Cmd[0].Stdout) expect: $LoadJSON('${parentPath}/expect.json') assertCache: action: validator:assert - init: - actual: $AsJSON($test.Cmd[1].Stdout) - actual: $actual + actual: $AsJSON($test.Cmd[1].Stdout) expect: $LoadJSON('${parentPath}/expect-cached.json') diff --git a/example/e2e/regression/cases/003_vec_client/test.yaml b/example/e2e/regression/cases/003_vec_client/test.yaml index 5064519..9d00e80 100644 --- a/example/e2e/regression/cases/003_vec_client/test.yaml +++ b/example/e2e/regression/cases/003_vec_client/test.yaml @@ -12,16 +12,10 @@ pipeline: assert: action: validator:assert - init: - actual: $AsJSON($test.Cmd[0].Stdout) - expect: expect - actual: $actual - expect: $LoadJSON('${parentPath}/${expect}.json') + actual: $AsJSON($test.Cmd[0].Stdout) + expect: $LoadJSON('${parentPath}/expect.json') assertCache: action: validator:assert - init: - actual: $AsJSON($test.Cmd[1].Stdout) - expect: expect-cache - actual: $actual - expect: $LoadJSON('${parentPath}/${expect}.json') + actual: $AsJSON($test.Cmd[1].Stdout) + expect: $LoadJSON('${parentPath}/expect-cache.json') diff --git a/example/e2e/regression/cases/004_slf_transform_batch/expect-batch-cache.json b/example/e2e/regression/cases/004_slf_transform_batch/expect-batch-cache.json index 789e6f1..ef16faf 100644 --- a/example/e2e/regression/cases/004_slf_transform_batch/expect-batch-cache.json +++ b/example/e2e/regression/cases/004_slf_transform_batch/expect-batch-cache.json @@ -3,7 +3,7 @@ "data": [ { "Class": "five" - }, + }, { "Class": "two" } diff --git a/example/e2e/regression/cases/004_slf_transform_batch/expect-batch.json b/example/e2e/regression/cases/004_slf_transform_batch/expect-batch.json index b19d1e7..3cdca68 100644 --- a/example/e2e/regression/cases/004_slf_transform_batch/expect-batch.json +++ b/example/e2e/regression/cases/004_slf_transform_batch/expect-batch.json @@ -2,10 +2,10 @@ "status": "ok", "data": [ { - "Class": "five" - }, - { - "Class": "two" - } - ] + "Class": "five" + }, + { + "Class": "two" + } + ] } diff --git a/example/e2e/regression/cases/004_slf_transform_batch/expect.json b/example/e2e/regression/cases/004_slf_transform_batch/expect.json index b7e651e..9de27c0 100644 --- a/example/e2e/regression/cases/004_slf_transform_batch/expect.json +++ b/example/e2e/regression/cases/004_slf_transform_batch/expect.json @@ -1,6 +1,6 @@ { "status": "ok", "data": { - "Class": "five" + "Class": "five" } } diff --git a/example/e2e/regression/cases/004_slf_transform_batch/test.yaml b/example/e2e/regression/cases/004_slf_transform_batch/test.yaml index ba5b711..4cb4fed 100644 --- a/example/e2e/regression/cases/004_slf_transform_batch/test.yaml +++ b/example/e2e/regression/cases/004_slf_transform_batch/test.yaml @@ -13,21 +13,16 @@ pipeline: assert: action: validator:assert - init: - expect: expect actual: $AsJSON($test.Cmd[0].Stdout) - expect: $LoadJSON('${parentPath}/${expect}.json') + expect: $LoadJSON('${parentPath}/expect.json') assertBatch: action: validator:assert init: - expect: expect-batch actual: $AsJSON($test.Cmd[1].Stdout) - expect: $LoadJSON('${parentPath}/${expect}.json') + expect: $LoadJSON('${parentPath}/expect-batch.json') assertBatchCache: action: validator:assert - init: - expect: expect-batch-cache actual: $AsJSON($test.Cmd[2].Stdout) - expect: $LoadJSON('${parentPath}/${expect}.json') + expect: $LoadJSON('${parentPath}/expect-batch-cache.json') diff --git a/example/e2e/regression/cases/005_lookup_transform/test.yaml b/example/e2e/regression/cases/005_lookup_transform/test.yaml index 66f485a..e223c48 100644 --- a/example/e2e/regression/cases/005_lookup_transform/test.yaml +++ b/example/e2e/regression/cases/005_lookup_transform/test.yaml @@ -12,16 +12,10 @@ pipeline: assert: action: validator:assert - init: - actual: $AsJSON($test.Cmd[0].Stdout) - expect: expect-ll - actual: $actual - expect: $LoadJSON('${parentPath}/${expect}.json') + actual: $AsJSON($test.Cmd[0].Stdout) + expect: $LoadJSON('${parentPath}/expect-ll.json') assertCache: action: validator:assert - init: - actual: $AsJSON($test.Cmd[1].Stdout) - expect: expect-ko - actual: $actual - expect: $LoadJSON('${parentPath}/${expect}.json') + actual: $AsJSON($test.Cmd[1].Stdout) + expect: $LoadJSON('${parentPath}/expect-ko.json') diff --git a/example/e2e/regression/cases/006_metrics/test.yaml b/example/e2e/regression/cases/006_metrics/test.yaml index c9c5cc6..71111b4 100644 --- a/example/e2e/regression/cases/006_metrics/test.yaml +++ b/example/e2e/regression/cases/006_metrics/test.yaml @@ -14,16 +14,12 @@ pipeline: assertMetricsInvalid: action: validator:assert - init: - actual: $AsJSON($test.Cmd[3].Stdout) - actual: $actual + actual: $AsJSON($test.Cmd[3].Stdout) expect: $LoadJSON('${parentPath}/metrics.json') assertMetricsJSONError: action: validator:assert - init: - actual: $AsJSON($test.Cmd[4].Stdout) - actual: $actual + actual: $AsJSON($test.Cmd[4].Stdout) expect: $LoadJSON('${parentPath}/metrics-json.json') diff --git a/example/e2e/regression/cases/007_health/metrics.json b/example/e2e/regression/cases/007_health/metrics.json deleted file mode 100644 index 5f8b2e0..0000000 --- a/example/e2e/regression/cases/007_health/metrics.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "sli": 1 -} diff --git a/example/e2e/regression/cases/007_health/test.yaml b/example/e2e/regression/cases/007_health/test.yaml index de464d2..f1f3e36 100644 --- a/example/e2e/regression/cases/007_health/test.yaml +++ b/example/e2e/regression/cases/007_health/test.yaml @@ -10,8 +10,8 @@ pipeline: assertMetrics: action: validator:assert - init: - actual: $AsJSON($test.Cmd[0].Stdout) - actual: $actual - expect: $LoadJSON('${parentPath}/metrics.json') - + actual: $AsJSON($test.Cmd[0].Stdout) + expect: | + { + "sli": 1 + } \ No newline at end of file diff --git a/example/e2e/regression/cases/010_triton_sli/expect.json b/example/e2e/regression/cases/010_triton/expect.json similarity index 94% rename from example/e2e/regression/cases/010_triton_sli/expect.json rename to example/e2e/regression/cases/010_triton/expect.json index 1c0f261..b956c2c 100644 --- a/example/e2e/regression/cases/010_triton_sli/expect.json +++ b/example/e2e/regression/cases/010_triton/expect.json @@ -2,4 +2,4 @@ "data": { "expand": 25 } -} \ No newline at end of file +} diff --git a/example/e2e/regression/cases/010_triton/request.json b/example/e2e/regression/cases/010_triton/request.json new file mode 100644 index 0000000..892d267 --- /dev/null +++ b/example/e2e/regression/cases/010_triton/request.json @@ -0,0 +1,6 @@ +{ + "cache_key": "test-a-a-10", + "sa": "a", + "sl": "a", + "aux": "test" +} diff --git a/example/e2e/regression/cases/010_triton/test.yaml b/example/e2e/regression/cases/010_triton/test.yaml new file mode 100644 index 0000000..2a2b5b1 --- /dev/null +++ b/example/e2e/regression/cases/010_triton/test.yaml @@ -0,0 +1,22 @@ +init: + parentPath: $parent.path + +pipeline: + test: + action: http/runner:send + requests: + - Method: POST + URL: http://127.0.0.1:8086/v1/api/model/sli_triton/eval + JSONBody: $LoadJSON('${parentPath}/request.json') + Expect: + Code: 200 + Body: $LoadJSON('${parent.path}/expect.json') + + testMlyCli: + action: exec:run + target: $target + checkError: true + commands: + - /tmp/e2e/mlyc -m sli_triton -a 'sa:a;sl:a;aux:test' + - /tmp/e2e/mlyc -m sli_triton -a 'sa:b;sl:a;aux:test' + - /tmp/e2e/mlyc -m sli_triton_grpc_explicit -a 'sa:c;sl:a;aux:test' diff --git a/example/e2e/regression/cases/010_triton_sli/test.yaml b/example/e2e/regression/cases/010_triton_sli/test.yaml deleted file mode 100644 index 116f5bb..0000000 --- a/example/e2e/regression/cases/010_triton_sli/test.yaml +++ /dev/null @@ -1,13 +0,0 @@ -init: - parentPath: $parent.path - -pipeline: - test: - action: http/runner:send - requests: - - Method: POST - URL: http://127.0.0.1:8086/v1/api/model/sli_triton/eval - JSONBody: {"batch_size": 0, "cache_key": "test-a-a-10", "sa": "a", "sl":"a", "aux": "test"} - Expect: - Code: 200 - Body: $LoadJSON('${parent.path}/expect.json') diff --git a/example/e2e/regression/cases/011_aux_cache/expect-cache.json b/example/e2e/regression/cases/011_aux_cache/expect-cache.json new file mode 100644 index 0000000..6dc5fe2 --- /dev/null +++ b/example/e2e/regression/cases/011_aux_cache/expect-cache.json @@ -0,0 +1,6 @@ +{ + "status": "cached", + "data": { + "Class": "five" + } +} \ No newline at end of file diff --git a/example/e2e/regression/cases/011_aux_cache/expect.json b/example/e2e/regression/cases/011_aux_cache/expect.json new file mode 100644 index 0000000..94f5bdb --- /dev/null +++ b/example/e2e/regression/cases/011_aux_cache/expect.json @@ -0,0 +1,6 @@ +{ + "status": "ok", + "data": { + "Class": "five" + } +} \ No newline at end of file diff --git a/example/e2e/regression/cases/011_aux_cache/test.yaml b/example/e2e/regression/cases/011_aux_cache/test.yaml new file mode 100644 index 0000000..9206be7 --- /dev/null +++ b/example/e2e/regression/cases/011_aux_cache/test.yaml @@ -0,0 +1,23 @@ +init: + parentPath: $parent.path + +pipeline: + test: + action: exec:run + target: $target + checkError: true + commands: + - /tmp/e2e/mlyc -m slf -s slft -a 'sa:b;sl:a;aux:c1' + - /tmp/e2e/mlyc -m slf -s slft -a 'sa:b;sl:a;aux:c2' + + assert: + action: validator:assert + actual: $AsJSON($test.Cmd[0].Stdout) + expect: $LoadJSON('${parentPath}/expect.json') + + assertCache: + action: validator:assert + actual: $AsJSON($test.Cmd[1].Stdout) + expect: $LoadJSON('${parentPath}/expect-cache.json') + + diff --git a/example/e2e/regression/cases/012_router/test.yaml b/example/e2e/regression/cases/012_router/test.yaml new file mode 100644 index 0000000..c3ee331 --- /dev/null +++ b/example/e2e/regression/cases/012_router/test.yaml @@ -0,0 +1,21 @@ +init: + parentPath: $parent.path + +pipeline: + test: + action: exec:run + target: $target + checkError: true + commands: + - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:1;sa:b;sl:a;aux:1' + - /tmp/e2e/mlyc -m sli_router -a 'routing_id|int:2;sa:b;sl:a;aux:2' + + # assert: + # action: validator:assert + # actual: $AsJSON($test.Cmd[0].Stdout) + # expect: $LoadJSON('${parentPath}/expect.json') + + # assertCache: + # action: validator:assert + # actual: $AsJSON($test.Cmd[1].Stdout) + # expect: $LoadJSON('${parentPath}/expect-cache.json') diff --git a/example/e2e/regression/regression.yaml b/example/e2e/regression/regression.yaml index c80c2b9..297fc5a 100644 --- a/example/e2e/regression/regression.yaml +++ b/example/e2e/regression/regression.yaml @@ -16,7 +16,7 @@ pipeline: subPath: 'cases/${index}_*' - range: 1..010 + range: 1..012 template: checkSkip: action: nop diff --git a/example/e2e/regression/reset/fli.json b/example/e2e/regression/reset/fli.json new file mode 100644 index 0000000..ee1aac4 --- /dev/null +++ b/example/e2e/regression/reset/fli.json @@ -0,0 +1 @@ +[{}] \ No newline at end of file diff --git a/example/e2e/regression/reset/flinc.json b/example/e2e/regression/reset/flinc.json new file mode 100644 index 0000000..ee1aac4 --- /dev/null +++ b/example/e2e/regression/reset/flinc.json @@ -0,0 +1 @@ +[{}] \ No newline at end of file diff --git a/example/e2e/run.yaml b/example/e2e/run.yaml index e84c146..807012c 100644 --- a/example/e2e/run.yaml +++ b/example/e2e/run.yaml @@ -24,13 +24,18 @@ pipeline: request: '@build' tasks: '*' + test: + action: run + description: run regression test + request: '@regression/regression' + stop: stop: action: run request: '@regression/app' tasks: stop - test: - action: run - description: run regression test - request: '@regression/regression' + stopSystem: + action: run + request: '@system' + tasks: stop diff --git a/example/e2e/start-with-pid.sh b/example/e2e/start-with-pid.sh new file mode 100644 index 0000000..b98e52b --- /dev/null +++ b/example/e2e/start-with-pid.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Starts a command in the background and writes its PID to a file. +# Useful for monitoring service startup with check-port.sh +# +# Usage: start-with-pid.sh PIDFILE COMMAND [ARGS...] +# +# Example: +# ./start-with-pid.sh /tmp/myservice.pid ./myservice --port 8080 +# ./check-port.sh http://localhost:8080 30 1 /tmp/myservice.pid + +set -e + +PIDFILE=$1 +if [ -z "$PIDFILE" ]; then + echo "usage: $0 PIDFILE COMMAND [ARGS...]" + echo "pidfile path required" + exit 2 +fi +shift + +if [ $# -eq 0 ]; then + echo "usage: $0 PIDFILE COMMAND [ARGS...]" + echo "command required" + exit 2 +fi + +# Ensure parent directory exists +PIDDIR=$(dirname "$PIDFILE") +if [ ! -d "$PIDDIR" ]; then + mkdir -p "$PIDDIR" +fi + +# Remove stale PID file if it exists +rm -f "$PIDFILE" + +# Start the command in the background +"$@" & +PID=$! + +# Write PID to file +echo "$PID" > "$PIDFILE" + +echo "started process $PID, pidfile: $PIDFILE" +echo "command: $*" diff --git a/example/e2e/system.yaml b/example/e2e/system.yaml index 215ec65..c9eb299 100644 --- a/example/e2e/system.yaml +++ b/example/e2e/system.yaml @@ -1,12 +1,36 @@ init: modelRepo: ${appPath}/example/e2e/data/triton_model_repository + pipeline: stop: - services: + mly_aero: action: docker:stop - images: - - aerospike-server - - tritonserver + names: + - mly_aero + mly_triton: + action: docker:stop + names: + - mly_triton + mly_triton_grpc_explicit: + action: docker:stop + names: + - mly_triton_grpc_explicit + + copyModels: + action: exec:run + target: $target + checkError: true + commands: + - mkdir -p ${modelRepo}/sli_1/1/model.savedmodel + - cp -r ${appPath}/example/model/string_lookups_int_model/* ${modelRepo}/sli_1/1/model.savedmodel + - mkdir -p ${modelRepo}/sli_2/1/model.savedmodel + - cp -r ${appPath}/example/model/string_lookups_int_model/* ${modelRepo}/sli_2/1/model.savedmodel + - mkdir -p ${modelRepo}/sli_exp/1/model.savedmodel + - cp -r ${appPath}/example/model/string_lookups_int_model/* ${modelRepo}/sli_exp/1/model.savedmodel + - mkdir -p ${modelRepo}/r2/1/model.savedmodel + - cp -r ${appPath}/example/model/r2_output_model/* ${modelRepo}/r2/1/model.savedmodel + - mkdir -p ${modelRepo}/ko/1/model.savedmodel + - cp -r ${appPath}/example/model/keyed_out_model/* ${modelRepo}/ko/1/model.savedmodel start: aerospike: @@ -23,15 +47,36 @@ pipeline: name: mly_triton image: 'nvcr.io/nvidia/tritonserver:25.02-py3' platform: linux/arm64 - port: "8000/tcp" + port: "8070/tcp" + ports: + 8070: 8000 + 8071: 8001 + 8072: 8002 + mount: + ${modelRepo}: /models + cmd: + - tritonserver + - --model-repository=/models + - --backend-config=tensorflow,default-max-batch-size=32 + entrypoint: + - /opt/nvidia/nvidia_entrypoint.sh + + triton_grpc_explicit: + action: docker:run + name: mly_triton_grpc_explicit + image: 'nvcr.io/nvidia/tritonserver:25.02-py3' + platform: linux/arm64 + port: "8101/tcp" ports: - 8000: 8000 - 8001: 8001 - 8002: 8002 + 8100: 8000 + 8101: 8001 + 8102: 8002 mount: ${modelRepo}: /models cmd: - tritonserver - --model-repository=/models + - --model-control-mode=explicit + - --backend-config=tensorflow,default-max-batch-size=32 entrypoint: - /opt/nvidia/nvidia_entrypoint.sh \ No newline at end of file diff --git a/example/server/etc/config.yaml b/example/server/etc/config.yaml index 7467fbf..2d9164d 100644 --- a/example/server/etc/config.yaml +++ b/example/server/etc/config.yaml @@ -3,7 +3,7 @@ Endpoint: models: - id: sli_triton - url: http://localhost:8000 + url: http://localhost:8070 platform: triton triton: modelName: sli_2 @@ -11,19 +11,100 @@ models: useDict: true debug: false datastore: sli_triton + inputs: + - name: sa + wildcard: true + - name: sl + wildcard: true + - name: aux + wildcard: true + auxiliary: true + outputs: + - name: expand + dataType: int + + - id: sli_triton_grpc_explicit + platform: triton + triton: + modelName: sli_exp + serverID: triton_grpc_explicit + useDict: true + debug: false + datastore: sli_triton keyfields: - sa - sl - aux inputs: - name: sa + wildcard: true - name: sl + wildcard: true - name: aux + wildcard: true auxiliary: true outputs: - name: expand dataType: int + - id: sli_router + platform: triton + mode: router + triton: + serverID: triton_grpc_explicit + Router: + ConfigURL: ${appPath}/example/server/etc/router.yaml + InputName: routing_id + Workers: 2 + Global: + Exists: false + PredictionReplacements: + - Name: expand + Type: int64 + Value: 0 + Output: + FieldName: model_id + Datastore: sli_triton + useDict: true + Inputs: + - name: routing_id + datatype: int64 + wildcard: true + - name: sl + datatype: string + wildcard: true + - name: sa + datatype: string + wildcard: true + - name: aux + datatype: string + auxiliary: true + Outputs: + - name: expand + datatype: int64 + - name: model_id + datatype: string + + Test: + SingleBatch: true + + - id: ko_triton + platform: triton + triton: + modelName: ko + serverID: triton_grpc_explicit + debug: false + inputs: + - name: buggy_aux + auxiliary: true + - name: i + dataType: int + outputs: + - name: i0_copy + dataType: int + - name: i1_copy + dataType: int + - id: sli url: ${appPath}/example/model/string_lookups_int_model useDict: true @@ -108,6 +189,15 @@ connections: port: 3000 timeout: 300 +tritonServers: + - id: triton_grpc_explicit + grpcBaseURL: localhost:8101 + grpcConnectParams: + baseDelayMs: 10 + multiplier: 2 + jitter: 0.1 + maxDelayMs: 150 + datastores: - id: sli_triton connection: localL1 diff --git a/example/server/etc/router.yaml b/example/server/etc/router.yaml new file mode 100644 index 0000000..9d06850 --- /dev/null +++ b/example/server/etc/router.yaml @@ -0,0 +1,5 @@ +entityMapping: + - entityID: 1 + modelName: sli_1 + - entityID: 2 + modelName: sli_2 \ No newline at end of file diff --git a/service/config/model.go b/service/config/model.go index 445aaa8..282c3a0 100644 --- a/service/config/model.go +++ b/service/config/model.go @@ -160,6 +160,10 @@ func (m *Model) Init(globalBatchConfig *batchconfig.BatcherConfig) { if m.Router != nil { m.Router.Init() } + + if m.Triton != nil { + m.Triton.Init(m.IsRouter()) + } } func (m *Model) Validate() error { @@ -183,7 +187,7 @@ func (m *Model) Validate() error { return fmt.Errorf("tensorflow model %s requires URL", m.ID) } - if m.Mode == "router" { + if m.IsRouter() { return fmt.Errorf("tensorflow model %s is not supported in router mode", m.ID) } case "triton": @@ -191,18 +195,23 @@ func (m *Model) Validate() error { return fmt.Errorf("triton model %s requires Triton configuration", m.ID) } - if err := m.Triton.Validate(m.Mode == "router", m.URL != ""); err != nil { + if err := m.Triton.Validate(m.IsRouter(), m.URL != ""); err != nil { return fmt.Errorf("triton model %s config invalid: %w", m.ID, err) } default: return fmt.Errorf("unsupported platform '%s' for model %s (supported: tensorflow, triton)", platform, m.ID) } - if m.Mode == "router" { + if m.IsRouter() { if m.Router == nil { return fmt.Errorf("router model %s requires Router configuration", m.ID) } + if m.Triton == nil { + // TODO support TensorFlow + return fmt.Errorf("router model %s requires Triton configuration", m.ID) + } + if err := m.Router.Validate(); err != nil { return fmt.Errorf("router model %s config invalid: %w", m.ID, err) } @@ -211,6 +220,10 @@ func (m *Model) Validate() error { return nil } +func (m *Model) IsRouter() bool { + return m.Mode == "router" +} + // ConfigCheck is a path to validate relationships with other config entities. func (m *Model) ConfigCheck(validDatastoreIDs map[string]struct{}, validTritonServerIDs map[string]struct{}) error { if m.DataStore != "" { @@ -259,10 +272,14 @@ type TritonConfig struct { Timeout int `json:",omitempty" yaml:",omitempty"` } -func (t *TritonConfig) Init() { +func (t *TritonConfig) Init(isRouter bool) { if t.Timeout == 0 { t.Timeout = 100 } + + if isRouter { + t.ModelName = "" + } } func (t *TritonConfig) Validate(isRouter bool, urlPresent bool) error { @@ -270,6 +287,10 @@ func (t *TritonConfig) Validate(isRouter bool, urlPresent bool) error { return fmt.Errorf("triton ModelName is required") } + if isRouter && t.ServerID == "" { + return fmt.Errorf("triton ServerID is required for router mode") + } + if t.ServerID == "" && !urlPresent { return fmt.Errorf("triton ServerID or Model.URL is required") } diff --git a/service/config/router.go b/service/config/router.go index d7cdd16..82638bd 100644 --- a/service/config/router.go +++ b/service/config/router.go @@ -6,18 +6,20 @@ type RouterConfig struct { // Required if Model.Mode is "router". ConfigURL string - // Required + // Required name of the input that will route the request to the backend. InputName string `json:",omitempty" yaml:",omitempty"` - // Unimplemented. - // If true, the router will batch the requests to the backend. - BatchBackend bool `json:",omitempty" yaml:",omitempty"` + // ForceBatchSize1 controls whether the router sends individual samples or batches by model. + // When false (default), requests within a single Predict() call that route to the + // same model evaluator are grouped into a single batched prediction call. + // When true, each sample is sent as an individual prediction request with batch size 1. + ForceBatchSize1 bool `json:",omitempty" yaml:",omitempty"` - // The maximum number of concurrent requests to the backend. + // The maximum number of concurrent batches dispatched to model evaluators. // Defaults to 50. Workers int `json:",omitempty" yaml:",omitempty"` - // The maximum number of requests to queue. + // The maximum number of batches to queue before rejecting. // Defaults to 1000. MaxQueueSize int `json:",omitempty" yaml:",omitempty"` @@ -64,6 +66,10 @@ func (o *RouterConfig) Init() { if o.MaxQueueSize == 0 { o.MaxQueueSize = 1000 } + + if o.Output.NoModelID == "" { + o.Output.NoModelID = "none" + } } func (o *RouterConfig) Validate() error { @@ -87,9 +93,5 @@ func (o *RouterConfig) Validate() error { return fmt.Errorf("global model does not exist but no prediction replacements were provided") } - if o.Output.NoModelID == "" { - o.Output.NoModelID = "none" - } - return nil } diff --git a/service/domain/input.go b/service/domain/input.go index 1dfe28e..f1101b9 100644 --- a/service/domain/input.go +++ b/service/domain/input.go @@ -21,6 +21,7 @@ type Input struct { Vocab bool // Auxiliary is true if this input isn't part of the model + // TODO redesign model IO vs server IO concerns Auxiliary bool Type reflect.Type diff --git a/service/domain/output.go b/service/domain/output.go index 7156336..2b7b38c 100644 --- a/service/domain/output.go +++ b/service/domain/output.go @@ -8,14 +8,19 @@ import ( // Output represents model output type Output struct { + // Used in Stream Name string - // Primarily shown in config + // Shown in config DataType string - // DataTypeKind is used only for GBQ tool + // Only for GBQ tool DataTypeKind reflect.Kind - Index int + + // Used to extract output from *tf.Operation. + // Eventually becomes part of tf.Session.Run() parameter fetches ([]tf.Output). + Index int + *tf.Operation goType reflect.Type diff --git a/service/domain/signature.go b/service/domain/signature.go index 12745b7..22a9f21 100644 --- a/service/domain/signature.go +++ b/service/domain/signature.go @@ -4,6 +4,7 @@ package domain // Contains information required to extract vocabularies, unmarshal requests, and validate request inputs. // TODO document and address issues if reloaded model IO changes. type Signature struct { + // Method is unused. Method string Inputs []Input diff --git a/service/endpoint/checker/self.go b/service/endpoint/checker/self.go index e48dbdc..2e5ebd2 100644 --- a/service/endpoint/checker/self.go +++ b/service/endpoint/checker/self.go @@ -14,7 +14,7 @@ import ( ) func SelfTest(host []*client.Host, timeout time.Duration, modelID string, usesTransformer bool, tp config.TestPayload, debug bool) error { - cli, err := client.New(modelID, host, client.WithDebug(true)) + cli, err := client.New(modelID, host, client.WithDebug(true), client.WithPrometheusMetrics(false)) if err != nil { return fmt.Errorf("%s:%w", modelID, err) } diff --git a/service/endpoint/model.go b/service/endpoint/model.go index 6afe7c3..b065443 100644 --- a/service/endpoint/model.go +++ b/service/endpoint/model.go @@ -52,7 +52,7 @@ func Build( mux *http.ServeMux, config *Config, datastores map[string]*datastore.Service, - tritonClients map[string]triton.TritonClient, + tritonServices map[string]*triton.Service, hooks []Hook, metrics *gmetric.Service, promReg prometheus.Registerer, @@ -84,16 +84,26 @@ func Build( Namespace: "mly", Subsystem: "model", Name: "idletime", - - Help: "measured time between requests in nanoseconds", - - Buckets: buckets, + Help: "measured time between requests in nanoseconds", + Buckets: buckets, }, []string{"model"}) var err error err = promReg.Register(obsv) if err != nil { - return err + return fmt.Errorf("failed to register idletime histogram: %w", err) + } + + healthGauge := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "mly", + Subsystem: "model", + Name: "reload_success", + Help: "successfully reloaded model", + }, []string{"model"}) + + err = promReg.Register(healthGauge) + if err != nil { + return fmt.Errorf("failed to register health gauge: %w", err) } serviceOpts := make([]service.Option, 0) @@ -129,7 +139,13 @@ func Build( var modelSrv *service.Service var err error - modelSrv, err = service.New(context.Background(), model, fs, metrics, datastores, tritonClients, sema, cfge.MaxEvaluatorWait, serviceOpts...) + modelSrv, err = service.NewV2(context.Background(), model, fs, metrics, service.NewArgs{ + Datastores: datastores, + TritonServices: tritonServices, + Semaphore: sema, + MaxEvaluatorWait: cfge.MaxEvaluatorWait, + HealthGauge: healthGauge, + }, serviceOpts...) if err != nil { return fmt.Errorf("failed to create service for model:%v, err:%w", model.ID, err) diff --git a/service/endpoint/service.go b/service/endpoint/service.go index f191c9e..67eb051 100644 --- a/service/endpoint/service.go +++ b/service/endpoint/service.go @@ -207,7 +207,7 @@ func New(cfg *Config) (*Service, error) { return nil, fmt.Errorf("failed to create datastores: %w", err) } - tritonClients := make(map[string]triton.TritonClient) + tritonServices := make(map[string]*triton.Service) for _, server := range cfg.TritonServers { tritonClient, err := triton.NewClient(server) if err != nil { @@ -223,14 +223,14 @@ func New(cfg *Config) (*Service, error) { return nil, fmt.Errorf("failed to check triton server %s health: %w", server.ID, err) } - tritonClients[server.ID] = tritonClient + tritonServices[server.ID] = triton.NewService(tritonClient) } hooks := []Hook{ healthHandler, } - err = Build(mux, cfg, datastores, tritonClients, hooks, metrics, promReg) + err = Build(mux, cfg, datastores, tritonServices, hooks, metrics, promReg) if err != nil { return nil, err } diff --git a/service/handler.go b/service/handler.go index 5713ea3..45dadae 100644 --- a/service/handler.go +++ b/service/handler.go @@ -5,10 +5,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "log" "net/http" "reflect" + "strconv" "strings" "sync" "time" @@ -25,6 +25,29 @@ import ( "github.com/viant/mly/shared/stat" ) +// responseMarshalError signals that gojay.Marshal of the Response struct +// failed during writeResponse. The HTTP response is NOT yet committed, +// so ServeHTTP can still emit an explicit 5xx with a meaningful body. +// Surfaced as a typed error so it can be routed to its own metric bucket +// (sstat.ResponseMarshalError) and distinguished from upstream errors. +type responseMarshalError struct{ err error } + +func (e *responseMarshalError) Error() string { return e.err.Error() } +func (e *responseMarshalError) Unwrap() error { return e.err } + +// responseCommittedError signals that the HTTP response status line and +// headers have already been flushed to the client when the wrapped error +// occurred. The caller MUST NOT attempt to send a different status code: +// net/http will drop the second WriteHeader and emit a "superfluous +// response.WriteHeader call" warning, while the client still observes the +// original (200) status. Surfaced so ServeHTTP can log + exit instead of +// trying to overwrite the status line, and so the failure can be routed +// to its own metric bucket (sstat.ResponseCommittedError). +type responseCommittedError struct{ err error } + +func (e *responseCommittedError) Error() string { return e.err.Error() } +func (e *responseCommittedError) Unwrap() error { return e.err } + // Handler converts a model prediction HTTP request to its internal calls. type Handler struct { maxDuration time.Duration @@ -62,7 +85,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques if httpRequest.Method == http.MethodGet { request = h.service.NewRequest() if err := h.buildRequestFromQuery(httpRequest, request); err != nil { - http.Error(writer, err.Error(), http.StatusBadRequest) + h.writeError(writer, response, hStats, http.StatusBadRequest, err) return } } else { @@ -76,7 +99,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques defer func() { onDone(time.Now(), stats.Values()...) }() if err != nil { - stats.Append(sstat.ReadError{err}) + stats.Append(sstat.ReadError{Error: err}) if isDebug { log.Printf("[%v http] read error: %v\n", h.service.config.ID, err) } @@ -86,7 +109,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques code = http.StatusRequestEntityTooLarge } - http.Error(writer, err.Error(), code) + h.writeError(writer, response, hStats, code, err) return err } @@ -101,14 +124,14 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques err = gojay.Unmarshal(data[:size], request) if err != nil { werr := fmt.Errorf("unmarshal error: %w data: %s", err, string(data[:size])) - stats.Append(sstat.UnmarshalError{werr}) + stats.Append(sstat.UnmarshalError{Error: werr}) if isDebug { log.Printf("[%v http] unmarshal error: %v\n", h.service.config.ID, err) } - rmsg := fmt.Sprintf("%s (are your input types correct?)", err.Error()) - http.Error(writer, rmsg, http.StatusBadRequest) + displayErr := fmt.Errorf("%s (are your input types correct?)", err.Error()) + h.writeError(writer, response, hStats, http.StatusBadRequest, displayErr) return err } @@ -123,11 +146,17 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques if request == nil { // This isn't a particularly helpful message. // Currently, the only case this handles is if the request is too large. - http.Error(writer, "no request", http.StatusBadRequest) + h.writeError(writer, response, hStats, http.StatusBadRequest, errors.New("no request")) return } - err := h.handleAppRequest(ctx, writer, request, response) + err := h.service.Do(ctx, request, response) + if err != nil { + response.SetError(err) + } else { + err = h.writeResponse(writer, response, http.StatusOK) + } + if isDebug { data, merr := json.Marshal(response.Data) @@ -144,6 +173,30 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques } if err != nil { + // If the response was already committed (status + headers flushed), + // the wire status code is fixed at 200 and cannot be changed. Calling + // writeError here would log "superfluous WriteHeader" and silently + // drop the new status — the client still sees 200 + truncated body. + // Log unconditionally so this defect is visible in production, and + // emit a dedicated metric so it can be alerted independently of + // the generic ErrorKey bucket. + var committed *responseCommittedError + if errors.As(err, &committed) { + hStats.Append(sstat.ResponseCommittedError{Error: err}) + log.Printf("[%v http] response committed but write failed: %v", h.service.config.ID, err) + return + } + + // Marshal failure: response NOT committed; we will emit an explicit + // 5xx below. Track it in its own metric bucket so the operator can + // distinguish "we never sent anything" from "we sent something we + // shouldn't have". writeError clears response.Data before retrying, + // so the second marshal cannot fail for the same reason. + var marshal *responseMarshalError + if errors.As(err, &marshal) { + hStats.Append(sstat.ResponseMarshalError{Error: err}) + } + var status int if _, ok := err.(*clienterr.ClientError); ok { status = http.StatusBadRequest @@ -157,7 +210,7 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, httpRequest *http.Reques log.Printf("[%v http] status:%d error:%v", h.service.config.ID, status, err) } - http.Error(writer, err.Error(), status) + h.writeError(writer, response, hStats, status, err) } } @@ -175,27 +228,93 @@ func (h *Handler) buildRequestFromQuery(httpRequest *http.Request, request *requ return nil } -func (h *Handler) handleAppRequest(ctx context.Context, writer io.Writer, request *request.Request, response *Response) error { - if err := h.service.Do(ctx, request, response); err != nil { - response.SetError(err) - return err +// writeResponse marshals appResponse and emits it with explicit-commit +// semantics: +// +// - Marshal first; on failure return a typed responseMarshalError -- the +// response is NOT yet committed and the caller can still set a different +// status (typically a 5xx). +// - Set Content-Length explicitly so a truncated body is detectable on +// the client side as io.ErrUnexpectedEOF (without it, the client cannot +// distinguish "done" from "connection broke mid-body" on a 200 OK). +// - Call WriteHeader(status) explicitly so the status line is committed +// in a known order, not as a side effect of the first Write. +// - On Write failure return responseCommittedError so the caller knows +// the status code can no longer be changed. +// +// status is typically http.StatusOK for success responses; the writeError +// helper passes the appropriate 4xx/5xx for error responses so the wire +// shape is uniform across success and failure paths. +// +// This addresses the silent "200 OK + empty body" failure mode where a +// canceled connection caused the implicit auto-200 from Write to flush +// headers while the body bytes were lost. +func (h *Handler) writeResponse(writer http.ResponseWriter, appResponse *Response, status int) error { + appResponse.ServiceTimeMcs = int(time.Since(appResponse.started).Microseconds()) + + data, err := gojay.Marshal(appResponse) + if err != nil { + return &responseMarshalError{err: fmt.Errorf("marshal response: %w", err)} } - if err := h.writeResponse(writer, response); err != nil { - return err + if h.service.config.Debug { + log.Printf("[%v write] output:%s", h.service.config.ID, data) + } + + writer.Header().Set("Content-Type", "application/json") + writer.Header().Set("Content-Length", strconv.Itoa(len(data))) + writer.WriteHeader(status) + + if _, err := writer.Write(data); err != nil { + return &responseCommittedError{err: fmt.Errorf("write response body: %w", err)} } return nil } -func (h *Handler) writeResponse(writer io.Writer, appResponse *Response) error { - appResponse.ServiceTimeMcs = int(time.Now().Sub(appResponse.started).Microseconds()) - data, err := gojay.Marshal(appResponse) - if h.service.config.Debug { - log.Printf("[%v write] output:%s", h.service.config.ID, data) +// writeError emits an error response with the given HTTP status code as +// a JSON-encoded Response object (status="error", populated error +// message). It is the error-path counterpart to writeResponse and shares +// the same explicit-commit contract so clients always see Content-Length +// and a parseable JSON body regardless of success or failure. +// +// Side-effects on the response struct: +// +// - response.SetError(err) populates response.Error and sets +// response.Status = "error". +// - response.Data is cleared. This guarantees the marshal will succeed +// regardless of the prior state of Data, which matters when the +// original failure was itself a marshal error on a populated Data +// value. +// +// On a post-commit write failure (responseCommittedError) the status is +// already on the wire; we only log + emit the dedicated metric. +// +// On a marshal failure of the (cleared) error response (essentially +// impossible -- the struct now contains only string + int fields), we +// fall back to http.Error so the client at least receives a status code. +func (h *Handler) writeError(writer http.ResponseWriter, response *Response, hStats *stat.Values, status int, err error) { + response.SetError(err) + response.Data = nil + + werr := h.writeResponse(writer, response, status) + if werr == nil { + return + } + + var committed *responseCommittedError + if errors.As(werr, &committed) { + hStats.Append(sstat.ResponseCommittedError{Error: werr}) + log.Printf("[%v http] error response committed but write failed: %v (original error: %v)", h.service.config.ID, werr, err) + return + } + + var marshal *responseMarshalError + if errors.As(werr, &marshal) { + hStats.Append(sstat.ResponseMarshalError{Error: werr}) } - _, err = writer.Write(data) - return err + log.Printf("[%v http] failed to write error response: %v (original error: %v)", h.service.config.ID, werr, err) + http.Error(writer, err.Error(), status) } func (h *Handler) trackIdle() { @@ -224,6 +343,6 @@ func NewHandler(service *Service, pool *buffer.Pool, maxDuration time.Duration, lrObserver: lrOV.With(prometheus.Labels{"model": modelID}), overheadMetrics: m.MultiOperationCounter(location, modelID+"HTTPOverhead", modelID+" server HTTP startup overhead", time.Microsecond, time.Minute, 2, sstat.NewHttp()), - httpContextMetrics: m.MultiOperationCounter(location, modelID+"HTTPHandler", modelID+" server HTTP handler", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()), + httpContextMetrics: m.MultiOperationCounter(location, modelID+"HTTPHandler", modelID+" server HTTP handler", time.Microsecond, time.Minute, 2, sstat.NewHandler()), } } diff --git a/service/handler_test.go b/service/handler_test.go new file mode 100644 index 0000000..a33e59d --- /dev/null +++ b/service/handler_test.go @@ -0,0 +1,308 @@ +package service + +import ( + "errors" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/mly/service/config" + sstat "github.com/viant/mly/service/stat" + "github.com/viant/mly/shared/stat" +) + +// writeFailingResponseWriter wraps httptest.ResponseRecorder so that +// Write returns a configurable error after a configurable number of +// bytes. Used to simulate a broken-pipe condition where the client has +// already closed the connection (e.g. the bidder's 40 ms client timeout +// fired while MLY was mid-response). httptest.ResponseRecorder by itself +// never errors on Write. +// +// failAfter == 0 → the very first Write call errors after writing zero +// body bytes (the headers have still been committed at that point by +// writer.WriteHeader, which is the precise condition that produced the +// observed 200-OK-with-empty-body wire trace). +type writeFailingResponseWriter struct { + *httptest.ResponseRecorder + failAfter int // bytes written successfully before Write starts erroring + written int // total successful body bytes so far + failErr error +} + +func newWriteFailingResponseWriter(failAfter int) *writeFailingResponseWriter { + return &writeFailingResponseWriter{ + ResponseRecorder: httptest.NewRecorder(), + failAfter: failAfter, + failErr: errors.New("simulated broken pipe"), + } +} + +func (w *writeFailingResponseWriter) Write(p []byte) (int, error) { + remaining := w.failAfter - w.written + if remaining <= 0 { + return 0, w.failErr + } + if len(p) <= remaining { + n, err := w.ResponseRecorder.Write(p) + w.written += n + return n, err + } + n, _ := w.ResponseRecorder.Write(p[:remaining]) + w.written += n + return n, w.failErr +} + +// newTestHandler constructs a Handler with the minimum scaffolding +// required to exercise writeResponse. The metric Operations are left +// nil — writeResponse does not touch them, only ServeHTTP does. Tests +// that exercise ServeHTTP need a different fixture (not provided here +// because Service.Do depends on a fully wired tfmodel.Service which is +// out of scope for a unit test). +func newTestHandler(modelID string, debug bool) *Handler { + return &Handler{ + service: &Service{ + config: &config.Model{ID: modelID, Debug: debug}, + }, + } +} + +// TestWriteResponse_Success verifies the happy path: a fully populated +// Response is marshaled, headers are set explicitly (Content-Type, +// Content-Length), the status is committed at 200, and the body bytes +// match the marshaled JSON. This locks in the explicit-commit contract +// that lets clients detect truncation via Content-Length mismatch. +func TestWriteResponse_Success(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{ + Status: "ok", + DictHash: 42, + started: time.Now().Add(-time.Millisecond), + } + + rec := httptest.NewRecorder() + err := h.writeResponse(rec, resp, http.StatusOK) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type"), + "Content-Type must be set explicitly") + + cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) + require.NoError(t, atoiErr, "Content-Length must be a parseable integer") + assert.Equal(t, rec.Body.Len(), cl, + "Content-Length must match actual body length so clients can detect truncation") + + body := rec.Body.String() + assert.Contains(t, body, `"status":"ok"`) + assert.Contains(t, body, `"dictHash":42`) + assert.Contains(t, body, `"serviceTimeMcs":`, + "serviceTimeMcs must be present so the bidder can record mly_eval_duration_us") +} + +// TestWriteResponse_WriteFailureReturnsCommittedError simulates the +// failure mode that drives the bidder's invalid_json class on the wire: +// the body Write fails (broken pipe) AFTER WriteHeader has already +// committed the 200 status line. The post-condition is that: +// - writeResponse returns *responseCommittedError so the caller knows +// the status code can no longer be changed, +// - the status code on the wire is the originally-committed 200 (NOT +// the 500 we would otherwise want to send), +// - Content-Length was set, so a downstream client correctly checking +// it would observe an early-EOF / unexpected-EOF condition rather +// than silently treating the empty body as a valid response. +func TestWriteResponse_WriteFailureReturnsCommittedError(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now()} + + rec := newWriteFailingResponseWriter(0) + err := h.writeResponse(rec, resp, http.StatusOK) + + require.Error(t, err) + + var committed *responseCommittedError + require.True(t, errors.As(err, &committed), + "expected *responseCommittedError, got %T: %v", err, err) + assert.ErrorIs(t, err, rec.failErr, + "wrapped error chain must reach the underlying broken-pipe error") + + assert.Equal(t, http.StatusOK, rec.Code, + "status was committed before Write failed; explicit-commit contract") + assert.NotEmpty(t, rec.Header().Get("Content-Length"), + "Content-Length must be set BEFORE Write so client can detect truncation") + assert.Equal(t, 0, rec.Body.Len(), + "no body bytes should have been written on the failAfter=0 case") +} + +// TestWriteResponse_PartialWriteReturnsCommittedError covers the +// truncated-body case: the headers + status flush, then a few body +// bytes succeed, then the connection breaks. The committed-error type +// must still surface so ServeHTTP's error branch knows not to call +// http.Error (which would emit "superfluous WriteHeader" log noise). +func TestWriteResponse_PartialWriteReturnsCommittedError(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now()} + + rec := newWriteFailingResponseWriter(5) + err := h.writeResponse(rec, resp, http.StatusOK) + + require.Error(t, err) + var committed *responseCommittedError + require.True(t, errors.As(err, &committed), + "partial write must also yield *responseCommittedError") + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 5, rec.Body.Len(), + "exactly 5 body bytes should have been flushed before failure") +} + +// TestWriteResponse_HasContentLengthMatchingBody locks in the invariant +// that Content-Length declared in the header equals the bytes the +// handler intends to write. Without this, a client cannot distinguish +// "done" from "connection broke mid-body" on a 200 OK response — which +// is the root mechanism that allowed the bidder-side io.ReadAll swallow +// (shared/client/service.go) to silently produce empty-body +// invalid_json events. +func TestWriteResponse_HasContentLengthMatchingBody(t *testing.T) { + h := newTestHandler("test", false) + + cases := []struct { + name string + resp *Response + }{ + {"empty", &Response{started: time.Now()}}, + {"with-status", &Response{Status: "ok", started: time.Now()}}, + {"with-error", &Response{Status: "error", Error: "something failed", started: time.Now()}}, + {"with-dict-hash", &Response{Status: "ok", DictHash: 0xdeadbeef, started: time.Now()}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + require.NoError(t, h.writeResponse(rec, tc.resp, http.StatusOK)) + + cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) + require.NoError(t, atoiErr) + assert.Equal(t, rec.Body.Len(), cl, + "declared Content-Length must equal actual body bytes") + }) + } +} + +// TestResponseCommittedError_Unwrap verifies that the typed error +// participates correctly in errors.Is / errors.As chains. ServeHTTP +// relies on errors.As to detect committed-error from arbitrarily-deep +// wrappings. +func TestResponseCommittedError_Unwrap(t *testing.T) { + inner := errors.New("underlying broken pipe") + wrapped := &responseCommittedError{err: inner} + + var target *responseCommittedError + assert.True(t, errors.As(wrapped, &target)) + assert.True(t, errors.Is(wrapped, inner), + "errors.Is must traverse Unwrap to the underlying cause") +} + +// TestResponseMarshalError_Unwrap is the symmetric assertion for the +// marshal-failure sentinel, used by ServeHTTP to route marshal failures +// to their own metric bucket while still emitting an HTTP 5xx response. +func TestResponseMarshalError_Unwrap(t *testing.T) { + inner := errors.New("malformed Response struct") + wrapped := &responseMarshalError{err: inner} + + var target *responseMarshalError + assert.True(t, errors.As(wrapped, &target)) + assert.True(t, errors.Is(wrapped, inner)) +} + +// TestWriteResponse_HonorsStatusParam verifies that writeResponse +// commits the supplied status code rather than always 200. This is the +// foundation for the unified error-response wire format: writeError +// uses writeResponse with 4xx/5xx so success and error responses share +// shape (Content-Type, Content-Length, JSON body) and only differ in +// status line + populated fields. +func TestWriteResponse_HonorsStatusParam(t *testing.T) { + h := newTestHandler("test", false) + cases := []int{ + http.StatusOK, + http.StatusBadRequest, + http.StatusRequestEntityTooLarge, + http.StatusTooManyRequests, + http.StatusInternalServerError, + } + for _, status := range cases { + t.Run(http.StatusText(status), func(t *testing.T) { + resp := &Response{Status: "ok", started: time.Now()} + rec := httptest.NewRecorder() + require.NoError(t, h.writeResponse(rec, resp, status)) + assert.Equal(t, status, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + assert.NotEmpty(t, rec.Header().Get("Content-Length")) + }) + } +} + +// TestWriteError_EmitsJSONErrorWithStatus locks in the wire shape +// promised to clients on the error path: 4xx/5xx + JSON Response body +// with status="error", populated error message, and serviceTimeMcs. +// This is what makes a defensive consumer-side check +// (e.g. mediator's `if response.Error != "" { ... }`) actually fire on +// real predict-time errors instead of silently no-op'ing. +func TestWriteError_EmitsJSONErrorWithStatus(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now(), Data: "leftover-data"} + rec := httptest.NewRecorder() + hStats := stat.NewValues() + + h.writeError(rec, resp, hStats, http.StatusInternalServerError, errors.New("upstream blew up")) + + assert.Equal(t, http.StatusInternalServerError, rec.Code, + "error status must reach the wire") + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + cl, atoiErr := strconv.Atoi(rec.Header().Get("Content-Length")) + require.NoError(t, atoiErr) + assert.Equal(t, rec.Body.Len(), cl) + + body := rec.Body.String() + assert.Contains(t, body, `"status":"error"`, + "writeError must populate response.Status as error") + assert.Contains(t, body, `"error":"upstream blew up"`, + "writeError must populate response.Error from the supplied error") + assert.NotContains(t, body, "leftover-data", + "writeError must clear response.Data so the original Data does not leak into the error body") +} + +// TestWriteError_FallsBackToHTTPErrorOnCommittedFailure verifies that +// when the error response's body Write fails after status commit, the +// fallback path appends a metric and returns without panic. The status +// is already on the wire so http.Error inside the fallback is a no-op, +// but the metric attribution and log line are what matter for +// diagnosing the cliff scenario where both the success and error +// responses fail to flush. +func TestWriteError_FallsBackToHTTPErrorOnCommittedFailure(t *testing.T) { + h := newTestHandler("test", false) + resp := &Response{Status: "ok", started: time.Now()} + rec := newWriteFailingResponseWriter(0) + hStats := stat.NewValues() + + h.writeError(rec, resp, hStats, http.StatusInternalServerError, errors.New("upstream blew up")) + + assert.Equal(t, http.StatusInternalServerError, rec.Code, + "status was committed before the body write failed") + assert.NotEmpty(t, hStats.Values(), + "hStats must record the post-commit failure for metric attribution") + + var sawCommitted bool + for _, v := range hStats.Values() { + if _, ok := v.(sstat.ResponseCommittedError); ok { + sawCommitted = true + break + } + } + assert.True(t, sawCommitted, + "hStats must include a sstat.ResponseCommittedError marker") +} diff --git a/service/new.go b/service/new.go new file mode 100644 index 0000000..98f54bf --- /dev/null +++ b/service/new.go @@ -0,0 +1,105 @@ +package service + +import ( + "context" + "fmt" + "reflect" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/viant/afs" + "github.com/viant/gmetric" + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/platform/factory" + "github.com/viant/mly/service/stat" + "github.com/viant/mly/service/triton" + "github.com/viant/mly/shared/datastore" + sstat "github.com/viant/mly/shared/stat" + "golang.org/x/sync/semaphore" +) + +// NewArgs is an open-to-extension approach to keeping the NewV2() API invariant +// Potential tech debt: most likely we should be encapsulating parameters more appropriately. +// Natural encapsulation boundaries will emerge when we start seeing what New() also initializes along with just Service. +type NewArgs struct { + Datastores map[string]*datastore.Service + TritonServices map[string]*triton.Service + Semaphore *semaphore.Weighted + MaxEvaluatorWait time.Duration + + HealthGauge *prometheus.GaugeVec +} + +// New creates a service with platform router support +func New( + ctx context.Context, + cfg *config.Model, + fs afs.Service, + metrics *gmetric.Service, + datastores map[string]*datastore.Service, + tritonServices map[string]*triton.Service, + sema *semaphore.Weighted, + maxEvaluatorWait time.Duration, + options ...Option, +) (*Service, error) { + return NewV2(ctx, cfg, fs, metrics, NewArgs{ + Datastores: datastores, + TritonServices: tritonServices, + Semaphore: sema, + MaxEvaluatorWait: maxEvaluatorWait, + }, options...) +} + +// New creates a service with platform router support +func NewV2( + ctx context.Context, + cfg *config.Model, + fs afs.Service, + metrics *gmetric.Service, + args NewArgs, + options ...Option, +) (*Service, error) { + + if metrics == nil { + metrics = gmetric.New() + } + + location := reflect.TypeOf(Service{}).PkgPath() + + cfg.Init(nil) + + // Create platform evaluator context + evaluatorContext, err := factory.CreateEvaluator(cfg, fs, metrics, args.Semaphore, args.MaxEvaluatorWait, args.TritonServices) + if err != nil { + return nil, fmt.Errorf("failed to create platform evaluator for model %s: %w", cfg.ID, err) + } + + srv := &Service{ + config: cfg, + evaluator: evaluatorContext, + useDatastore: cfg.UseDictionary() && cfg.DataStore != "", + serviceMetric: metrics.MultiOperationCounter(location, cfg.ID+"Perf", cfg.ID+" service performance", time.Microsecond, time.Minute, 2, stat.NewProvider()), + reloadPollTicker: time.NewTicker(time.Duration(cfg.ReloadPollIntervalSeconds) * time.Second), + reloadTimeout: time.Duration(cfg.ReloadTimeoutSeconds) * time.Second, + } + + if args.HealthGauge != nil { + srv.healthGauge = args.HealthGauge.With(prometheus.Labels{"model": cfg.ID}) + } + + // Set up reload metrics for platforms that support reloading + srv.reloadMetric = metrics.MultiOperationCounter(location, cfg.ID+"Reload", cfg.ID+" reloading", time.Microsecond, time.Minute, 1, sstat.NewCtxErrOnly()) + + for _, opt := range options { + opt.Apply(srv) + } + + err = srv.initializeService(ctx, cfg, fs, metrics, args.Datastores) + if err != nil { + return nil, err + } + + go srv.pollModelReload() + + return srv, err +} diff --git a/service/platform/evaluator.go b/service/platform/evaluator.go index 0ebf621..0c88ca7 100644 --- a/service/platform/evaluator.go +++ b/service/platform/evaluator.go @@ -19,24 +19,26 @@ const ( type PlatformEvaluator interface { Predictor - // Signature returns underlying model's signature + // Signature returns underlying model's signature. + // This is expected to return non-nil after ReloadIfNeeded() succeeds. Signature() *domain.Signature - // Dictionary returns vocabulary if available + // Dictionary returns vocabulary if available. Dictionary() *common.Dictionary - // Inputs returns the model input definitions for request validation + // Inputs returns the model input definitions for request validation. + // This will be invoked after at least 1 ReloadIfNeeded() succeeds. Inputs() map[string]*domain.Input - // Stats returns platform-specific live metrics, for debugging + // Deprecated: Do not use or implement. Stats returns platform-specific live metrics, for debugging purposes Stats(stats map[string]interface{}) - // Close releases resources Close() error - // ReloadIfNeeded will update models as needed, and check their health. + // ReloadIfNeeded will update models as needed, check their health, and consolidate signatures, if implemented. + // This can also be named EnsurePredictionPossible() or EnsureReady() or the like. // For in-process models (TensorFlow), this will check if the underlying models need to be updated. - // For external models (Triton), this will check Triton models' health. + // For external models (Triton), this will use the Model Control API to load, unload, and check the health of Triton models. ReloadIfNeeded(ctx context.Context) error } diff --git a/service/platform/factory/factory.go b/service/platform/factory/factory.go index b9e8c37..d176739 100644 --- a/service/platform/factory/factory.go +++ b/service/platform/factory/factory.go @@ -22,10 +22,10 @@ func CreateEvaluator( metrics *gmetric.Service, sema *semaphore.Weighted, maxEvaluatorWait time.Duration, - tritonClients map[string]triton.TritonClient, + tritonServices map[string]*triton.Service, ) (platform.PlatformEvaluator, error) { p := cfg.GetPlatform() - isRouter := cfg.Mode == "router" + isRouter := cfg.IsRouter() switch p { case "tensorflow": @@ -35,10 +35,14 @@ func CreateEvaluator( case "triton": if isRouter { - return router.NewRouter(cfg, fs, tritonClients) + makeEvaluator := func(modelName string) (platform.PlatformEvaluator, error) { + return triton.NewRoutedTritonEvaluator(modelName, cfg, tritonServices) + } + + return router.NewRouter(cfg, fs, tritonServices, makeEvaluator) } - return triton.NewTritonEvaluator(cfg, tritonClients) + return triton.NewTritonEvaluator(cfg, tritonServices) default: return nil, fmt.Errorf("unsupported platform: %s for model %s", p, cfg.ID) } diff --git a/service/platform/router/batchkey_bench_test.go b/service/platform/router/batchkey_bench_test.go new file mode 100644 index 0000000..41407d7 --- /dev/null +++ b/service/platform/router/batchkey_bench_test.go @@ -0,0 +1,137 @@ +package router + +import ( + "fmt" + "strconv" + "testing" +) + +// Benchmark different approaches to generating batch keys +// These benchmarks compare key generation strategies for the ForceBatchSize1 path + +var ( + sampleModelName = "model_name_example" + sampleOffset = 12345 + result string // prevent compiler optimization +) + +func BenchmarkBatchKey_Sprintf(b *testing.B) { + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", sampleModelName, sampleOffset) + } + result = r +} + +func BenchmarkBatchKey_StrconcatStrconv(b *testing.B) { + var r string + for i := 0; i < b.N; i++ { + r = sampleModelName + "#" + strconv.Itoa(sampleOffset) + } + result = r +} + +func BenchmarkBatchKey_StrconcatFormatInt(b *testing.B) { + var r string + for i := 0; i < b.N; i++ { + r = sampleModelName + "#" + strconv.FormatInt(int64(sampleOffset), 10) + } + result = r +} + +// Benchmark with varying model name lengths +func BenchmarkBatchKey_ShortName_Sprintf(b *testing.B) { + name := "m1" + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", name, i%1000) + } + result = r +} + +func BenchmarkBatchKey_ShortName_Strconcat(b *testing.B) { + name := "m1" + var r string + for i := 0; i < b.N; i++ { + r = name + "#" + strconv.Itoa(i%1000) + } + result = r +} + +func BenchmarkBatchKey_LongName_Sprintf(b *testing.B) { + name := "very_long_model_name_with_many_characters_for_testing" + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", name, i%1000) + } + result = r +} + +func BenchmarkBatchKey_LongName_Strconcat(b *testing.B) { + name := "very_long_model_name_with_many_characters_for_testing" + var r string + for i := 0; i < b.N; i++ { + r = name + "#" + strconv.Itoa(i%1000) + } + result = r +} + +// Benchmark the map lookup with generated keys (more realistic scenario) +func BenchmarkBatchKey_MapLookup_Sprintf(b *testing.B) { + m := make(map[string]int) + names := []string{"model1", "model2", "model3", "model4", "model5"} + + // Pre-populate map + for _, name := range names { + for j := 0; j < 100; j++ { + m[fmt.Sprintf("%s#%d", name, j)] = j + } + } + + b.ResetTimer() + var sum int + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("%s#%d", names[i%5], i%100) + sum += m[key] + } + _ = sum +} + +func BenchmarkBatchKey_MapLookup_Strconcat(b *testing.B) { + m := make(map[string]int) + names := []string{"model1", "model2", "model3", "model4", "model5"} + + // Pre-populate map + for _, name := range names { + for j := 0; j < 100; j++ { + m[name+"#"+strconv.Itoa(j)] = j + } + } + + b.ResetTimer() + var sum int + for i := 0; i < b.N; i++ { + key := names[i%5] + "#" + strconv.Itoa(i%100) + sum += m[key] + } + _ = sum +} + +// Benchmark allocations +func BenchmarkBatchKey_Allocs_Sprintf(b *testing.B) { + b.ReportAllocs() + var r string + for i := 0; i < b.N; i++ { + r = fmt.Sprintf("%s#%d", sampleModelName, i%1000) + } + result = r +} + +func BenchmarkBatchKey_Allocs_Strconcat(b *testing.B) { + b.ReportAllocs() + var r string + for i := 0; i < b.N; i++ { + r = sampleModelName + "#" + strconv.Itoa(i%1000) + } + result = r +} diff --git a/service/platform/router/fixed.go b/service/platform/router/fixed.go index 7fdc0fd..3971ad3 100644 --- a/service/platform/router/fixed.go +++ b/service/platform/router/fixed.go @@ -1,17 +1,30 @@ package router import ( - "context" "fmt" "github.com/viant/mly/service/config" - "github.com/viant/mly/service/request/shape" ) +// preparedReplacement holds a pre-parsed replacement value for fixed evaluator outputs +type preparedReplacement struct { + name string + typ string + value interface{} +} + type fixedEvaluator struct { prepared []preparedReplacement } +func (f *fixedEvaluator) OutputNames() []string { + names := make([]string, len(f.prepared)) + for i, p := range f.prepared { + names[i] = p.name + } + return names +} + func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, error) { prepared := make([]preparedReplacement, 0, len(repls)) @@ -71,7 +84,7 @@ func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, e default: return fmt.Errorf("router replacement %q: value %T not coercible to int64", r.Name, r.Value) } - case "float", "float32": + case "float32": switch n := r.Value.(type) { case int: pr = preparedReplacement{typ: "float32", value: float32(n)} @@ -105,6 +118,7 @@ func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, e return fmt.Errorf("unsupported router replacement type %q for %q", r.Type, r.Name) } + pr.name = r.Name prepared = append(prepared, pr) } @@ -117,12 +131,7 @@ func newFixedEvaluator(repls []config.PredictionReplacement) (*fixedEvaluator, e return &fixedEvaluator{prepared: prepared}, nil } -func (f *fixedEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - batchSize, err := shape.DetermineBatchSize(params) - if err != nil { - return nil, err - } - +func (f *fixedEvaluator) Predict(batchSize int) ([]interface{}, error) { makeString := func(v string) [][]string { out := make([][]string, batchSize) for i := 0; i < batchSize; i++ { @@ -182,8 +191,6 @@ func (f *fixedEvaluator) Predict(ctx context.Context, params []interface{}) ([]i results[i] = makeInt32(repl.value.(int32)) case "int64": results[i] = makeInt64(repl.value.(int64)) - case "float": - results[i] = makeFloat32(repl.value.(float32)) case "float32": results[i] = makeFloat32(repl.value.(float32)) case "float64": diff --git a/service/platform/router/prometheus.go b/service/platform/router/prometheus.go index 9809a5d..5598690 100644 --- a/service/platform/router/prometheus.go +++ b/service/platform/router/prometheus.go @@ -28,17 +28,6 @@ var ( []string{"router", "fixed_only"}, ) - routerWorkerChannelQueuedSummary = prometheus.NewSummaryVec( - prometheus.SummaryOpts{ - Namespace: "mly", - Subsystem: "router", - Name: "worker_channel_queued_summary", - Help: "Number of router predictions queued in the worker channel.", - Objectives: buckets.CommonSummaryObjectives, - }, - []string{"router"}, - ) - routerPredictDroppedCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: "mly", @@ -49,6 +38,16 @@ var ( []string{"router"}, ) + routerQueueDurationMicrosSummary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "router", + Name: "queue_duration_summary_us", + Help: "Duration of router queueing.", + }, + []string{"router"}, + ) + routerModelUnloadGauge = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: "mly", @@ -64,6 +63,6 @@ func init() { prometheus.MustRegister(routerPredictDurationMicrosSummary) prometheus.MustRegister(routerReloadDurationMicrosSummary) prometheus.MustRegister(routerModelUnloadGauge) + prometheus.MustRegister(routerQueueDurationMicrosSummary) prometheus.MustRegister(routerPredictDroppedCounter) - prometheus.MustRegister(routerWorkerChannelQueuedSummary) } diff --git a/service/platform/router/reload.go b/service/platform/router/reload.go new file mode 100644 index 0000000..2b0b9bd --- /dev/null +++ b/service/platform/router/reload.go @@ -0,0 +1,528 @@ +package router + +import ( + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "log" + "reflect" + "strings" + "sync" + "time" + + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/domain" + "github.com/viant/mly/service/files" + "github.com/viant/mly/service/platform" + "github.com/viant/mly/shared/config/router" + "gopkg.in/yaml.v2" +) + +type modelSignature struct { + name string + signature *domain.Signature +} + +func (r *Router) ReloadIfNeeded(ctx context.Context) error { + start := time.Now() + isFullReload := false + defer func() { + var mode string + if isFullReload { + mode = "full" + } else { + mode = "checks" + } + + routerReloadDurationMicrosSummary.WithLabelValues(r.routerName, mode).Observe(float64(time.Since(start).Microseconds())) + }() + + // fetch and check router configuration file + snapshot, err := files.ModifiedSnapshot(ctx, r.fs, r.configURL, nil) + if err != nil { + return fmt.Errorf("failed to check router configuration file: %w", err) + } + + if !r.isModified(snapshot) { + // check health of all underlying models + var wg sync.WaitGroup + + r.configLock.RLock() + errChannels := len(r.routingTable) + if r.globalModel != nil { + errChannels++ + } + + errCh := make(chan error, errChannels) + + if r.globalModel != nil { + wg.Add(1) + go func() { + defer wg.Done() + err := r.globalModel.ReloadIfNeeded(ctx) + if err != nil { + errCh <- fmt.Errorf("failed to reload global model: %w", err) + } + }() + } + + for m, p := range r.routingTable { + wg.Add(1) + go func(m string, p platform.PlatformEvaluator) { + defer wg.Done() + err := p.ReloadIfNeeded(ctx) + if err != nil { + errCh <- fmt.Errorf("failed to reload model %s: %w", m, err) + } + }(m, p) + } + + wg.Wait() + close(errCh) + + if len(errCh) > 0 { + var errStrings []string + for err := range errCh { + errStrings = append(errStrings, err.Error()) + } + + err = fmt.Errorf("reloading errors: %s", strings.Join(errStrings, "; ")) + } + + r.configLock.RUnlock() + return err + } + + // see defer above + isFullReload = true + + // otherwise just abandon the routing table status checks + + r.configLock.Lock() + defer r.configLock.Unlock() + + r.configModified = snapshot + + // load router configuration file + rawReader, err := r.fs.OpenURL(ctx, r.configURL) + if err != nil { + return fmt.Errorf("failed to open router configuration file: %w", err) + } + + defer rawReader.Close() + var reader io.Reader = rawReader + if strings.HasSuffix(r.configURL, ".gz") { + if reader, err = gzip.NewReader(rawReader); err != nil { + return fmt.Errorf("failed to create gzip reader for router configuration file: %w", err) + } + } + + newConfig := new(router.RoutingConfig) + + // TODO move this check earlier + if strings.Contains(r.configURL, ".yaml") { + decoder := yaml.NewDecoder(reader) + err = decoder.Decode(newConfig) + } else if strings.Contains(r.configURL, ".json") { + err = json.NewDecoder(reader).Decode(newConfig) + } else { + return fmt.Errorf("unsupported router configuration file type: %s", r.configURL) + } + + if err != nil { + return fmt.Errorf("failed to decode router configuration file: %w", err) + } + + if err := r.applyRouterConfig(ctx, newConfig); err != nil { + return err + } + + return nil +} + +// applyRouterConfig will both update evaluators to new configuration state and verify and build the signature +func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RoutingConfig) error { + modelsToUnload := make(map[string]struct{}) + reuseEvaluators := make(map[string]platform.PlatformEvaluator) + var reuseGlobal platform.PlatformEvaluator + + // copy members to local scope + var finalSignature *domain.Signature + var oldConfig *router.RoutingConfig + func() { + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + + if r.ioState != nil { + finalSignature = r.ioState.signature + } + + reuseGlobal = r.globalModel + oldConfig = r.routingConfig + }() + + if oldConfig != nil { + for _, entity := range oldConfig.EntityMapping { + modelsToUnload[entity.ModelName] = struct{}{} + if evaluator, ok := r.routingTable[entity.ModelName]; ok { + reuseEvaluators[entity.ModelName] = evaluator + } + } + + if oldConfig.GlobalModelName != "" { + modelsToUnload[oldConfig.GlobalModelName] = struct{}{} + } + } + + newModelMapping := make(map[int]string) + for _, entity := range newConfig.EntityMapping { + r.debugLogf("add mapping: %d -> %s", entity.EntityID, entity.ModelName) + + newModelMapping[entity.EntityID] = entity.ModelName + delete(modelsToUnload, entity.ModelName) + } + + globalModelName := newConfig.GlobalModelName + if globalModelName == "" && r.hasGlobalModel { + return fmt.Errorf("global model name is missing") + } + + if globalModelName != "" { + r.debugLogf("global model: %s", globalModelName) + delete(modelsToUnload, globalModelName) + } + + newRoutingTable := make(map[string]platform.PlatformEvaluator) + for _, entity := range newConfig.EntityMapping { + model := entity.ModelName + if _, ok := newRoutingTable[model]; ok { + continue + } + + if evaluator, ok := reuseEvaluators[model]; ok { + newRoutingTable[model] = evaluator + continue + } + + evaluator, err := r.makeRoutedEvaluator(model) + + if err != nil { + return fmt.Errorf("failed to create Routed Evaluator for model %s: %w", model, err) + } + + newRoutingTable[model] = evaluator + } + + var globalEvaluator platform.PlatformEvaluator + if globalModelName != "" { + if oldConfig != nil && globalModelName == oldConfig.GlobalModelName && reuseGlobal != nil { + globalEvaluator = reuseGlobal + } else if evaluator, ok := newRoutingTable[globalModelName]; ok { + globalEvaluator = evaluator + } else if evaluator, ok := reuseEvaluators[globalModelName]; ok { + globalEvaluator = evaluator + } else { + var err error + globalEvaluator, err = r.makeRoutedEvaluator(globalModelName) + if err != nil { + return fmt.Errorf("failed to create Routed Evaluator for global model %s: %w", globalModelName, err) + } + } + } + + wg := sync.WaitGroup{} + + numWorkers := len(newRoutingTable) + if globalEvaluator != nil { + numWorkers++ + } + + errCh := make(chan error, numWorkers) + signatureCh := make(chan modelSignature, numWorkers) + + if globalEvaluator != nil { + wg.Add(1) + go func() { + defer wg.Done() + if err := globalEvaluator.ReloadIfNeeded(ctx); err != nil { + errCh <- fmt.Errorf("failed to reload global model %s: %w", globalModelName, err) + } + }() + } + + for model := range newRoutingTable { + wg.Add(1) + go func(model string) { + defer wg.Done() + + r.debugLogf("reload model: %s", model) + + modelEvaluator := newRoutingTable[model] + if err := modelEvaluator.ReloadIfNeeded(ctx); err != nil { + r.debugLogf("failed to reload model: %s: %v", model, err) + errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) + } + + evalSig := modelEvaluator.Signature() + + if evalSig == nil { + errCh <- fmt.Errorf("model %s signature is nil", model) + return + } + + signatureCh <- modelSignature{ + name: model, + signature: evalSig, + } + }(model) + } + + r.debugLogf("wait for reloads") + + wg.Wait() + close(errCh) + close(signatureCh) + + if len(errCh) > 0 { + var errStrings []string + for err := range errCh { + errStrings = append(errStrings, err.Error()) + } + return fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) + } + + sigInputMap := make(map[string]*domain.Input) + + // sigOutputMap is for validating output consistency + sigOutputMap := make(map[string]*domain.Output) + + // we only create ioState on the first reload + var ioState *IOState = new(IOState) + if finalSignature != nil { + for _, input := range finalSignature.Inputs { + sigInputMap[input.Name] = &input + } + + for _, output := range finalSignature.Outputs { + sigOutputMap[output.Name] = &output + } + } + + // dsmi stands for DownStream Model Information + for dsmi := range signatureCh { + // accept first available signature as the final signature + if finalSignature == nil { + srcSig := dsmi.signature + + // copy signature from downstream + finalSignature = &domain.Signature{ + Inputs: make([]domain.Input, len(srcSig.Inputs), len(srcSig.Inputs)+1), + Outputs: make([]domain.Output, len(srcSig.Outputs)), + } + copy(finalSignature.Inputs, srcSig.Inputs) + copy(finalSignature.Outputs, srcSig.Outputs) + + // add router input + inputOffset := len(finalSignature.Inputs) + routerInput := domain.Input{ + Name: r.routerInputFieldName, + Type: reflect.TypeOf(int64(0)), + Index: inputOffset, + } + + ioState.routerInputOffset = inputOffset + + finalSignature.Inputs = append(finalSignature.Inputs, routerInput) + + // sigInputMap is for Request validation + for _, input := range finalSignature.Inputs { + sigInputMap[input.Name] = &input + } + + // add configured (aux) inputs to signature + for _, input := range r.configuredInputs { + _, ok := sigInputMap[input.Name] + + if ok { + // the input is configured and already in the self-reported signature + continue + } + + if !input.Auxiliary { + return fmt.Errorf("non-auxiliary input %s for model %s was not in model inputs", input.Name, dsmi.name) + } + + sigInputMap[input.Name] = &domain.Input{ + Name: input.Name, + Type: input.RawType(), + Auxiliary: input.Auxiliary, + } + } + + if r.modelOutputName != "" { + // add the selected model output + modelOutput := domain.Output{ + Name: r.modelOutputName, + Index: len(finalSignature.Outputs), + DataType: "string", + } + + finalSignature.Outputs = append(finalSignature.Outputs, modelOutput) + } + + for _, output := range finalSignature.Outputs { + sigOutputMap[output.Name] = &output + } + + continue + } + + dsSignature := dsmi.signature + // validate signature consistency + // Note: Index differences are permitted - IOs are matched by name + + // check that the new signature has no new outputs + thisSignatureOutputMap := make(map[string]*domain.Output) + for _, output := range dsSignature.Outputs { + oldOutput, ok := sigOutputMap[output.Name] + if !ok { + return fmt.Errorf("signature output %s for model %s not found in the previous signature", output.Name, dsmi.name) + } + + thisSignatureOutputMap[output.Name] = &output + + if oldOutput.DataType != output.DataType { + return fmt.Errorf("signature output %s for model %s has data type %s, and the previous signature has data type %s", output.Name, dsmi.name, output.DataType, oldOutput.DataType) + } + } + + // check that the new signature has no new outputs except the model name output + for expectedOutput := range sigOutputMap { + if _, ok := thisSignatureOutputMap[expectedOutput]; !ok && expectedOutput != r.modelOutputName { + return fmt.Errorf("signature output %s for was not found in model %s signature", expectedOutput, dsmi.name) + } + } + + // check that the new signature has no new inputs + thisSignatureInputMap := make(map[string]*domain.Input) + for _, input := range dsSignature.Inputs { + oldInput, ok := sigInputMap[input.Name] + if !ok { + return fmt.Errorf("signature input %s for model %s not found in the previous signature", input.Name, dsmi.name) + } + + thisSignatureInputMap[input.Name] = &input + + if !oldInput.Type.ConvertibleTo(input.Type) { + return fmt.Errorf("signature input %s for model %s has data type %s, and the previous signature has data type %s", input.Name, dsmi.name, input.Type.String(), oldInput.Type.String()) + } + } + + // check that the new signature has all expected inputs except for the routing and auxiliary inputs + for expectedInput := range sigInputMap { + if sigInputMap[expectedInput].Auxiliary { + continue + } + + if expectedInput == r.routerInputFieldName { + continue + } + + if _, ok := thisSignatureInputMap[expectedInput]; !ok { + return fmt.Errorf("signature input %s for was not found in model %s signature", expectedInput, dsmi.name) + } + } + } + + if r.fixedEvaluatorFields != nil { + // TODO this is actually an acceptable case, we can simply ignore fixed evaluator fields that aren't applicable + for field := range r.fixedEvaluatorFields { + if _, ok := sigOutputMap[field]; !ok { + return fmt.Errorf("fixed evaluator field: %s was not found in any model outputs", field) + } + } + + // check that the fixed evaluator fields have all expected outputs + for _, field := range sigOutputMap { + if _, ok := r.fixedEvaluatorFields[field.Name]; !ok && field.Name != r.modelOutputName { + return fmt.Errorf("signature output %s is not replaced", field.Name) + } + } + } + + ioState.signature = finalSignature + ioState.inputs = sigInputMap + + if globalEvaluator != nil { + if _, exists := newRoutingTable[globalModelName]; !exists { + newRoutingTable[globalModelName] = globalEvaluator + } + } + + func() { + r.routingTableLock.Lock() + defer r.routingTableLock.Unlock() + + r.routingConfig = newConfig + + r.routingMap = newModelMapping + r.routingTable = newRoutingTable + + r.globalModel = globalEvaluator + + if r.ioState == nil { + r.ioState = ioState + } + }() + + for model := range modelsToUnload { + r.unloadGauge.Inc() + + go func(modelName string) { + defer r.unloadGauge.Dec() + + ctxTo, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + r.debugLogf("request to unload model: %s", modelName) + + if err := r.unloadModel(ctxTo, modelName); err != nil { + r.debugLogf("failed to unload model %s: %v\n", modelName, err) + } + }(model) + } + + return nil +} + +func (r *Router) debugLogf(format string, args ...interface{}) { + if r.debug { + prefix := "[%s Router] " + log.Printf(prefix+format, append([]interface{}{r.routerName}, args...)...) + } +} + +// TODO refactor with service/tfmodel/service.isModified()? +func (r *Router) isModified(snapshot *config.Modified) bool { + if r.routingConfig == nil || r.configModified == nil { + return true + } + + if snapshot.Max.IsZero() { + return false + } + + r.configLock.RLock() + modified := r.configModified + r.configLock.RUnlock() + + return !(modified.Max.Equal(snapshot.Max) && modified.Min.Equal(snapshot.Min)) +} + +func (r *Router) unloadModel(ctx context.Context, modelName string) error { + if err := r.unloader.UnloadModel(ctx, r.routerName, modelName); err != nil { + return fmt.Errorf("failed to unload model %s: %w", modelName, err) + } + return nil +} diff --git a/service/platform/router/reload_test.go b/service/platform/router/reload_test.go new file mode 100644 index 0000000..3c91dbc --- /dev/null +++ b/service/platform/router/reload_test.go @@ -0,0 +1,507 @@ +package router + +import ( + "context" + "errors" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/domain" + "github.com/viant/mly/service/platform" + "github.com/viant/mly/service/triton" + sharedrouter "github.com/viant/mly/shared/config/router" +) + +type mockUnloader struct { + tritonServer *mockTritonServer + unloadCh chan string +} + +func (m *mockUnloader) ModelUnload(ctx context.Context, tritonModelName string) error { + if m.tritonServer != nil { + m.tritonServer.mu.Lock() + defer m.tritonServer.mu.Unlock() + + if m.tritonServer.readyState == nil { + m.tritonServer.readyState = make(map[string]bool) + } + + m.tritonServer.readyState[tritonModelName] = false + } + + ch := m.unloadCh + + if ch != nil { + ch <- tritonModelName + } + + return nil +} + +type wrappedUnloader struct { + tritonService *triton.Service + + wg *sync.WaitGroup +} + +func (w *wrappedUnloader) UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error { + defer w.wg.Done() + return w.tritonService.UnloadModel(ctx, mlyModelID, tritonModelName) +} + +type mockTritonServer struct { + mu sync.Mutex + + readyState map[string]bool + modelLoadErr map[string]error +} + +func (m *mockTritonServer) ModelLoad(ctx context.Context, modelName string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.modelLoadErr[modelName]; err != nil { + return err + } + + if m.readyState == nil { + m.readyState = make(map[string]bool) + } + + m.readyState[modelName] = true + + return nil +} + +func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { + ctx := context.Background() + mockClient := &mockUnloader{ + unloadCh: make(chan string, 2), + } + + oldConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + {EntityID: 2, ModelName: "modelB"}, + }, + GlobalModelName: "global-old", + } + + makeSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + router := &Router{ + unloader: &triton.Service{Unloader: mockClient, Repository: triton.NewRepository()}, + routingConfig: oldConfig, + routingMap: map[int]string{ + 1: "modelA", + 2: "modelB", + }, + routingTable: map[string]platform.PlatformEvaluator{ + "modelA": &mockEvaluator{signature: makeSig}, + "modelB": &mockEvaluator{signature: makeSig}, + }, + globalModel: &mockEvaluator{}, + debug: true, + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: makeSig}, nil + }, + unloadGauge: routerModelUnloadGauge.WithLabelValues("test_router"), + } + + reusedModelB := router.routingTable["modelB"] + + newConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + {EntityID: 3, ModelName: "modelC"}, + }, + GlobalModelName: "global-new", + } + + if err := router.applyRouterConfig(ctx, newConfig); err != nil { + t.Fatalf("applyRouterConfig returned error: %v", err) + } + + waitForCalls(t, mockClient.unloadCh, 2) + + if router.routingConfig != newConfig { + t.Fatalf("routerConfig pointer not updated") + } + + expectedRouting := map[int]string{ + 1: "modelB", + 3: "modelC", + } + + if !reflect.DeepEqual(router.routingMap, expectedRouting) { + t.Fatalf("routingMap mismatch, got %#v", router.routingMap) + } + + if router.globalModel == nil { + t.Fatalf("globalModel was not set") + } + + if _, ok := router.routingTable["modelB"]; !ok { + t.Fatalf("routingTable missing modelB") + } + + if router.routingTable["modelB"] != reusedModelB { + t.Fatalf("modelB evaluator was not reused") + } + + if _, ok := router.routingTable["modelC"]; !ok { + t.Fatalf("routingTable missing modelC") + } + + if _, ok := router.routingTable["modelA"]; ok { + t.Fatalf("routingTable still contains modelA") + } +} + +func TestRouter_applyRouterConfig_LoadError(t *testing.T) { + ctx := context.Background() + + loadErr := errors.New("load failure") + + tritonServer := new(mockTritonServer) + tritonServer.modelLoadErr = map[string]error{ + "modelX": loadErr, + } + + mockClient := &mockUnloader{} + + oldConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + }, + } + + signature := &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + + router := &Router{ + debug: true, + routerName: "load_error", + + unloader: &triton.Service{Unloader: mockClient}, + routingConfig: oldConfig, + routingMap: map[int]string{ + 1: "modelA", + }, + + makeRoutedEvaluator: func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{ + tritonServer: tritonServer, + modelName: modelName, + signature: func() *domain.Signature { return signature }, + }, nil + }, + unloadGauge: routerModelUnloadGauge.WithLabelValues("load_error"), + } + + newConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 2, ModelName: "modelX"}, + }, + } + + err := router.applyRouterConfig(ctx, newConfig) + if err == nil { + t.Fatalf("expected error but got nil") + } + + if !strings.Contains(err.Error(), "modelX") { + t.Fatalf("expected error mentioning modelX, got %v", err) + } + + if router.routingConfig != oldConfig { + t.Fatalf("routerConfig should remain unchanged on error") + } + + if !reflect.DeepEqual(router.routingMap, map[int]string{1: "modelA"}) { + t.Fatalf("routingMap should remain unchanged on error") + } + + if router.routingTable != nil { + t.Fatalf("routingTable should not be replaced on error") + } +} + +func TestRouter_applyRouterConfig_signature(t *testing.T) { + ctx := context.Background() + mockClient := &mockUnloader{ + unloadCh: make(chan string, 1), + tritonServer: &mockTritonServer{}, + } + + makeSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + cfg := &config.Model{ + ID: "test_signature", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + InputName: "router_id", + Global: config.GlobalModelConfig{ + Exists: true, + }, + Output: config.OutputConfig{ + FieldName: "model_id", + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfg.Init(nil) + + var router *Router + var err error + + router, err = newRouter(cfg, nil, map[string]UnloadService{ + "test_server": &triton.Service{Unloader: mockClient}, + }, func(modelName string) (platform.PlatformEvaluator, error) { + return &mockEvaluator{signature: makeSig}, nil + }) + + if err != nil { + t.Fatalf("NewRouter error: %v", err) + } + + newConfig := &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "model1"}, + {EntityID: 2, ModelName: "model2"}, + }, + GlobalModelName: "global-model", + } + + if err := router.applyRouterConfig(ctx, newConfig); err != nil { + t.Fatalf("applyRouterConfig returned error: %v", err) + } + + detectedInputs := map[string]struct{}{} + for _, input := range router.ioState.signature.Inputs { + detectedInputs[input.Name] = struct{}{} + } + + expectedInputs := map[string]struct{}{ + "text": {}, + "router_id": {}, + } + + assert.Equal(t, expectedInputs, detectedInputs) + assert.Equal(t, 1, router.ioState.routerInputOffset) + + params := []interface{}{ + [][]string{{"a"}, {"abcd"}}, // text + [][]int64{{1}, {2}}, // router_id + } + + results, err := router.Predict(ctx, params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + assert.Equal(t, 2, len(results)) +} + +func TestRouter_applyRouterConfig_sharedTritonServer(t *testing.T) { + tritonServer := &mockTritonServer{} + modelUnloader := &mockUnloader{ + tritonServer: tritonServer, + } + + repository := triton.NewRepository() + tritonService := &triton.Service{ + Unloader: modelUnloader, + Repository: repository, + } + + wrappedService := &wrappedUnloader{ + tritonService: tritonService, + wg: &sync.WaitGroup{}, + } + + unloaders := map[string]UnloadService{ + "test_server": wrappedService, + } + + cfgA := &config.Model{ + ID: "test_shared_a", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + Global: config.GlobalModelConfig{ + PredictionReplacements: []config.PredictionReplacement{ + { + Name: "score", + Type: "float32", + Value: 0.0, + }, + }, + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfgB := &config.Model{ + ID: "test_shared_b", + Debug: true, + Mode: "router", + Platform: "triton", + Router: &config.RouterConfig{ + ConfigURL: "memory://router-config", + Global: config.GlobalModelConfig{ + PredictionReplacements: []config.PredictionReplacement{ + { + Name: "score", + Type: "float32", + Value: 0.0, + }, + }, + }, + }, + Triton: &config.TritonConfig{ + ServerID: "test_server", + }, + } + + cfgA.Init(nil) + cfgB.Init(nil) + + newSig := func() *domain.Signature { + return &domain.Signature{ + Inputs: []domain.Input{ + {Name: "text", Index: 0, Type: reflect.TypeOf("")}, + }, + Outputs: []domain.Output{ + {Name: "score", Index: 0, DataType: "float32"}, + }, + } + } + + routerA, err := newRouter(cfgA, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { + tritonService.RegisterUsage(cfgA.ID, modelName) + return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil + }) + + if err != nil { + t.Fatalf("newRouter returned error: %v", err) + } + + routerB, err := newRouter(cfgB, nil, unloaders, func(modelName string) (platform.PlatformEvaluator, error) { + tritonService.RegisterUsage(cfgB.ID, modelName) + return &mockEvaluator{modelName: modelName, signature: newSig, tritonServer: tritonServer}, nil + }) + + if err != nil { + t.Fatalf("newRouter returned error: %v", err) + } + + ctx := context.Background() + + // establish mappings for A and B using the same models + + if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + {EntityID: 2, ModelName: "modelC"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig A initial returned error: %v", err) + } + + if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelA"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig B inital returned error: %v", err) + } + + // we expect model A and model C to be attempted to be unloaded + wrappedService.wg.Add(2) + + // routerA will now get a different mapping + if err = routerA.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig A reload returned error: %v", err) + } + + wrappedService.wg.Wait() + + assert.True(t, tritonServer.readyState["modelA"], "modelA should still be loaded") + assert.False(t, tritonServer.readyState["modelC"], "modelC should be unloaded") + + // we expect model A to be attempted to be unloaded + wrappedService.wg.Add(1) + + if err = routerB.applyRouterConfig(ctx, &sharedrouter.RoutingConfig{ + EntityMapping: []sharedrouter.EntityKV{ + {EntityID: 1, ModelName: "modelB"}, + {EntityID: 2, ModelName: "modelC"}, + }, + }); err != nil { + t.Fatalf("applyRouterConfig B reload returned error: %v", err) + } + + wrappedService.wg.Wait() + + assert.False(t, tritonServer.readyState["modelA"], "modelA should be unloaded") + assert.True(t, tritonServer.readyState["modelB"], "modelB should still be loaded") + assert.True(t, tritonServer.readyState["modelC"], "modelC should be loaded") +} + +func waitForCalls(t *testing.T, ch <-chan string, count int) []string { + t.Helper() + var out []string + for i := 0; i < count; i++ { + select { + case v := <-ch: + out = append(out, v) + case <-time.After(time.Second): + t.Fatalf("timeout waiting for call %d/%d", i+1, count) + } + } + return out +} diff --git a/service/platform/router/router.go b/service/platform/router/router.go index f4d1fc6..64305cd 100644 --- a/service/platform/router/router.go +++ b/service/platform/router/router.go @@ -1,236 +1,212 @@ package router import ( - "compress/gzip" "context" - "encoding/json" "fmt" - "io" - "log" - "reflect" - "strings" + "strconv" "sync" + "sync/atomic" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/viant/afs" "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" - "github.com/viant/mly/service/files" "github.com/viant/mly/service/platform" "github.com/viant/mly/service/request/shape" tricli "github.com/viant/mly/service/triton" + "github.com/viant/mly/shared" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config/router" - "gopkg.in/yaml.v2" ) +const queueSizeExceededError = "queue size exceeded" + +type IOState struct { + inputs map[string]*domain.Input + signature *domain.Signature + + // router input offset is the index of the routing input in the inputs array + routerInputOffset int +} + +type UnloadService interface { + UnloadModel(ctx context.Context, mlyModelID string, tritonModelName string) error +} + // Router implements the PlatformEvaluator interface for router mode. type Router struct { - configURL string - fs afs.Service + configURL string + fs afs.Service + + // config lock only protects the configModified field configLock sync.RWMutex configModified *config.Modified - routerConfig *router.RouterConfig + // routingTableLock protects: + // - routerConfig + // - routingMap + // - routingTable + // - globalModel + // - ioState routingTableLock sync.RWMutex - routingMap map[int]string - routingTable map[string]platform.PlatformEvaluator - workCh chan *workRequest + // routingConfig contains the last loaded routing configuration + routingConfig *router.RoutingConfig - globalModel platform.PlatformEvaluator - fixedEvaluator platform.Predictor - modelOutputName string + hasGlobalModel bool + makeRoutedEvaluator func(modelName string) (platform.PlatformEvaluator, error) - modelConfig *config.Model - routerName string - tritonClient tricli.TritonClient + routerInputFieldName string - signature *domain.Signature - indexToName map[int]string - inputs map[string]*domain.Input + routingMap map[int]string + routingTable map[string]platform.PlatformEvaluator - // router input offset is the index of the router input in the inputs array - routerInputOffset int -} + // TODO see if this can be removed, may just need to map via model name + globalModel platform.PlatformEvaluator -func NewRouter(cfg *config.Model, fs afs.Service, tritonClients map[string]tricli.TritonClient) (*Router, error) { - if cfg.Router == nil { - return nil, fmt.Errorf("router configuration is required") - } + // fixedEvaluator is non-nil IFF there is no global model configured + fixedEvaluator *fixedEvaluator - if err := cfg.Router.Validate(); err != nil { - return nil, fmt.Errorf("router configuration is invalid: %w", err) + // fixedEvaluatorFields is for checking all outputs in the signature are replaced + fixedEvaluatorFields map[string]struct{} + outputConfig config.OutputConfig - } + modelOutputName string - tritonClient, ok := tritonClients[cfg.Triton.ServerID] - if !ok { - return nil, fmt.Errorf("triton client not found for server ID: %s", cfg.Triton.ServerID) - } + routerName string + debug bool + unloader UnloadService + unloadGauge prometheus.Gauge - r := &Router{ - configURL: cfg.Router.ConfigURL, - fs: fs, - routerName: cfg.ID, - modelConfig: cfg, - tritonClient: tritonClient, - } + configuredInputs []*shared.Field + ioState *IOState - if err := r.handleIO(cfg); err != nil { - return nil, fmt.Errorf("failed to handle IO: %w", err) - } + // forceBatchSize1 when true uses legacy per-sample dispatch; when false (default) uses batched dispatch + forceBatchSize1 bool - r.workCh = make(chan *workRequest, cfg.Router.MaxQueueSize) - for i := 0; i < cfg.Router.Workers; i++ { - go handleWorkRequests(r.workCh, routerWorkerChannelQueuedSummary.WithLabelValues(r.routerName)) - } + // workerSemaphore limits concurrent model evaluations + workerSemaphore chan struct{} - return r, nil -} + // maxQueueSize limits queued batches before rejection + maxQueueSize uint64 -type preparedReplacement struct { - typ string - value interface{} + queued *atomic.Uint64 + queueDurationObserver prometheus.Observer + droppedCounter prometheus.Counter } -func (t *Router) handleIO(cfg *config.Model) error { - io := &cfg.MetaInput - - if len(io.Inputs) == 0 { - return fmt.Errorf("input configuration is required for a router") +// NewRouter creates a new Router instance. +// cfg is expected to be Init()'d and Validate()'d before calling this function. +// makeEvaluator is expected to register usage for every created Evaluator. +func NewRouter(cfg *config.Model, fs afs.Service, tritonServices map[string]*tricli.Service, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { + unloaders := make(map[string]UnloadService) + for serverID, tritonService := range tritonServices { + unloaders[serverID] = tritonService } - if len(io.Outputs) == 0 { - return fmt.Errorf("output configuration is required for a router") - } - - var inputs []domain.Input - - // for declaring the router's inputs - mappedInputs := make(map[string]*domain.Input) + return newRouter(cfg, fs, unloaders, makeEvaluator) +} - // generate backend input - indexToName := make(map[int]string) +// newRouter uses a map[string]ModelUnloader, where ModelUnloader is-a triton.TritonClient, for testing. +func newRouter(cfg *config.Model, fs afs.Service, unloaders map[string]UnloadService, makeEvaluator func(modelName string) (platform.PlatformEvaluator, error)) (*Router, error) { + rtCfg := cfg.Router + if rtCfg == nil { + return nil, fmt.Errorf("router configuration is required") + } - i := 0 - for _, input := range io.Inputs { - if !input.Auxiliary && input.Name != cfg.Router.InputName { - inputs = append(inputs, domain.Input{ - Name: input.Name, - Index: input.Index, - }) + unloader, ok := unloaders[cfg.Triton.ServerID] + if !ok { + return nil, fmt.Errorf("triton client not found for server ID: %s", cfg.Triton.ServerID) + } - indexToName[i] = input.Name - i++ + var fixedEvaluator *fixedEvaluator + var fixedEvaluatorFields map[string]struct{} + if !rtCfg.Global.Exists { + replacementsByName := make(map[string]config.PredictionReplacement) + for _, repl := range rtCfg.Global.PredictionReplacements { + replacementsByName[repl.Name] = repl } - inputType := reflect.TypeOf("") - if input.DataType != "" { - switch input.DataType { - case "string": - inputType = reflect.TypeOf("") - case "int": - inputType = reflect.TypeOf(0) - case "int32": - inputType = reflect.TypeOf(int32(0)) - case "int64": - inputType = reflect.TypeOf(int64(0)) - case "float32", "float": - inputType = reflect.TypeOf(float32(0)) - case "float64": - inputType = reflect.TypeOf(float64(0)) - } + var err error + + fixedEvaluator, err = newFixedEvaluator(rtCfg.Global.PredictionReplacements) + if err != nil { + return nil, fmt.Errorf("failed to create fixed evaluator: %w", err) } - mappedInputs[input.Name] = &domain.Input{ - Name: input.Name, - Index: len(inputs), - Type: inputType, - Vocab: false, - Auxiliary: input.Auxiliary, + fixedEvaluatorFields = make(map[string]struct{}, len(replacementsByName)) + for name := range replacementsByName { + fixedEvaluatorFields[name] = struct{}{} } } - var outputs []domain.Output - outputByName := make(map[string]domain.Output) - - for i, output := range io.Outputs { - outputs = append(outputs, domain.Output{ - Name: output.Name, - Index: i, - DataType: output.DataType, - }) - - outputByName[output.Name] = outputs[i] - } + routerName := cfg.ID + r := &Router{ + debug: cfg.Debug, + routerName: routerName, - modelOutputName := cfg.Router.Output.FieldName - hasModelOutputName := modelOutputName != "" + configURL: rtCfg.ConfigURL, + fs: fs, + makeRoutedEvaluator: makeEvaluator, - if !cfg.Router.Global.Exists { - replacementsByName := make(map[string]config.PredictionReplacement) - for _, repl := range cfg.Router.Global.PredictionReplacements { - replacementsByName[repl.Name] = repl - } + unloader: unloader, + unloadGauge: routerModelUnloadGauge.WithLabelValues(routerName), - replacementOutputs := make([]config.PredictionReplacement, 0, len(outputs)) - for _, output := range outputs { - if hasModelOutputName && output.Name == modelOutputName { - // model-used output field name is handled in a different way - continue - } + outputConfig: rtCfg.Output, + hasGlobalModel: rtCfg.Global.Exists, - if _, ok := replacementsByName[output.Name]; !ok { - return fmt.Errorf("replacement for output %s not found", output.Name) - } + modelOutputName: rtCfg.Output.FieldName, - replacementOutputs = append(replacementOutputs, replacementsByName[output.Name]) - } + fixedEvaluator: fixedEvaluator, + fixedEvaluatorFields: fixedEvaluatorFields, - fixedEvaluator, err := newFixedEvaluator(replacementOutputs) - if err != nil { - return fmt.Errorf("failed to create fixed evaluator: %w", err) - } + configuredInputs: cfg.Inputs, + routerInputFieldName: rtCfg.InputName, - t.fixedEvaluator = fixedEvaluator - } + forceBatchSize1: rtCfg.ForceBatchSize1, - var modelOutputInOutputs bool = !hasModelOutputName - if hasModelOutputName { - _, modelOutputInOutputs = outputByName[modelOutputName] - } + workerSemaphore: make(chan struct{}, rtCfg.Workers), + maxQueueSize: uint64(rtCfg.MaxQueueSize), + queued: &atomic.Uint64{}, - if !modelOutputInOutputs { - outputs = append(outputs, domain.Output{ - Name: modelOutputName, - DataType: "string", - Index: len(outputs), - }) + queueDurationObserver: routerQueueDurationMicrosSummary.WithLabelValues(routerName), + droppedCounter: routerPredictDroppedCounter.WithLabelValues(routerName), } - t.modelOutputName = modelOutputName - - t.indexToName = indexToName - - t.signature = &domain.Signature{ - Inputs: inputs, - Outputs: outputs, - Output: outputs[0], - } + return r, nil +} - t.inputs = mappedInputs +// modelBatch holds accumulated rows destined for a single model evaluator +type modelBatch struct { + evaluator platform.PlatformEvaluator // need Signature() for input reordering + isFixedEval bool // true skips input reordering + modelName string + inputsByName map[string]interface{} // keyed by input name - accumulated batched inputs + rowOffsets []int // original positions in the incoming batch +} - return nil +// batchResult holds the result from a batched model prediction +type batchResult struct { + modelName string + results []interface{} + offsets []int + err error + outputNames []string // output names in the order returned by evaluator (for reordering) } -// Predict performs model inference with the given parameters -// params is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds) +// Predict performs model inference with the given parameters. +// params is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds). +// +// Rows are grouped into batches based on their target model evaluator. +// When forceBatchSize1 is true, each row forms its own batch (batch size 1). +// When forceBatchSize1 is false (default), rows destined for the same model are batched together. func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { if len(params) == 0 { return nil, fmt.Errorf("no input parameters provided") } + // metricFixedOnly is true if the request is only using the fixedEvaluator metricFixedOnly := true start := time.Now() defer func() { @@ -240,7 +216,6 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface } else { fos = "false" } - routerPredictDurationMicrosSummary.WithLabelValues(r.routerName, fos).Observe(float64(time.Since(start).Microseconds())) }() @@ -249,439 +224,302 @@ func (r *Router) Predict(ctx context.Context, params []interface{}) ([]interface return nil, err } - numInputs := len(params) - r.routingTableLock.RLock() defer r.routingTableLock.RUnlock() - globalExists := r.modelConfig.Router.Global.Exists - reportedGlobalModelName := r.modelConfig.Router.Output.GlobalModelOverride - noModelName := r.modelConfig.Router.Output.NoModelID + var signature *domain.Signature + batches := make(map[string]*modelBatch) - predictWaitGroup := sync.WaitGroup{} - predictWaitGroup.Add(expectedBatchSize) + // Phase 1: Group rows into batches by model name + err = func() error { + if r.ioState == nil { + return fmt.Errorf("ioState was not initialized") + } - errCh := make(chan error, expectedBatchSize) - resultsCh := make(chan offsetResults, expectedBatchSize) + signature = r.ioState.signature + routerInputOffset := r.ioState.routerInputOffset - for batchOffset := range expectedBatchSize { - // 1 input is reserved for the router input - request := make([]interface{}, numInputs-1) + hasFixedEvaluator := r.fixedEvaluator != nil + reportedGlobalModelName := r.outputConfig.GlobalModelOverride + noModelName := r.outputConfig.NoModelID - var routingValueBatched interface{} + numInputs := len(params) - for inputOffset := range numInputs { - debatched, err := shape.Debatch(params[inputOffset], batchOffset) + routerInputBatch := params[routerInputOffset] + for batchOffset := range expectedBatchSize { + // Extract routing value for this row + routingValueBatched, err := shape.Debatch(routerInputBatch, batchOffset) if err != nil { - return nil, fmt.Errorf("failed to debatch for row %d and input %d: %w", batchOffset, inputOffset, err) + return fmt.Errorf("failed to debatch routing value for row %d: %w", batchOffset, err) } - if inputOffset < r.routerInputOffset { - request[inputOffset] = debatched - } else if inputOffset == r.routerInputOffset { - routingValueBatched = debatched - } else { - request[inputOffset-1] = debatched + routingValue, err := shape.SqueezeBatch(routingValueBatched) + if err != nil { + return fmt.Errorf("failed to extract routing value for row %d: %w", batchOffset, err) } - } - - routingValue, err := shape.SqueezeBatch(routingValueBatched) - if err != nil { - return nil, fmt.Errorf("failed to extract from batch for row %d: %w", batchOffset, err) - } - - var ok bool = true - var routingValueInt int - switch routingValue := routingValue.(type) { - case int: - routingValueInt = routingValue - case int32: - routingValueInt = int(routingValue) - case int64: - routingValueInt = int(routingValue) - default: - ok = false - } - if !ok { - return nil, fmt.Errorf("routing value is not an int: %v, is %T, for row %d", routingValue, routingValue, batchOffset) - } - - routingValueString, ok := r.routingMap[routingValueInt] + var routingValueInt int + switch rv := routingValue.(type) { + case int: + routingValueInt = rv + case int32: + routingValueInt = int(rv) + case int64: + routingValueInt = int(rv) + default: + return fmt.Errorf("routing value is not an int: %v, is %T, for row %d", routingValue, routingValue, batchOffset) + } - var evaluator platform.Predictor - if !ok { - if globalExists { - metricFixedOnly = false - // fallback to global model - evaluator = r.globalModel + routingValueString, ok := r.routingMap[routingValueInt] - // override model name - if reportedGlobalModelName != "" { - routingValueString = reportedGlobalModelName + var evaluator platform.PlatformEvaluator + isFixedEval := false + if !ok { + if hasFixedEvaluator { + // No global model, use fixed evaluator + routingValueString = noModelName + isFixedEval = true + } else { + metricFixedOnly = false + evaluator = r.globalModel + if reportedGlobalModelName != "" { + routingValueString = reportedGlobalModelName + } } } else { - routingValueString = noModelName - evaluator = r.fixedEvaluator + metricFixedOnly = false + evaluator, ok = r.routingTable[routingValueString] + if !ok { + return fmt.Errorf("no evaluator found for routing value: %v", routingValue) + } } - } else { - metricFixedOnly = false - var ok bool - evaluator, ok = r.routingTable[routingValueString] - if !ok { - return nil, fmt.Errorf("no evaluator found for routing value: %v", routingValue) + // Determine batch key: unique per row when forceBatchSize1, otherwise by model name + batchKey := routingValueString + if r.forceBatchSize1 { + batchKey = strconv.Itoa(batchOffset) } - } - select { - case r.workCh <- &workRequest{ - wg: &predictWaitGroup, + batch, exists := batches[batchKey] + if !exists { + batch = &modelBatch{ + evaluator: evaluator, + isFixedEval: isFixedEval, + modelName: routingValueString, + inputsByName: make(map[string]interface{}), + rowOffsets: make([]int, 0, 1), + } - predictor: evaluator, - ctx: ctx, - request: request, + batches[batchKey] = batch + } + + // Append this row's inputs to the batch (excluding router input) + for paramOffset := range numInputs { + if paramOffset == routerInputOffset { + continue + } - queuedTime: time.Now(), - offset: batchOffset, - modelOutputEnabled: r.modelOutputName != "", - routingValueString: routingValueString, + inputName := signature.Inputs[paramOffset].Name + debatched, err := shape.Debatch(params[paramOffset], batchOffset) + if err != nil { + return fmt.Errorf("failed to debatch for row %d, input %s: %w", batchOffset, inputName, err) + } - responseCh: resultsCh, - errCh: errCh, - }: + batch.inputsByName[inputName], err = shape.AppendRowToBatch(batch.inputsByName[inputName], debatched) + if err != nil { + return fmt.Errorf("failed to append row %d to batch for input %s: %w", batchOffset, inputName, err) + } + } - // continue - default: - routerPredictDroppedCounter.WithLabelValues(r.routerName).Inc() - return nil, fmt.Errorf("work channel is full") + batch.rowOffsets = append(batch.rowOffsets, batchOffset) } - } - - predictWaitGroup.Wait() - close(errCh) - close(resultsCh) + return nil + }() - for err := range errCh { + if err != nil { return nil, err } - allResults := make([][]interface{}, expectedBatchSize) - for results := range resultsCh { - allResults[results.offset] = results.results - } - - endResults := make([]interface{}, len(r.signature.Outputs)) - for i, results := range allResults { - endResults, err = shape.ConcatAxis0(endResults, results) - if err != nil { - return nil, fmt.Errorf("failed to concatenate results for row %d: %w", i, err) - } + // early queue size check + currentQ := r.queued.Load() + if uint64(len(batches))+currentQ > r.maxQueueSize { + r.droppedCounter.Inc() + return nil, fmt.Errorf(queueSizeExceededError) } - return endResults, nil -} - -func (r *Router) Signature() *domain.Signature { - return r.signature -} - -func (r *Router) Dictionary() *common.Dictionary { - return nil -} + // Phase 2: Execute predictions in parallel with bounded concurrency + resultCh := make(chan batchResult, len(batches)) + var wg sync.WaitGroup -func (r *Router) Inputs() map[string]*domain.Input { - return r.inputs -} - -func (r *Router) Stats(stats map[string]interface{}) { - -} + for _, batch := range batches { + wg.Add(1) -func (r *Router) Close() error { - return nil -} + // this must be decremented if queue is full and once no longer in queue + nowQueued := r.queued.Add(1) + startQueueTime := time.Now() -// TODO refactor with service/tfmodel/service.isModified()? -func (r *Router) isModified(snapshot *config.Modified) bool { - if r.routerConfig == nil || r.configModified == nil { - return true - } + if nowQueued > r.maxQueueSize { + r.queued.Add(^uint64(0)) + r.droppedCounter.Inc() + return nil, fmt.Errorf(queueSizeExceededError) + } - if snapshot.Max.IsZero() { - return false - } + go func(b *modelBatch) { + defer wg.Done() - r.configLock.RLock() - modified := r.configModified - r.configLock.RUnlock() + // Acquire semaphore slot + r.workerSemaphore <- struct{}{} - return !(modified.Max.Equal(snapshot.Max) && modified.Min.Equal(snapshot.Min)) -} + r.queued.Add(^uint64(0)) + r.queueDurationObserver.Observe(float64(time.Since(startQueueTime).Microseconds())) -func (r *Router) ReloadIfNeeded(ctx context.Context) error { - start := time.Now() - isFullReload := false - defer func() { - var mode string - if isFullReload { - mode = "full" - } else { - mode = "checks" - } - routerReloadDurationMicrosSummary.WithLabelValues(r.routerName, mode).Observe(float64(time.Since(start).Microseconds())) - }() + defer func() { + <-r.workerSemaphore + }() - // fetch and check router configuration file - snapshot, err := files.ModifiedSnapshot(ctx, r.fs, r.configURL, nil) - if err != nil { - return fmt.Errorf("failed to check router configuration file: %w", err) - } + // Reorder inputs to match each evaluator's expected order before calling Predict + var results []interface{} + var err error - if !r.isModified(snapshot) { - // check health of all underlying models - var wg sync.WaitGroup + // Capture output names for reordering in Phase 3 + var outputNames []string + bs := len(b.rowOffsets) - r.configLock.RLock() - errChannels := len(r.routingTable) - if r.globalModel != nil { - errChannels++ - } + if b.isFixedEval { + results, err = r.fixedEvaluator.Predict(bs) + outputNames = r.fixedEvaluator.OutputNames() + } else { + // Reorder inputs to match this evaluator's expected order + evalSig := b.evaluator.Signature() + orderedInputs := make([]interface{}, len(evalSig.Inputs)) + for i, sigInput := range evalSig.Inputs { + inputData, exists := b.inputsByName[sigInput.Name] + if !exists { + err = fmt.Errorf("input %s not found in batch for model %s", sigInput.Name, b.modelName) + break + } + orderedInputs[i] = inputData + } - errCh := make(chan error, errChannels) + if err == nil { + // Rely on downstream for timeouts + results, err = b.evaluator.Predict(ctx, orderedInputs) - if r.globalModel != nil { - wg.Add(1) - go func() { - defer wg.Done() - err := r.globalModel.ReloadIfNeeded(ctx) - if err != nil { - errCh <- fmt.Errorf("failed to reload global model: %w", err) + outputNames = make([]string, len(evalSig.Outputs)) + for i, out := range evalSig.Outputs { + outputNames[i] = out.Name + } } - }() - } + } - for m, p := range r.routingTable { - wg.Add(1) - go func(m string, p platform.PlatformEvaluator) { - defer wg.Done() - err := p.ReloadIfNeeded(ctx) - if err != nil { - errCh <- fmt.Errorf("failed to reload model %s: %w", m, err) + // Append model name to results if configured + if r.modelOutputName != "" && err == nil { + modelNames := make([][]string, bs) + for i := range modelNames { + modelNames[i] = []string{b.modelName} } - }(m, p) - } - - wg.Wait() - close(errCh) - if len(errCh) > 0 { - var errStrings []string - for err := range errCh { - errStrings = append(errStrings, err.Error()) + results = append(results, modelNames) + outputNames = append(outputNames, r.modelOutputName) } - err = fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) - } - - r.configLock.RUnlock() - return err - } - - isFullReload = true - - // otherwise just abandon the routing table status checks - - r.configLock.Lock() - defer r.configLock.Unlock() - - r.configModified = snapshot - - // load router configuration file - rawReader, err := r.fs.OpenURL(ctx, r.configURL) - if err != nil { - return fmt.Errorf("failed to open router configuration file: %w", err) - } - - defer rawReader.Close() - var reader io.Reader = rawReader - if strings.HasSuffix(r.configURL, ".gz") { - if reader, err = gzip.NewReader(rawReader); err != nil { - return fmt.Errorf("failed to create gzip reader for router configuration file: %w", err) - } - } - - newConfig := new(router.RouterConfig) - - // TODO move this check earlier - if strings.Contains(r.configURL, ".yaml") { - decoder := yaml.NewDecoder(reader) - err = decoder.Decode(newConfig) - } else if strings.Contains(r.configURL, ".json") { - err = json.NewDecoder(reader).Decode(newConfig) - } else { - return fmt.Errorf("unsupported router configuration file type: %s", r.configURL) - } - - if err != nil { - return fmt.Errorf("failed to decode router configuration file: %w", err) - } - - if err := r.applyRouterConfig(ctx, newConfig); err != nil { - return err - } - - return nil -} - -func (r *Router) applyRouterConfig(ctx context.Context, newConfig *router.RouterConfig) error { - modelsToUnload := make(map[string]struct{}) - reuseEvaluators := make(map[string]platform.PlatformEvaluator) - var reuseGlobal platform.PlatformEvaluator - - oldConfig := r.routerConfig - if oldConfig != nil { - for _, entity := range oldConfig.EntityMapping { - modelsToUnload[entity.ModelName] = struct{}{} - if evaluator, ok := r.routingTable[entity.ModelName]; ok { - reuseEvaluators[entity.ModelName] = evaluator + resultCh <- batchResult{ + modelName: b.modelName, + results: results, + offsets: b.rowOffsets, + err: err, + outputNames: outputNames, } - } - - if oldConfig.GlobalModelName != "" { - modelsToUnload[oldConfig.GlobalModelName] = struct{}{} - reuseGlobal = r.globalModel - } + }(batch) } - newModelMapping := make(map[int]string) - for _, entity := range newConfig.EntityMapping { - newModelMapping[entity.EntityID] = entity.ModelName - delete(modelsToUnload, entity.ModelName) - } + wg.Wait() + close(resultCh) - globalModelName := newConfig.GlobalModelName - if globalModelName != "" { - delete(modelsToUnload, globalModelName) + // Phase 3: Reassemble results in original order + // Build router output name -> index mapping for reordering + // TODO see if memoizing this provides material performance boosts + routerOutputIndex := make(map[string]int, len(signature.Outputs)) + for i, out := range signature.Outputs { + routerOutputIndex[out.Name] = i } - newRoutingTable := make(map[string]platform.PlatformEvaluator) - for _, entity := range newConfig.EntityMapping { - model := entity.ModelName - if _, ok := newRoutingTable[model]; ok { - continue - } - - if evaluator, ok := reuseEvaluators[model]; ok { - newRoutingTable[model] = evaluator - continue - } - - evaluator, err := tricli.NewRoutedTritonEvaluator( - model, - r.tritonClient, - r.modelConfig.Triton.Timeout, - r.indexToName, - ) + // allResults will be [expectedBatchSize][len(signature.Outputs)] + allResults := make([][]interface{}, expectedBatchSize) - if err != nil { - return fmt.Errorf("failed to create Triton evaluator for model %s: %w", model, err) + for res := range resultCh { + if res.err != nil { + return nil, fmt.Errorf("prediction failed for model %s: %w", res.modelName, res.err) } - newRoutingTable[model] = evaluator - } + // Extract individual rows from the batched result and place at original offsets + for evalOffset, originalOffset := range res.offsets { + rowResult := make([]interface{}, len(signature.Outputs)) - var globalEvaluator platform.PlatformEvaluator - if globalModelName != "" { - if oldConfig != nil && globalModelName == oldConfig.GlobalModelName && reuseGlobal != nil { - globalEvaluator = reuseGlobal - } else if evaluator, ok := newRoutingTable[globalModelName]; ok { - globalEvaluator = evaluator - } else if evaluator, ok := reuseEvaluators[globalModelName]; ok { - globalEvaluator = evaluator - } else { - var err error - globalEvaluator, err = tricli.NewRoutedTritonEvaluator( - globalModelName, - r.tritonClient, - r.modelConfig.Triton.Timeout, - r.indexToName, - ) + // Reorder outputs to match router's expected output order + for evalOutputIdx, outputBatch := range res.results { + extracted, err := shape.ExtractRowFromBatch(outputBatch, evalOffset) + if err != nil { + return nil, fmt.Errorf("failed to extract row %d from model %s output index %d: %w", + evalOffset, res.modelName, evalOutputIdx, err) + } - if err != nil { - return fmt.Errorf("failed to create Triton evaluator for global model %s: %w", globalModelName, err) - } - } - } + // Map evaluator output index to router output index by name + var originalOutputIdx int + if res.outputNames == nil { + // Fallback: assume same order (shouldn't happen in normal operation) + originalOutputIdx = evalOutputIdx + } else { + outputName := res.outputNames[evalOutputIdx] + + var exists bool + originalOutputIdx, exists = routerOutputIndex[outputName] + if !exists { + return nil, fmt.Errorf("output %s from model %s not found in router signature", + outputName, res.modelName) + } + } - wg := sync.WaitGroup{} - errCh := make(chan error, 1) - if globalEvaluator != nil { - wg.Add(1) - go func() { - defer wg.Done() - if err := globalEvaluator.ReloadIfNeeded(ctx); err != nil { - errCh <- fmt.Errorf("failed to reload global model %s: %w", globalModelName, err) + rowResult[originalOutputIdx] = extracted } - }() - } - for model := range newRoutingTable { - wg.Add(1) - go func(model string) { - defer wg.Done() - if err := newRoutingTable[model].ReloadIfNeeded(ctx); err != nil { - errCh <- fmt.Errorf("failed to reload model %s: %w", model, err) - } - }(model) + allResults[originalOffset] = rowResult + } } - wg.Wait() - close(errCh) - if len(errCh) > 0 { - var errStrings []string - for err := range errCh { - errStrings = append(errStrings, err.Error()) + // Reshape all values into [outputs][batch][M] + endResults := make([]interface{}, len(signature.Outputs)) + for i, results := range allResults { + endResults, err = shape.ConcatAxis0(endResults, results) + if err != nil { + return nil, fmt.Errorf("failed to concatenate results for row %d: %w", i, err) } - return fmt.Errorf("one or more model reloading errors: %s", strings.Join(errStrings, "; ")) } - func() { - r.routingTableLock.Lock() - defer r.routingTableLock.Unlock() - if globalEvaluator != nil { - if _, exists := newRoutingTable[globalModelName]; !exists { - newRoutingTable[globalModelName] = globalEvaluator - } - } - r.globalModel = globalEvaluator - r.routingMap = newModelMapping - r.routerConfig = newConfig - r.routingTable = newRoutingTable - }() + return endResults, nil +} - for model := range modelsToUnload { - routerModelUnloadGauge.WithLabelValues(r.routerName).Inc() +func (r *Router) Signature() *domain.Signature { + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + return r.ioState.signature +} - go func(modelName string) { - defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() +func (r *Router) Dictionary() *common.Dictionary { + return nil +} - ctxTo, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := r.unloadModel(ctxTo, modelName); err != nil { - log.Printf("failed to unload model %s: %v\n", modelName, err) - } - }(model) - } +func (r *Router) Inputs() map[string]*domain.Input { + r.routingTableLock.RLock() + defer r.routingTableLock.RUnlock() + return r.ioState.inputs +} - return nil +func (r *Router) Stats(stats map[string]interface{}) { + // do nothing } -func (r *Router) unloadModel(ctx context.Context, modelName string) error { - defer routerModelUnloadGauge.WithLabelValues(r.routerName).Dec() - if err := r.tritonClient.ModelUnload(ctx, modelName); err != nil { - return fmt.Errorf("failed to unload model %s: %w", modelName, err) - } +func (r *Router) Close() error { return nil } diff --git a/service/platform/router/router_test.go b/service/platform/router/router_test.go index a5baa16..6de05c7 100644 --- a/service/platform/router/router_test.go +++ b/service/platform/router/router_test.go @@ -2,550 +2,791 @@ package router import ( "context" - "errors" + "fmt" + "log" "reflect" + "slices" + "strconv" "strings" "sync" + "sync/atomic" "testing" - "time" "github.com/viant/mly/service/config" "github.com/viant/mly/service/domain" "github.com/viant/mly/service/platform" - tricli "github.com/viant/mly/service/triton" - "github.com/viant/mly/shared" + "github.com/viant/mly/service/request/shape" + "github.com/viant/mly/service/triton" "github.com/viant/mly/shared/common" - sharedrouter "github.com/viant/mly/shared/config/router" ) -// --- Router Predict scaffolds --- +// mockEvaluator currently handles all test case behaviors. +// This may be a sign that the router itself needs to be refactored into separate parts. +type mockEvaluator struct { + // reload-related objects + modelName string -type mockPredictOnly struct{} + // see reload_test.go + tritonServer *mockTritonServer -func (m *mockPredictOnly) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - // params is expected to have a single non-router input in this test: [][]string with shape [1][1] - var v string - switch typed := params[0].(type) { - case [][]string: - v = typed[0][0] - default: - tval := reflect.TypeOf(params[0]) - panic("unexpected input type in mockPredictOnly: " + tval.String()) - } - // simple function: length of string as float32 - out := [][]float32{{float32(len(v))}} - return []interface{}{out}, nil -} + // used in batching tests + predictCalls int + mu sync.Mutex -func (m *mockPredictOnly) Signature() *domain.Signature { return nil } -func (m *mockPredictOnly) Dictionary() *common.Dictionary { return nil } -func (m *mockPredictOnly) Inputs() map[string]*domain.Input { return nil } -func (m *mockPredictOnly) Stats(map[string]interface{}) {} -func (m *mockPredictOnly) Close() error { return nil } -func (m *mockPredictOnly) ReloadIfNeeded(ctx context.Context) error { - return nil -} + // force an error + err error -type mockTritonClient struct { - mu sync.Mutex - loadCalls []string - unloadCalls []string - readyCalls []string - readyState map[string]bool - unloadCh chan string - modelLoadErr map[string]error -} + // for queueing tests + waitFor *sync.WaitGroup + doneGroup *sync.WaitGroup -func (m *mockTritonClient) ServerReady(ctx context.Context) error { return nil } -func (m *mockTritonClient) ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) { - return nil, nil + predictor func(params []interface{}, signature *domain.Signature) ([]interface{}, error) + signature func() *domain.Signature } -func (m *mockTritonClient) ModelReady(ctx context.Context, modelName string) (bool, error) { + +func (m *mockEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { m.mu.Lock() - defer m.mu.Unlock() - m.readyCalls = append(m.readyCalls, modelName) - if ready, ok := m.readyState[modelName]; ok { - return ready, nil + m.predictCalls++ + m.mu.Unlock() + + if m.waitFor != nil { + m.waitFor.Wait() + m.waitFor = nil } - return true, nil -} -func (m *mockTritonClient) ModelLoad(ctx context.Context, modelName string) error { - m.mu.Lock() - defer m.mu.Unlock() - m.loadCalls = append(m.loadCalls, modelName) - if err := m.modelLoadErr[modelName]; err != nil { - return err + + if m.doneGroup != nil { + defer m.doneGroup.Done() + m.doneGroup = nil } - if m.readyState == nil { - m.readyState = make(map[string]bool) + + if m.err != nil { + return nil, m.err } - m.readyState[modelName] = true - return nil -} -func (m *mockTritonClient) ModelUnload(ctx context.Context, modelName string) error { - m.mu.Lock() - m.unloadCalls = append(m.unloadCalls, modelName) - ch := m.unloadCh - m.mu.Unlock() - if ch != nil { - ch <- modelName + + sig := m.signature() + inputs := sig.Inputs + if len(inputs) != len(params) { + return nil, fmt.Errorf("mock error: expected %d inputs, got %d", len(inputs), len(params)) } - return nil -} -func (m *mockTritonClient) Close() error { return nil } -func (m *mockTritonClient) snapshotLoadCalls() []string { - m.mu.Lock() - defer m.mu.Unlock() - return append([]string(nil), m.loadCalls...) -} + if m.predictor == nil { + // for test cases that do not validate results like reload-centric cases + return nil, nil + } -func (m *mockTritonClient) snapshotUnloadCalls() []string { - m.mu.Lock() - defer m.mu.Unlock() - return append([]string(nil), m.unloadCalls...) + return m.predictor(params, sig) } -func (m *mockTritonClient) snapshotReadyCalls() []string { - m.mu.Lock() - defer m.mu.Unlock() - return append([]string(nil), m.readyCalls...) -} +func (m *mockEvaluator) Signature() *domain.Signature { return m.signature() } +func (m *mockEvaluator) Dictionary() *common.Dictionary { return nil } +func (m *mockEvaluator) Inputs() map[string]*domain.Input { return nil } +func (m *mockEvaluator) Stats(map[string]interface{}) {} +func (m *mockEvaluator) Close() error { return nil } -func waitForCalls(t *testing.T, ch <-chan string, count int) []string { - t.Helper() - var out []string - for i := 0; i < count; i++ { - select { - case v := <-ch: - out = append(out, v) - case <-time.After(time.Second): - t.Fatalf("timeout waiting for call %d/%d", i+1, count) - } +func (m *mockEvaluator) ReloadIfNeeded(ctx context.Context) error { + if m.tritonServer == nil { + return nil } - return out + + return m.tritonServer.ModelLoad(ctx, m.modelName) } -func TestRouter_Predict_RoutesAndConcats(t *testing.T) { - ctx := context.Background() +func (m *mockEvaluator) getPredictCalls() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.predictCalls +} - tests := []struct { - name string - routerConfig *config.RouterConfig - verifier func(t *testing.T, results []interface{}) - }{ - { - name: "with global model", - routerConfig: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{ - Exists: true, // avoid fixed replacements path - }, - }, - verifier: func(t *testing.T, results []interface{}) { - if len(results) != 1 { - t.Fatalf("expected 1 output, got %d", len(results)) - } - out, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - want := [][]float32{{1}, {4}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }, - }, - { - name: "without global model", - routerConfig: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{ - PredictionReplacements: []config.PredictionReplacement{ - { - Name: "score", - Type: "float32", - Value: 1.0, - }, - }, - }, - }, - verifier: func(t *testing.T, results []interface{}) { - if len(results) != 1 { - t.Fatalf("expected 1 output, got %d", len(results)) - } - out, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - want := [][]float32{{1}, {4}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }, - }, - { - name: "with model output name", - routerConfig: &config.RouterConfig{ - ConfigURL: "memory://router-config", - InputName: "router_id", - Global: config.GlobalModelConfig{ - Exists: true, // avoid fixed replacements path - }, - Output: config.OutputConfig{ - FieldName: "model_output", - }, - }, - verifier: func(t *testing.T, results []interface{}) { - if len(results) != 2 { - t.Fatalf("expected 1 output, got %d", len(results)) +// appendFloatPredict expects []interface{[][]string} and returns []interface{}{[][]float32} +// The output field names are expected to be an integer in string form, which will be parsed and appended. +func appendFloatPredict(params []interface{}, signature *domain.Signature) ([]interface{}, error) { + batchSize, err := shape.BatchSize(params[0]) + if err != nil { + return nil, fmt.Errorf("could not determine batch size: %w", err) + } + + batchConcats := make([]*strings.Builder, batchSize) + for _, param := range params { + switch batch := param.(type) { + case [][]string: + for bi, s := range batch { + if batchConcats[bi] == nil { + batchConcats[bi] = &strings.Builder{} } - func() { - out, ok := results[0].([][]float32) - if !ok { - t.Fatalf("expected [][]float32, got %T", results[0]) - } - want := [][]float32{{1}, {4}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }() - - func() { - out, ok := results[1].([][]string) - if !ok { - t.Fatalf("expected [][]string, got %T", results[1]) - } - want := [][]string{{"model1"}, {"model2"}} - if !reflect.DeepEqual(out, want) { - t.Errorf("output mismatch: got %#v, want %#v", out, want) - } - }() - }, - }, + sb := batchConcats[bi] + sb.WriteString(s[0]) + } + default: + return nil, fmt.Errorf("unexpected input type: %T", param) + } } - for _, test := range tests { - - t.Run(test.name, func(t *testing.T) { - cfg := &config.Model{ - ID: "router_test", - Mode: "router", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - // router input first (default offset 0) - {Name: "router_id", Index: 0, DataType: "int64"}, - // single backend input - {Name: "text", Index: 1, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "score", Index: 0, DataType: "float32"}, - }, - }, - Router: test.routerConfig, - Triton: &config.TritonConfig{ - ServerID: "test_server", - }, + numOutputs := len(signature.Outputs) + retVal := make([]interface{}, numOutputs) + for oi := range numOutputs { + outputBatch := make([][]float32, batchSize) + for i, v := range batchConcats { + outputName := signature.Outputs[oi].Name + floatStr := outputName + v.String() + floatVal, err := strconv.ParseFloat(floatStr, 32) + if err != nil { + return nil, fmt.Errorf("could not parse float: %v", floatStr) } - cfg.Init(nil) - cfg.Router.MaxQueueSize = 1000 - cfg.Router.Workers = 3 - - router, err := NewRouter(cfg, nil, map[string]tricli.TritonClient{ - "test_server": &mockTritonClient{}, - }) + outputBatch[i] = []float32{float32(floatVal)} + } - if err != nil { - t.Fatalf("NewRouter error: %v", err) - } + retVal[oi] = outputBatch + } - router.routingMap = map[int]string{ - 1: "model1", - 2: "model2", - } - mockEval := &mockPredictOnly{} - router.routingTable = map[string]platform.PlatformEvaluator{ - "model1": mockEval, - "model2": mockEval, - } + return retVal, nil +} - // batch of 2 - params := []interface{}{ - [][]int64{{1}, {2}}, // router id - [][]string{{"a"}, {"abcd"}}, // backend input - } +type configVariant struct { + numOutputs int + hasGlobalModel bool + hasOutputName bool + forceBatchSize1 bool + reverseFPOutputs bool +} - results, err := router.Predict(ctx, params) - if err != nil { - t.Fatalf("Predict error: %v", err) +func makeConfig(cv configVariant) *config.RouterConfig { + var prs []config.PredictionReplacement + if !cv.hasGlobalModel { + prs = make([]config.PredictionReplacement, cv.numOutputs) + for i := range cv.numOutputs { + prs[i] = config.PredictionReplacement{ + Name: strconv.Itoa(i), + Type: "float32", + Value: float32(i) + 0.1, } + } + } - test.verifier(t, results) - }) + if cv.reverseFPOutputs { + slices.Reverse(prs) } -} -func TestRouter_applyRouterConfig_LoadsAndSwaps(t *testing.T) { - ctx := context.Background() - mockClient := &mockTritonClient{ - unloadCh: make(chan string, 2), - readyState: map[string]bool{ - "modelC": false, - "global-new": false, - }, + var outputName = "" + if cv.hasOutputName { + outputName = "model_id" } - oldConfig := &sharedrouter.RouterConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - {EntityID: 2, ModelName: "modelB"}, - }, - GlobalModelName: "global-old", + gmo := "" + if cv.hasGlobalModel { + gmo = "global" } - router := &Router{ - tritonClient: mockClient, - modelConfig: &config.Model{ - Triton: &config.TritonConfig{ - Timeout: 100, - }, - }, - indexToName: map[int]string{ - 0: "text", + rc := &config.RouterConfig{ + InputName: "router_id", + ConfigURL: "memory://router-config", + ForceBatchSize1: cv.forceBatchSize1, + Global: config.GlobalModelConfig{ + Exists: cv.hasGlobalModel, + PredictionReplacements: prs, }, - routerConfig: oldConfig, - routingMap: map[int]string{ - 1: "modelA", - 2: "modelB", + Output: config.OutputConfig{ + FieldName: outputName, + GlobalModelOverride: gmo, }, - routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockPredictOnly{}, - "modelB": &mockPredictOnly{}, - }, - globalModel: &mockPredictOnly{}, + MaxQueueSize: 100, + Workers: 3, } - reusedModelB := router.routingTable["modelB"] + return rc +} - newConfig := &sharedrouter.RouterConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelB"}, - {EntityID: 3, ModelName: "modelC"}, - }, - GlobalModelName: "global-new", - } +type predictTestCase struct { + name string + routerConfig *config.RouterConfig - if err := router.applyRouterConfig(ctx, newConfig); err != nil { - t.Fatalf("applyRouterConfig returned error: %v", err) - } + reverseInputs bool + reverseOutputs bool - loadCalls := mockClient.snapshotLoadCalls() - if len(loadCalls) != 2 { - // global-new and modelC - t.Fatalf("expected 2 model loads, got %d (%v), ready=%v", len(loadCalls), loadCalls, mockClient.snapshotReadyCalls()) - } + stringInputs [][]string + routerInputs []int - expectedLoads := map[string]bool{ - "modelC": false, - "global-new": false, - } - for _, call := range loadCalls { - if _, ok := expectedLoads[call]; ok { - expectedLoads[call] = true - } - } - for model, seen := range expectedLoads { - if !seen { - t.Fatalf("expected load for %s was not observed; calls=%v", model, loadCalls) - } - } + // outputs + expectedOutputs [][]float32 + expectCallCounts map[string]int +} - readyCalls := mockClient.snapshotReadyCalls() - expectedReady := map[string]bool{ - "modelC": false, - "global-new": false, - } - for _, call := range readyCalls { - if _, ok := expectedReady[call]; ok { - expectedReady[call] = true +// predictTest creates the signature and evaluators, runs Predict, and runs the tests +func predictTest(t *testing.T, test predictTestCase) { + routerInputs, router, mockEvaluators, globalModelName, params := prepareTestRouter(t, test) + + dl, _ := t.Deadline() + ctx, cancel := context.WithDeadline(context.Background(), dl) + defer cancel() + results, err := router.Predict(ctx, params) + if err != nil { + t.Fatalf("Predict error: %v", err) + } + + // model name is appended at the end, first check normal outputs + for oi, outputBatch := range test.expectedOutputs { + actualOutput := results[oi] + switch aob := actualOutput.(type) { + case [][]float32: + for obi, ov := range outputBatch { + actualValue := aob[obi][0] + log.Printf("model output %d expected:%f actual:%f", obi, ov, actualValue) + + if actualValue != ov { + t.Fatalf("input %d offset %d expected %f, got %f", oi, obi, ov, actualValue) + } + } + default: + t.Fatalf("input %d expected [][]float32, got %T", oi, actualOutput) } } - for model, seen := range expectedReady { - if !seen { - t.Fatalf("expected readiness check for %s was not observed; calls=%v", model, readyCalls) + + // then check model name outputs + if test.routerConfig.Output.FieldName != "" { + actualOutput := results[len(test.expectedOutputs)] + switch aob := actualOutput.(type) { + case [][]string: + for obi, ov := range aob { + routedNumber := routerInputs[obi] + routedModel, ok := router.routingMap[routedNumber] + if !ok { + if router.globalModel == nil { + routedModel = test.routerConfig.Output.NoModelID + } else { + routedModel = globalModelName + } + } + + log.Printf("model name expected:%s actual:%s", routedModel, ov[0]) + + if ov[0] != routedModel { + t.Fatalf("model name output expected %s, got %s", "model"+strconv.Itoa(obi), ov[0]) + } + } + default: + t.Fatalf("model name output expected [][]string, got %T", actualOutput) } } - waitForCalls(t, mockClient.unloadCh, 2) - unloadCalls := mockClient.snapshotUnloadCalls() - expectedUnloads := map[string]bool{ - "modelA": false, - "global-old": false, - } - for _, call := range unloadCalls { - if _, ok := expectedUnloads[call]; ok { - expectedUnloads[call] = true + // check number of predict calls + for _, evaluator := range mockEvaluators { + expectCallCount, hasExpect := test.expectCallCounts[evaluator.modelName] + if !hasExpect { + continue } - } - for model, seen := range expectedUnloads { - if !seen { - t.Fatalf("expected unload for %s was not observed; calls=%v", model, unloadCalls) + + predictCalls := evaluator.getPredictCalls() + log.Printf("predict calls %s expected:%d, actual:%d", evaluator.modelName, expectCallCount, predictCalls) + + if predictCalls != expectCallCount { + t.Fatalf("predict calls expected %d, got %d", expectCallCount, predictCalls) } } +} - if router.routerConfig != newConfig { - t.Fatalf("routerConfig pointer not updated") +func prepareTestRouter(t *testing.T, test predictTestCase) ([]int, *Router, map[string]*mockEvaluator, string, []interface{}) { + tritonServerID := "test_server" + cfg := &config.Model{ + ID: test.name, + Mode: "router", + Platform: "triton", + Router: test.routerConfig, + Triton: &config.TritonConfig{ + ServerID: tritonServerID, + }, } - expectedRouting := map[int]string{ - 1: "modelB", - 3: "modelC", - } - if !reflect.DeepEqual(router.routingMap, expectedRouting) { - t.Fatalf("routingMap mismatch, got %#v", router.routingMap) + cfg.Init(nil) + + signature := &domain.Signature{} + for i := range test.stringInputs { + signature.Inputs = append(signature.Inputs, domain.Input{ + Name: strconv.Itoa(i), + Index: i, + Type: reflect.TypeOf(""), + }) } - if router.globalModel == nil { - t.Fatalf("globalModel was not set") + for i := range test.expectedOutputs { + signature.Outputs = append(signature.Outputs, domain.Output{ + Name: strconv.Itoa(i), + Index: i, + DataType: "float32", + }) } - if _, ok := router.routingTable["modelB"]; !ok { - t.Fatalf("routingTable missing modelB") + var routerInputs []int = test.routerInputs + if routerInputs == nil { + sampledInput := test.stringInputs[0] + routerInputs = make([]int, len(sampledInput)) + for j := range len(sampledInput) { + routerInputs[j] = j + } } - if router.routingTable["modelB"] != reusedModelB { - t.Fatalf("modelB evaluator was not reused") + + router, err := newRouter(cfg, nil, map[string]UnloadService{ + tritonServerID: &triton.Service{}, + }, nil) + + if err != nil { + t.Fatalf("NewRouter error: %v", err) } - if _, ok := router.routingTable["modelC"]; !ok { - t.Fatalf("routingTable missing modelC") + + // manually generate router configuration + + // see how model names are constructed later + model0Name := "model0" + model1Name := "model1" + router.routingMap = map[int]string{ + 0: model0Name, + 1: model1Name, } - if _, ok := router.routingTable["modelA"]; ok { - t.Fatalf("routingTable still contains modelA") + + routerInputName := cfg.Router.InputName + routerInput := domain.Input{Name: routerInputName, Index: 0, Type: reflect.TypeOf(int64(0))} + + routerOutputs := make([]domain.Output, len(signature.Outputs)) + copy(routerOutputs, signature.Outputs) + + if test.routerConfig.Output.FieldName != "" { + routerOutputs = append(routerOutputs, domain.Output{ + Name: test.routerConfig.Output.FieldName, + Index: len(routerOutputs), + DataType: "string", + }) } -} -func TestRouter_applyRouterConfig_LoadError(t *testing.T) { - ctx := context.Background() - loadErr := errors.New("load failure") - mockClient := &mockTritonClient{ - readyState: map[string]bool{ - "modelX": false, + // initialize ioState with base outputs + router.ioState = &IOState{ + inputs: map[string]*domain.Input{ + routerInputName: &routerInput, }, - modelLoadErr: map[string]error{ - "modelX": loadErr, + signature: &domain.Signature{ + Inputs: []domain.Input{ + routerInput, + }, + Outputs: routerOutputs, }, + routerInputOffset: 0, } - oldConfig := &sharedrouter.RouterConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - }, + // add inputs + for i, sigInput := range signature.Inputs { + router.ioState.inputs[sigInput.Name] = &signature.Inputs[i] + router.ioState.signature.Inputs = append(router.ioState.signature.Inputs, sigInput) } - router := &Router{ - tritonClient: mockClient, - modelConfig: &config.Model{ - Triton: &config.TritonConfig{ - Timeout: 50, - }, - }, - indexToName: map[int]string{}, - routerConfig: oldConfig, - routingMap: map[int]string{ - 1: "modelA", - }, + model1Signature := &domain.Signature{ + Inputs: make([]domain.Input, len(signature.Inputs)), + Outputs: make([]domain.Output, len(signature.Outputs)), } - newConfig := &sharedrouter.RouterConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 2, ModelName: "modelX"}, - }, + copy(model1Signature.Inputs, signature.Inputs) + copy(model1Signature.Outputs, signature.Outputs) + + if test.reverseInputs { + slices.Reverse(model1Signature.Inputs) } - err := router.applyRouterConfig(ctx, newConfig) - if err == nil { - t.Fatalf("expected error but got nil") + if test.reverseOutputs { + slices.Reverse(model1Signature.Outputs) } - if !strings.Contains(err.Error(), "modelX") { - t.Fatalf("expected error mentioning modelX, got %v", err) + + mockEvaluators := map[string]*mockEvaluator{ + model0Name: { + signature: func() *domain.Signature { return signature }, + predictor: appendFloatPredict, + modelName: model0Name, + }, + model1Name: { + signature: func() *domain.Signature { return model1Signature }, + predictor: appendFloatPredict, + modelName: model1Name, + }, } - if router.routerConfig != oldConfig { - t.Fatalf("routerConfig should remain unchanged on error") + router.routingTable = make(map[string]platform.PlatformEvaluator) + for modelName, evaluator := range mockEvaluators { + router.routingTable[modelName] = evaluator } - if !reflect.DeepEqual(router.routingMap, map[int]string{1: "modelA"}) { - t.Fatalf("routingMap should remain unchanged on error") + globalModelName := cfg.Router.Output.GlobalModelOverride + if router.hasGlobalModel { + router.globalModel = &mockEvaluator{ + signature: func() *domain.Signature { return signature }, + predictor: appendFloatPredict, + modelName: globalModelName, + } + + router.routingTable[globalModelName] = router.globalModel } - if router.routingTable != nil { - t.Fatalf("routingTable should not be replaced on error") + // reshape inputs + params := []interface{}{} + + paramRouterInputs := make([][]int64, len(routerInputs)) + for pri, ri := range routerInputs { + paramRouterInputs[pri] = []int64{int64(ri)} } + params = append(params, paramRouterInputs) - loadCalls := mockClient.snapshotLoadCalls() - if len(loadCalls) != 1 || loadCalls[0] != "modelX" { - t.Fatalf("expected single load attempt for modelX, got %v", loadCalls) + for _, input := range test.stringInputs { + inputVals := [][]string{} + for _, inputVal := range input { + inputVals = append(inputVals, []string{inputVal}) + } + + params = append(params, inputVals) } + return routerInputs, router, mockEvaluators, globalModelName, params } -func TestRouter_applyRouterConfig_SkipsLoadWhenReady(t *testing.T) { - ctx := context.Background() - mockClient := &mockTritonClient{ - readyState: map[string]bool{ - "modelC": true, +func TestRouter_Predict_GlobalModel(t *testing.T) { + tests := []predictTestCase{ + { + name: "with_global_model", + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: true, hasOutputName: false}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + {14, 25, 36}, + }, + }, + { + name: "without_global_model", + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: false, hasOutputName: false}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + // third offset should get fixed prediction for output 0 + {14, 25, 0.1}, + }, }, } - oldConfig := &sharedrouter.RouterConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - }, + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + predictTest(t, test) + }) } +} - router := &Router{ - tritonClient: mockClient, - modelConfig: &config.Model{ - Triton: &config.TritonConfig{ - Timeout: 10, +func TestRouter_Predict_ModelName(t *testing.T) { + t.Run("no global model", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: false, hasOutputName: true}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, }, - }, - indexToName: map[int]string{ - 0: "text", - }, - routerConfig: oldConfig, - routingMap: map[int]string{ - 1: "modelA", - }, - routingTable: map[string]platform.PlatformEvaluator{ - "modelA": &mockPredictOnly{}, - }, + expectedOutputs: [][]float32{ + {14, 25, 0.1}, + }, + }) + }) + + t.Run("with global model", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{numOutputs: 1, hasGlobalModel: true, hasOutputName: true}), + stringInputs: [][]string{ + {"1", "2", "3"}, + {"4", "5", "6"}, + }, + expectedOutputs: [][]float32{ + {14, 25, 36}, + }, + }) + }) +} + +func TestRouter_Predict_BatchingBehavior(t *testing.T) { + t.Run("batch to model", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) + + t.Run("force batch size 1", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + forceBatchSize1: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 3, + "model1": 3, + }, + }) + }) + + t.Run("batched with no global", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 2, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 0.1, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) +} + +func TestRouter_Predict_SignatureReordering(t *testing.T) { + t.Run("reverse inputs", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + reverseInputs: true, + }) + }) + + t.Run("reverse outputs", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 1, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 44, 55, 66}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + reverseOutputs: true, + }) + }) + + t.Run("reversed fixed evaluator", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 2, + hasGlobalModel: false, + hasOutputName: true, + reverseFPOutputs: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 2, 0, 1, + }, + expectedOutputs: [][]float32{ + {11, 22, 33, 0.1, 55, 66}, + {111, 122, 133, 1.1, 155, 166}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) + + t.Run("reversed everything", func(t *testing.T) { + predictTest(t, predictTestCase{ + routerConfig: makeConfig(configVariant{ + numOutputs: 2, + hasGlobalModel: false, + hasOutputName: true, + reverseFPOutputs: true, + }), + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + routerInputs: []int{ + 0, 1, 0, 2, 0, 1, + }, + reverseInputs: true, + reverseOutputs: true, + expectedOutputs: [][]float32{ + {11, 22, 33, 0.1, 55, 66}, + {111, 122, 133, 1.1, 155, 166}, + }, + expectCallCounts: map[string]int{ + "model0": 1, + "model1": 1, + }, + }) + }) +} + +func TestRouter_Predict_Queuing(t *testing.T) { + rtCfg := makeConfig(configVariant{ + numOutputs: 1, + hasGlobalModel: false, + hasOutputName: true, + }) + + rtCfg.MaxQueueSize = 5 + rtCfg.Workers = 1 + + _, router, mockEvaluators, _, params := prepareTestRouter(t, + predictTestCase{ + routerConfig: rtCfg, + stringInputs: [][]string{ + {"1", "2", "3", "4", "5", "6"}, + {"1", "2", "3", "4", "5", "6"}, + }, + // hack to work around how output signature is built + expectedOutputs: [][]float32{ + {1}, + }, + }) + + doneGroup := &sync.WaitGroup{} + for _, evaluator := range mockEvaluators { + doneGroup.Add(1) + + evaluator.waitFor = &sync.WaitGroup{} + evaluator.doneGroup = doneGroup + evaluator.waitFor.Add(1) } - newConfig := &sharedrouter.RouterConfig{ - EntityMapping: []sharedrouter.EntityKV{ - {EntityID: 1, ModelName: "modelA"}, - {EntityID: 2, ModelName: "modelC"}, - }, + errCh := make(chan error, 10) + + var foundError uint32 + foundErrorLock := &sync.WaitGroup{} + foundErrorLock.Add(1) + + ctx := context.Background() + runPredictWG := &sync.WaitGroup{} + for pi := range 10 { + runPredictWG.Add(1) + go func() { + defer runPredictWG.Done() + _, err := router.Predict(ctx, params) + log.Printf("predict %d error: %v", pi, err) + if err != nil { + if atomic.CompareAndSwapUint32(&foundError, 0, 1) { + foundErrorLock.Done() + } + + errCh <- err + } + }() + } + + unlockedCh := make(chan struct{}, 1) + + go func() { + foundErrorLock.Wait() + unlockedCh <- struct{}{} + }() + + dl, ok := t.Deadline() + boundCtx := ctx + if ok { + var cancel context.CancelFunc + boundCtx, cancel = context.WithDeadline(ctx, dl) + defer cancel() } - if err := router.applyRouterConfig(ctx, newConfig); err != nil { - t.Fatalf("applyRouterConfig returned error: %v", err) + select { + case <-boundCtx.Done(): + t.Fatalf("test timed out") + + case <-unlockedCh: + // positive case } - if loads := mockClient.snapshotLoadCalls(); len(loads) != 0 { - t.Fatalf("expected no loads when model is ready, got %v", loads) + for _, evaluator := range mockEvaluators { + // unblock evaluators + evaluator.waitFor.Done() } - readyCalls := mockClient.snapshotReadyCalls() - if len(readyCalls) != 1 || readyCalls[0] != "modelC" { - t.Fatalf("expected readiness check for modelC, got %v", readyCalls) + // wait for evaluators to finish + doneGroup.Wait() + runPredictWG.Wait() + + close(errCh) + + foundQueueSizeError := false + for e := range errCh { + if e != nil { + if strings.Contains(e.Error(), queueSizeExceededError) { + foundQueueSizeError = true + } else { + t.Fatalf("Predict error: %v", e) + } + } } - if _, ok := router.routingTable["modelC"]; !ok { - t.Fatalf("routingTable missing modelC after reload") + if !foundQueueSizeError { + t.Fatalf("queue size exceeded not found") } } diff --git a/service/platform/router/worker.go b/service/platform/router/worker.go deleted file mode 100644 index 4acc056..0000000 --- a/service/platform/router/worker.go +++ /dev/null @@ -1,62 +0,0 @@ -package router - -import ( - "context" - "fmt" - "log" - "sync" - "time" - - "github.com/prometheus/client_golang/prometheus" - "github.com/viant/mly/service/platform" -) - -type workRequest struct { - wg *sync.WaitGroup - - predictor platform.Predictor - ctx context.Context - request []interface{} - - queuedTime time.Time - offset int - modelOutputEnabled bool - routingValueString string - - responseCh chan offsetResults - errCh chan error -} - -type offsetResults struct { - offset int - results []interface{} -} - -func handleWorkRequests(workCh chan *workRequest, observer prometheus.Observer) { - for request := range workCh { - if request == nil { - log.Println("work request is nil, stopping") - break - } - - func(request workRequest) { - observer.Observe(float64(time.Since(request.queuedTime).Microseconds())) - - defer request.wg.Done() - results, err := request.predictor.Predict(request.ctx, request.request) - if err != nil { - request.errCh <- fmt.Errorf("failed to predict for row %d: %w", request.offset, err) - return - } - - if request.modelOutputEnabled { - results = append(results, [][]string{{request.routingValueString}}) - } - - request.responseCh <- offsetResults{ - offset: request.offset, - results: results, - } - }(*request) - } -} diff --git a/service/request/request.go b/service/request/request.go index 92182f4..2225070 100644 --- a/service/request/request.go +++ b/service/request/request.go @@ -21,14 +21,19 @@ type Request struct { // Passed through to Evaluator. // This is expected to be [numInputs]([batchSize][1]T). + // TODO: This shape is fixed, and should be addressed. + // Also, the fact that it is a []interface{} is a TensorFlow concern; ideally it should be map[string]interface{}. Feeds []interface{} supplied map[string]struct{} // used to check if the required inputs were provided - Input *transfer.Input // cache metadata + // Input is primarily used with Transformer. + Input *transfer.Input // type metadata from service/tfservice.Service.inputs // see service/tfmodel.(*Service).reconcileIOFromSignature + // The key is expected to be the name of the input. + // Uses fields Index, Auxiliary, Type, and Name (for debugging) inputs map[string]*domain.Input } @@ -57,45 +62,47 @@ func (r *Request) Put(key string, value string) error { return nil } + inputIndex := input.Index + switch input.Type.Kind() { case reflect.String: - r.Feeds[input.Index] = [][]string{{value}} + r.Feeds[inputIndex] = [][]string{{value}} case reflect.Bool: val, err := strconv.ParseBool(value) if err != nil { return fmt.Errorf("failed to parse bool: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]bool{{val}} + r.Feeds[inputIndex] = [][]bool{{val}} case reflect.Int: val, err := strconv.ParseInt(value, 10, 64) if err != nil { return fmt.Errorf("failed to parse int: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]int{{int(val)}} + r.Feeds[inputIndex] = [][]int{{int(val)}} case reflect.Int32: val, err := strconv.ParseInt(value, 10, 32) if err != nil { return fmt.Errorf("failed to parse int32: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]int32{{int32(val)}} + r.Feeds[inputIndex] = [][]int32{{int32(val)}} case reflect.Int64: val, err := strconv.ParseInt(value, 10, 64) if err != nil { return fmt.Errorf("failed to parse int64: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]int64{{val}} + r.Feeds[inputIndex] = [][]int64{{val}} case reflect.Float64: val, err := strconv.ParseFloat(value, 64) if err != nil { return fmt.Errorf("failed to parse float64: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]float64{{val}} + r.Feeds[inputIndex] = [][]float64{{val}} case reflect.Float32: val, err := strconv.ParseFloat(value, 32) if err != nil { return fmt.Errorf("failed to parse float32: '%v' for %v, %w", val, key, err) } - r.Feeds[input.Index] = [][]float32{{float32(val)}} + r.Feeds[inputIndex] = [][]float32{{float32(val)}} default: // TODO add more type support return fmt.Errorf("unsupported input type: %T", reflect.New(input.Type).Interface()) @@ -143,7 +150,14 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { } r.supplied[key] = exists - inputValue, err := r.Input.SetAt(input.Index, input.Name, input.Type.Kind()) + + inputIndex := input.Index + + if inputIndex >= len(r.Feeds) && !input.Auxiliary { + return fmt.Errorf("non-aux input %s index %d is out of range for %d feeds", input.Name, inputIndex, len(r.Feeds)) + } + + inputValue, err := r.Input.SetAt(inputIndex, input.Name, input.Type.Kind()) if err != nil { return err } @@ -153,7 +167,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = inputValue.Feed(r.Input.BatchSize) + r.Feeds[inputIndex] = inputValue.Feed(r.Input.BatchSize) } return nil } @@ -168,7 +182,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { } _ = inputValue.Set(value) if !input.Auxiliary { - r.Feeds[input.Index] = [][]string{{value}} + r.Feeds[inputIndex] = [][]string{{value}} } case reflect.Bool: var value bool @@ -176,7 +190,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]bool{{value}} + r.Feeds[inputIndex] = [][]bool{{value}} } _ = inputValue.Set(value) case reflect.Int: @@ -185,7 +199,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]int{{value}} + r.Feeds[inputIndex] = [][]int{{value}} } _ = inputValue.Set(value) case reflect.Int32: @@ -194,7 +208,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]int32{{value}} + r.Feeds[inputIndex] = [][]int32{{value}} } _ = inputValue.Set(value) case reflect.Int64: @@ -203,7 +217,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]int64{{value}} + r.Feeds[inputIndex] = [][]int64{{value}} } _ = inputValue.Set(value) case reflect.Float64: @@ -212,7 +226,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]float64{{value}} + r.Feeds[inputIndex] = [][]float64{{value}} } _ = inputValue.Set(value) case reflect.Float32: @@ -221,7 +235,7 @@ func (r *Request) UnmarshalJSONObject(dec *gojay.Decoder, key string) error { return err } if !input.Auxiliary { - r.Feeds[input.Index] = [][]float32{{float32(value)}} + r.Feeds[inputIndex] = [][]float32{{float32(value)}} } _ = inputValue.Set(value) default: diff --git a/service/request/request_test.go b/service/request/request_test.go index 76f9eee..055097a 100644 --- a/service/request/request_test.go +++ b/service/request/request_test.go @@ -21,15 +21,14 @@ func TestDecode(t *testing.T) { }, } - inputs := make(map[string]*domain.Input, len(modelInputs)) + numInputs := len(modelInputs) + inputs := make(map[string]*domain.Input, numInputs) for i, modelInput := range modelInputs { modelInput.Index = i inputs[modelInput.Name] = modelInput } - numInputs := len(modelInputs) - testCases := []struct { desc string requestEnc string @@ -41,7 +40,7 @@ func TestDecode(t *testing.T) { requestEnc: `{ "batch_size": 1, "a2": ["a2_0"], - "a1": ["a1_0"], + "a1": ["a1_0"], "a3": ["a3_0"], "cache_key": ["ck1"], }`, @@ -61,7 +60,7 @@ func TestDecode(t *testing.T) { desc: "invalid", requestEnc: `{ "batch_size": 1, - "a1": ["a1_0"], + "a1": ["a1_0"], "a3": ["a3_0"], "cache_key": ["ck1"], }`, @@ -71,9 +70,9 @@ func TestDecode(t *testing.T) { desc: "duplicate_aux", requestEnc: `{ "batch_size": 1, - "a1": ["a1_0"], - "a2": ["a1_0"], - "a3": ["a3_0"], + "a1": ["a1_0"], + "a2": ["a1_0"], + "a3": ["a3_0"], "a3": ["a3_1"], "cache_key": ["ck1"], }`, @@ -83,9 +82,9 @@ func TestDecode(t *testing.T) { desc: "duplicate_input", requestEnc: `{ "batch_size": 1, - "a1": ["a1_0"], - "a2": ["a2_0"], - "a2": ["a2_1"], + "a1": ["a1_0"], + "a2": ["a2_0"], + "a2": ["a2_1"], "a3": ["a3_0"], "cache_key": ["ck1"], }`, @@ -101,8 +100,8 @@ func TestDecode(t *testing.T) { desc: "bad_batch_expansion", requestEnc: `{ "batch_size": 2, - "a1": ["a1_0"], - "a2": ["a2_0", "a2_1"], + "a1": ["a1_0"], + "a2": ["a2_0", "a2_1"], "a3": ["a3_0", "a3_1"], "cache_key": ["ck1", "ck2"], }`, @@ -130,10 +129,7 @@ func TestDecode(t *testing.T) { } for _, tc := range testCases { - r := &Request{ - inputs: inputs, - Feeds: make([]interface{}, numInputs, numInputs), - } + r := NewRequest(numInputs, inputs) err := gojay.Unmarshal([]byte(tc.requestEnc), r) diff --git a/service/request/shape/batch.go b/service/request/shape/batch.go index 3f62604..4351b4a 100644 --- a/service/request/shape/batch.go +++ b/service/request/shape/batch.go @@ -1,6 +1,9 @@ package shape -import "fmt" +import ( + "fmt" + "reflect" +) // DetermineBatchSize determines batch size from a service.Request.Feeds slice. func DetermineBatchSize(inputs []interface{}) (int, error) { @@ -66,7 +69,96 @@ func Debatch(untypedBatch interface{}, i int) (interface{}, error) { return nil, fmt.Errorf("unexpected batch type: %T", untypedBatch) } -// concatAxis0 concatenates two tensors along axis 0 (batch dimension). +// AppendRowToBatch appends a single debatched row to an accumulating batch. +// If the accumulator is nil, it initializes it with the row's type. +// The row is expected to be in debatched form: [][]T with shape [1][1]. +// The accumulator will grow to shape [N][1] after N appends. +func AppendRowToBatch(accumulator interface{}, row interface{}) (interface{}, error) { + if accumulator == nil { + // Initialize with the row (already in correct shape [1][1]) + return row, nil + } + + switch accTyped := accumulator.(type) { + case [][]int32: + rowTyped, ok := row.([][]int32) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]int32, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]int64: + rowTyped, ok := row.([][]int64) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]int64, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]float32: + rowTyped, ok := row.([][]float32) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]float32, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]float64: + rowTyped, ok := row.([][]float64) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]float64, row is %T", row) + } + return append(accTyped, rowTyped...), nil + case [][]string: + rowTyped, ok := row.([][]string) + if !ok { + return nil, fmt.Errorf("type mismatch: accumulator is [][]string, row is %T", row) + } + return append(accTyped, rowTyped...), nil + default: + return nil, fmt.Errorf("unsupported accumulator type: %T", accumulator) + } +} + +// ExtractRowFromBatch extracts a single row from a batch at the given index. +// The batch is expected to have shape [N][M] and the result will have shape [1][M]. +func ExtractRowFromBatch(batch interface{}, index int) (interface{}, error) { + switch typedBatch := batch.(type) { + case [][]int32: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]int32{typedBatch[index]}, nil + case [][]int64: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]int64{typedBatch[index]}, nil + case [][]float32: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]float32{typedBatch[index]}, nil + case [][]float64: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]float64{typedBatch[index]}, nil + case [][]string: + if index >= len(typedBatch) { + return nil, fmt.Errorf("index %d out of range for batch of size %d", index, len(typedBatch)) + } + return [][]string{typedBatch[index]}, nil + default: + return nil, fmt.Errorf("unsupported batch type: %T", batch) + } +} + +// BatchSize returns the batch size (first dimension) of a batch tensor. +func BatchSize(batch interface{}) (int, error) { + val := reflect.ValueOf(batch) + if val.Kind() != reflect.Slice { + return 0, fmt.Errorf("expected slice, got %T", batch) + } + return val.Len(), nil +} + +// ConcatAxis0 concatenates two tensors along axis 0 (batch dimension). func ConcatAxis0(x []interface{}, y []interface{}) ([]interface{}, error) { if len(x) != len(y) { return nil, fmt.Errorf("x and y must have the same length: %d vs %d", len(x), len(y)) diff --git a/service/request/shape/batch_test.go b/service/request/shape/batch_test.go index 9b43628..82706d5 100644 --- a/service/request/shape/batch_test.go +++ b/service/request/shape/batch_test.go @@ -88,3 +88,132 @@ func TestDebatchAndSqueezeBatch_Int64(t *testing.T) { t.Errorf("squeezeBatch() got %v (%T), want 20 (int64)", scalar, scalar) } } + +func TestAppendRowToBatch_String(t *testing.T) { + // Start with nil accumulator + row1 := [][]string{{"hello"}} + acc, err := AppendRowToBatch(nil, row1) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want1 := [][]string{{"hello"}} + if !reflect.DeepEqual(acc, want1) { + t.Errorf("after first append: got %#v, want %#v", acc, want1) + } + + // Append second row + row2 := [][]string{{"world"}} + acc, err = AppendRowToBatch(acc, row2) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want2 := [][]string{{"hello"}, {"world"}} + if !reflect.DeepEqual(acc, want2) { + t.Errorf("after second append: got %#v, want %#v", acc, want2) + } + + // Append third row + row3 := [][]string{{"foo"}} + acc, err = AppendRowToBatch(acc, row3) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want3 := [][]string{{"hello"}, {"world"}, {"foo"}} + if !reflect.DeepEqual(acc, want3) { + t.Errorf("after third append: got %#v, want %#v", acc, want3) + } +} + +func TestAppendRowToBatch_Float32(t *testing.T) { + row1 := [][]float32{{1.5}} + acc, err := AppendRowToBatch(nil, row1) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + row2 := [][]float32{{2.5}} + acc, err = AppendRowToBatch(acc, row2) + if err != nil { + t.Fatalf("AppendRowToBatch returned error: %v", err) + } + + want := [][]float32{{1.5}, {2.5}} + if !reflect.DeepEqual(acc, want) { + t.Errorf("got %#v, want %#v", acc, want) + } +} + +func TestAppendRowToBatch_TypeMismatch(t *testing.T) { + acc := [][]string{{"hello"}} + row := [][]int64{{123}} + + _, err := AppendRowToBatch(acc, row) + if err == nil { + t.Fatal("expected type mismatch error, got nil") + } +} + +func TestExtractRowFromBatch_String(t *testing.T) { + batch := [][]string{{"a"}, {"b"}, {"c"}} + + row, err := ExtractRowFromBatch(batch, 1) + if err != nil { + t.Fatalf("ExtractRowFromBatch returned error: %v", err) + } + + want := [][]string{{"b"}} + if !reflect.DeepEqual(row, want) { + t.Errorf("got %#v, want %#v", row, want) + } +} + +func TestExtractRowFromBatch_Float32(t *testing.T) { + batch := [][]float32{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}} + + row, err := ExtractRowFromBatch(batch, 2) + if err != nil { + t.Fatalf("ExtractRowFromBatch returned error: %v", err) + } + + want := [][]float32{{5.0, 6.0}} + if !reflect.DeepEqual(row, want) { + t.Errorf("got %#v, want %#v", row, want) + } +} + +func TestExtractRowFromBatch_OutOfRange(t *testing.T) { + batch := [][]int64{{1}, {2}} + + _, err := ExtractRowFromBatch(batch, 5) + if err == nil { + t.Fatal("expected out of range error, got nil") + } +} + +func TestBatchSize(t *testing.T) { + tests := []struct { + name string + batch interface{} + want int + }{ + {"string batch", [][]string{{"a"}, {"b"}, {"c"}}, 3}, + {"float32 batch", [][]float32{{1}, {2}}, 2}, + {"int64 batch", [][]int64{{1}}, 1}, + {"empty batch", [][]string{}, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := BatchSize(tt.batch) + if err != nil { + t.Fatalf("BatchSize returned error: %v", err) + } + if got != tt.want { + t.Errorf("got %d, want %d", got, tt.want) + } + }) + } +} diff --git a/service/service.go b/service/service.go index a21ff1c..3b94133 100644 --- a/service/service.go +++ b/service/service.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/viant/afs" "github.com/viant/gmetric" "github.com/viant/gtly" @@ -18,19 +19,16 @@ import ( serrs "github.com/viant/mly/service/errors" "github.com/viant/mly/service/gtlyop" "github.com/viant/mly/service/platform" - "github.com/viant/mly/service/platform/factory" "github.com/viant/mly/service/request" "github.com/viant/mly/service/stat" "github.com/viant/mly/service/stream" "github.com/viant/mly/service/transform" - "github.com/viant/mly/service/triton" "github.com/viant/mly/shared" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/common/storable" "github.com/viant/mly/shared/datastore" sstat "github.com/viant/mly/shared/stat" "github.com/viant/xunsafe" - "golang.org/x/sync/semaphore" ) // Service serves as the entrypoint for using the ML model. @@ -45,13 +43,15 @@ type Service struct { // continueOnRecover if false, will re-panic on recover continueOnRecover bool - // TODO how does this interact with Service.inputs + // inputProvider is used in transformer.Transform() inputProvider *gtly.Provider // health status for centralized health reporting - // Deprecated: use GetHealth() instead + // Deprecated: use GetHealth() or healthGauge ReloadOK int32 + healthGauge prometheus.Gauge + reloadPollTicker *time.Ticker reloadTimeout time.Duration @@ -64,7 +64,6 @@ type Service struct { // outputs transformer domain.Transformer - newStorable func() common.Storable // serviceMetric measures validate + model + transformer serviceMetric *gmetric.Operation @@ -93,6 +92,8 @@ func (s *Service) Config() *config.Model { return s.config } +// Signature is invoked after at least 1 successful ReloadIfNeeded(). +// Considered hot path. func (s *Service) Signature() *domain.Signature { return s.evaluator.Signature() } @@ -306,6 +307,14 @@ func (s *Service) initializeService(ctx context.Context, cfg *config.Model, fs a } atomic.StoreInt32(&s.ReloadOK, 1) + if s.healthGauge != nil { + s.healthGauge.Set(1) + } + + signature := s.Signature() + if signature == nil { + return fmt.Errorf("signature could not be determined") + } s.transformer, err = transform.Get(cfg.Transformer) if err != nil { @@ -313,7 +322,7 @@ func (s *Service) initializeService(ctx context.Context, cfg *config.Model, fs a } if err = s.initDatastore(cfg, datastores); err != nil { - return err + return fmt.Errorf("failed to initialize datastore: %w", err) } if cfg.Stream != nil { @@ -336,59 +345,6 @@ func (s *Service) initializeService(ctx context.Context, cfg *config.Model, fs a return nil } -// New creates a service with platform router support -func New( - ctx context.Context, - cfg *config.Model, - fs afs.Service, - metrics *gmetric.Service, - datastores map[string]*datastore.Service, - tritonClients map[string]triton.TritonClient, - sema *semaphore.Weighted, - maxEvaluatorWait time.Duration, - options ...Option, -) (*Service, error) { - - if metrics == nil { - metrics = gmetric.New() - } - - location := reflect.TypeOf(Service{}).PkgPath() - - cfg.Init(nil) - - // Create platform evaluator context - evaluatorContext, err := factory.CreateEvaluator(cfg, fs, metrics, sema, maxEvaluatorWait, tritonClients) - if err != nil { - return nil, fmt.Errorf("failed to create platform evaluator for model %s: %w", cfg.ID, err) - } - - srv := &Service{ - config: cfg, - evaluator: evaluatorContext, - useDatastore: cfg.UseDictionary() && cfg.DataStore != "", - serviceMetric: metrics.MultiOperationCounter(location, cfg.ID+"Perf", cfg.ID+" service performance", time.Microsecond, time.Minute, 2, stat.NewProvider()), - reloadPollTicker: time.NewTicker(time.Duration(cfg.ReloadPollIntervalSeconds) * time.Second), - reloadTimeout: time.Duration(cfg.ReloadTimeoutSeconds) * time.Second, - } - - // Set up reload metrics for platforms that support reloading - srv.reloadMetric = metrics.MultiOperationCounter(location, cfg.ID+"Reload", cfg.ID+" reloading", time.Microsecond, time.Minute, 1, sstat.NewCtxErrOnly()) - - for _, opt := range options { - opt.Apply(srv) - } - - err = srv.initializeService(ctx, cfg, fs, metrics, datastores) - if err != nil { - return nil, err - } - - go srv.pollModelReload() - - return srv, err -} - // NewRequest should be used for Do() func (s *Service) NewRequest() *request.Request { numKeyInputs := s.config.KeysLen() @@ -407,10 +363,11 @@ func (s *Service) initDatastore(cfg *config.Model, datastores map[string]*datast signature := s.Signature() if signature == nil { - return fmt.Errorf("signature was emtpy") + return fmt.Errorf("signature was not provided") } if len(cfg.KeyFields) == 0 { + // add all inputs from model signature as a key field for _, input := range signature.Inputs { cfg.KeyFields = append(cfg.KeyFields, input.Name) } @@ -433,10 +390,6 @@ func (s *Service) initDatastore(cfg *config.Model, datastores map[string]*datast _ = datastoreConfig.FieldsDescriptor(fields) } - if s.newStorable == nil { - s.newStorable = getStorable(datastoreConfig) - } - return nil } @@ -448,33 +401,40 @@ func (s *Service) GetHealth() int32 { func (s *Service) pollModelReload() { for range s.reloadPollTicker.C { - ctx, cancel := context.WithTimeout(context.Background(), s.reloadTimeout) - defer cancel() - - stats := sstat.NewValues() - if s.reloadMetric != nil { - onDone := s.reloadMetric.Begin(time.Now()) - defer func() { - onDone(time.Now(), stats.Values()...) - }() - } - - var reloadOK int32 - err := s.evaluator.ReloadIfNeeded(ctx) - if err != nil { - stats.AppendError(err) - log.Printf("[%s reload] failed to reload model:%v", s.config.ID, err) - - reloadOK = 0 - } else { - reloadOK = 1 - } - - atomic.StoreInt32(&s.ReloadOK, reloadOK) if atomic.LoadInt32(&s.closed) != 0 { log.Printf("[%s reload] shutting down, stopping reload loop", s.config.ID) return } + + func() { + ctx, cancel := context.WithTimeout(context.Background(), s.reloadTimeout) + defer cancel() + + stats := sstat.NewValues() + if s.reloadMetric != nil { + onDone := s.reloadMetric.Begin(time.Now()) + defer func() { + onDone(time.Now(), stats.Values()...) + }() + } + + var reloadOK int32 + err := s.evaluator.ReloadIfNeeded(ctx) + if err != nil { + stats.AppendError(err) + log.Printf("[%s reload] failed to reload model:%v", s.config.ID, err) + + reloadOK = 0 + } else { + reloadOK = 1 + } + + if s.healthGauge != nil { + s.healthGauge.Set(float64(reloadOK)) + } + + atomic.StoreInt32(&s.ReloadOK, reloadOK) + }() } } diff --git a/service/stat/handler.go b/service/stat/handler.go new file mode 100644 index 0000000..e62e98a --- /dev/null +++ b/service/stat/handler.go @@ -0,0 +1,106 @@ +package stat + +import ( + "github.com/viant/gmetric/counter" + "github.com/viant/mly/shared/stat" +) + +const ( + // ResponseMarshalErrorKey counts requests where the prediction succeeded + // but gojay.Marshal of the Response struct failed. The HTTP response was + // NOT committed: ServeHTTP recovers by emitting an explicit 500. + ResponseMarshalErrorKey = "responseMarshalError" + + // ResponseCommittedErrorKey counts requests where status + headers were + // already flushed to the client (200 OK) when the body Write failed. + // This is the server-side counterpart to the bidder-observed + // "200 OK + empty/truncated body → invalid_json" failure mode. + // A non-zero rate here indicates either: + // - clients are closing the connection mid-response (most common + // under load when client deadline < server response time), or + // - HTTP/1.1 keepalive desync producing broken pipes on reuse. + // Distinct from ErrorKey so it can be alerted independently. + ResponseCommittedErrorKey = "responseCommittedError" +) + +// ResponseMarshalError is a stat marker for the gmetric provider. The +// embedded error is retained for top-K error sampling; the struct itself +// is intentionally NOT an `error` so the type-switch in Map can route it +// to its own bucket without colliding with the generic error case. +type ResponseMarshalError struct{ Error error } + +// String implements fmt.Stringer (used by gmetric top-K error sampling). +func (r ResponseMarshalError) String() string { return r.Error.Error() } + +// Aggregate implements github.com/viant/gmetric/counter.CustomCounter. +func (r ResponseMarshalError) Aggregate(interface{}) {} + +// ResponseCommittedError is the analogous stat marker for post-commit +// write failures. See ResponseCommittedErrorKey for the operational +// significance. +type ResponseCommittedError struct{ Error error } + +func (r ResponseCommittedError) String() string { return r.Error.Error() } +func (r ResponseCommittedError) Aggregate(interface{}) {} + +// handler is the gmetric counter.Provider for service.Handler.ServeHTTP. +// It is a strict superset of shared/stat.NewCtxErrOnly(): the first three +// keys (ErrorKey, Canceled, DeadlineExceeded) preserve their indices so +// any existing Prometheus dashboards/alerts on those buckets continue to +// emit at the same labels. Two new keys are appended for the explicit +// response-write failure classes introduced by the explicit-commit +// refactor of writeResponse. +type handler struct{} + +// Keys returns the stat key labels in stable index order. Order matters: +// gmetric's counter buckets are addressed by index, and changing the +// order of the first three would silently re-label existing series. +func (h handler) Keys() []string { + return []string{ + stat.ErrorKey, // 0 + stat.Canceled, // 1 + stat.DeadlineExceeded, // 2 + ResponseMarshalErrorKey, // 3 + ResponseCommittedErrorKey, // 4 + } +} + +// Map routes a value to its key index. Concrete struct cases come BEFORE +// the generic `error` case to ensure typed stat markers route to their +// dedicated buckets even if a future change makes them satisfy `error`. +func (h handler) Map(value interface{}) int { + if value == nil { + return -1 + } + + if _, ok := value.(ResponseMarshalError); ok { + return 3 + } + if _, ok := value.(ResponseCommittedError); ok { + return 4 + } + + switch v := value.(type) { + case error: + return 0 + case string: + switch v { + case stat.Canceled: + return 1 + case stat.DeadlineExceeded: + return 2 + case ResponseMarshalErrorKey: + return 3 + case ResponseCommittedErrorKey: + return 4 + } + } + + return -1 +} + +// NewHandler returns the counter.Provider used by Handler.ServeHTTP's +// per-request httpContextMetrics. +func NewHandler() counter.Provider { + return handler{} +} diff --git a/service/storable.go b/service/storable.go deleted file mode 100644 index 0d8f3c9..0000000 --- a/service/storable.go +++ /dev/null @@ -1,18 +0,0 @@ -package service - -import ( - "github.com/viant/mly/shared/common" - "github.com/viant/mly/shared/common/storable" - "github.com/viant/mly/shared/config" -) - -func getStorable(cfg *config.Datastore) func() common.Storable { - result, err := storable.Singleton().Lookup(cfg.Storable) - if err == nil && result != nil { - return result - } //otherwise return default storable - - return func() common.Storable { - return storable.New(cfg.Fields) - } -} diff --git a/service/transform/transformer.go b/service/transform/transformer.go index 4120c4b..8437842 100644 --- a/service/transform/transformer.go +++ b/service/transform/transformer.go @@ -19,7 +19,3 @@ func Get(name string) (domain.Transformer, error) { // otherwise return default transformer return domain.Transform, nil } - -func ExecuteTransform() interface{} { - return nil -} diff --git a/service/triton/client.go b/service/triton/client.go index 3d01add..f19502a 100644 --- a/service/triton/client.go +++ b/service/triton/client.go @@ -12,21 +12,42 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +// A TritonClient represents a client to a single Triton server. type TritonClient interface { ServerReady(ctx context.Context) error // inputs is expected to be [numInputs]([batchSize][1]T) (see service/request.Request.Feeds) + // inputs will never be empty ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) ModelReady(ctx context.Context, modelName string) (bool, error) ModelLoad(ctx context.Context, modelName string) error - ModelUnload(ctx context.Context, modelName string) error + ModelUnloader + + ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) Close() error } +type ModelUnloader interface { + ModelUnload(ctx context.Context, modelName string) error +} + +// https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#model-metadata-response-json-object `$metadata_tensor` +type MetadataTensor struct { + Name string `json:"name"` + Datatype string `json:"datatype"` + Shape []int64 `json:"shape"` +} + +// stripped down version of https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#model-metadata-response-json-object +type ModelMetadata struct { + Inputs []MetadataTensor `json:"inputs"` + Outputs []MetadataTensor `json:"outputs"` +} + // NewClient creates either an HTTP or gRPC client. func NewClient(server config.TritonServer) (TritonClient, error) { if server.GRPCBaseURL != "" { @@ -53,7 +74,6 @@ func NewClient(server config.TritonServer) (TritonClient, error) { } // HTTP options seem a bit bare - // TODO see if DRY with return NewMeteredTritonClient(&HTTPClient{ httpClient: &http.Client{ Timeout: time.Duration(server.HTTPClientTimeoutMs) * time.Millisecond, diff --git a/service/triton/evaluator.go b/service/triton/evaluator.go new file mode 100644 index 0000000..5831bed --- /dev/null +++ b/service/triton/evaluator.go @@ -0,0 +1,283 @@ +package triton + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/viant/mly/service/config" + "github.com/viant/mly/service/domain" + "github.com/viant/mly/shared" + "github.com/viant/mly/shared/common" +) + +// TritonEvaluator implements service/platform.PlatformEvaluator. +type TritonEvaluator struct { + service *Service + modelName string + + // if true, this client is used only for this instance + isPrivateClient bool + repositoryExplicit bool + + timeout time.Duration + + modelID string + debug bool + + signature *domain.Signature + + // maps Feeds index to input name + indexToName map[int]string + + configuredInputs []*shared.Field + + inputs map[string]*domain.Input +} + +// NewTritonEvaluator creates a new Triton evaluator +func NewTritonEvaluator(config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { + evaluator, err := createEvaluator(config, tritonClients) + if err != nil { + return nil, err + } + err = evaluator.registerUsage() + if err != nil { + return nil, fmt.Errorf("failed to register usage for Triton evaluator: %w", err) + } + + return evaluator, nil +} + +func createEvaluator(config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { + var service *Service + + isPrivateClient := config.URL != "" + timeout := time.Duration(config.Triton.Timeout) * time.Millisecond + + if isPrivateClient { + // "Private" URL configuration will only support HTTP + client := &HTTPClient{ + httpClient: &http.Client{ + Timeout: timeout, + }, + serverURL: config.URL, + debug: config.Debug, + } + + service = &Service{ + Client: client, + } + } else { + service = tritonClients[config.Triton.ServerID] + if service == nil { + return nil, fmt.Errorf("client not found for Triton, server ID: %s", config.Triton.ServerID) + } + } + + evaluator := &TritonEvaluator{ + service: service, + + modelName: config.Triton.ModelName, + timeout: timeout, + + isPrivateClient: isPrivateClient, + + // clients defined in TritonServers are assumed to be in EXPLICIT mode + repositoryExplicit: !isPrivateClient || config.Triton.RepositoryExplicit, + + configuredInputs: config.MetaInput.Inputs, + + modelID: config.ID, + debug: config.Debug, + } + + return evaluator, nil + +} + +// Upward dependency. +// Provides Evaluators as needed for the service/platform/router module. +func NewRoutedTritonEvaluator(modelName string, config *config.Model, tritonClients map[string]*Service) (*TritonEvaluator, error) { + evaluator, err := createEvaluator(config, tritonClients) + if err != nil { + return nil, fmt.Errorf("failed to create Triton Routed evaluator: %w", err) + } + + evaluator.modelName = modelName + evaluator.configuredInputs = nil // routed evaluators must not have any additional inputs + + err = evaluator.registerUsage() + if err != nil { + return nil, fmt.Errorf("failed to register usage for Triton Routed evaluator: %w", err) + } + + return evaluator, nil +} + +func (t *TritonEvaluator) registerUsage() error { + if t.modelName == "" { + return fmt.Errorf("model name is required for registering usage") + } + + t.service.RegisterUsage(t.modelID, t.modelName) + return nil +} + +// Predict performs inference via Triton Inference Server +func (t *TritonEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { + if len(params) == 0 { + return nil, fmt.Errorf("no input parameters") + } + + requestCtx := ctx + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(ctx, t.timeout) + defer cancel() + } + + return t.service.Client.ModelInfer(requestCtx, t.modelName, params, t.indexToName) +} + +func (t *TritonEvaluator) Signature() *domain.Signature { + return t.signature +} + +func (t *TritonEvaluator) Dictionary() *common.Dictionary { + // no dictionary + return nil +} + +func (t *TritonEvaluator) Stats(stats map[string]interface{}) { + // no stats +} + +func (t *TritonEvaluator) Inputs() map[string]*domain.Input { + return t.inputs +} + +// Close releases Triton client resources and stops health monitoring +func (t *TritonEvaluator) Close() error { + if t.isPrivateClient { + return t.service.Client.Close() + } + + return nil +} + +// For independent Triton server models, reloading is not supported. +func (t *TritonEvaluator) ReloadIfNeeded(ctx context.Context) error { + ready, err := t.service.Client.ModelReady(ctx, t.modelName) + if err != nil { + return fmt.Errorf("failed to check Triton model %s health: %w", t.modelName, err) + } + + if ready && t.signature != nil { + // only a health check + return nil + } + + if !ready { + if !t.repositoryExplicit { + return fmt.Errorf("model %s not ready and Triton is not in EXPLICIT Model Control Mode: %w", t.modelName, err) + } + + err = t.service.Client.ModelLoad(ctx, t.modelName) + if err != nil { + return fmt.Errorf("failed to load Triton model %s: %w", t.modelName, err) + } + + ready, err = t.service.Client.ModelReady(ctx, t.modelName) + if err != nil { + return fmt.Errorf("failed to check Triton model %s health after loading: %w", t.modelName, err) + } + } + + if !ready { + return fmt.Errorf("model %s is not ready after loading", t.modelName) + } + + // we need to get the model metadata and consolidate the signature + metadata, err := t.service.Client.ModelMetadata(ctx, t.modelName) + if err != nil || metadata == nil { + return fmt.Errorf("failed to get Triton model %s metadata: %w", t.modelName, err) + } + + mappedInputs := make(map[string]*domain.Input) + indexedInputNames := make(map[int]string) + + signatureInputs := make([]domain.Input, len(metadata.Inputs)) + for i, input := range metadata.Inputs { + goType := TritonToGoType(input.Datatype) + di := domain.Input{ + Name: input.Name, + // for now, since the request provides a []interface{}, we populate the Index + Index: i, + Type: goType, + Vocab: false, + Auxiliary: false, + } + + if t.debug { + log.Printf("[%s] Triton[%s] input:%s index:%d datatype:%s goType:%s", + t.modelID, t.modelName, input.Name, di.Index, input.Datatype, goType.Name()) + } + + signatureInputs[i] = di + mappedInputs[input.Name] = &di + indexedInputNames[i] = input.Name + } + + t.indexToName = indexedInputNames + + for _, input := range t.configuredInputs { + iName := input.Name + if _, ok := mappedInputs[iName]; !ok { + goType, err := common.DataType(input.DataType) + if err != nil { + return fmt.Errorf("failed to get data type for %s: %w", iName, err) + } + + mappedInputs[iName] = &domain.Input{ + Name: iName, + Index: len(mappedInputs), + Type: goType, + Vocab: false, + Auxiliary: true, + } + + if t.debug { + log.Printf("[%s] Triton[%s] auxiliary input:%s goType:%s", + t.modelID, t.modelName, iName, goType.Name()) + } + } + } + + t.inputs = mappedInputs + + outputs := make([]domain.Output, len(metadata.Outputs)) + for i, output := range metadata.Outputs { + o := domain.Output{ + Name: output.Name, + Index: len(outputs), + } + + goType := TritonToGoType(output.Datatype) + o.SetType(goType) + o.DataType = goType.Name() + o.DataTypeKind = goType.Kind() + + outputs[i] = o + } + + t.signature = &domain.Signature{ + Inputs: signatureInputs, + Outputs: outputs, + Output: outputs[0], + } + + return nil +} diff --git a/service/triton/evaluator_test.go b/service/triton/evaluator_test.go new file mode 100644 index 0000000..b138556 --- /dev/null +++ b/service/triton/evaluator_test.go @@ -0,0 +1,206 @@ +package triton + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/mly/service/config" + "github.com/viant/mly/shared" +) + +type mockTritonClient struct { + mu sync.Mutex + + readyState map[string]bool + + unloadCh chan string + modelLoadErr map[string]error + + metadata *ModelMetadata +} + +func (m *mockTritonClient) ServerReady(ctx context.Context) error { return nil } + +func (m *mockTritonClient) ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) { + return nil, nil +} + +func (m *mockTritonClient) ModelReady(ctx context.Context, modelName string) (bool, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if ready, ok := m.readyState[modelName]; ok { + return ready, nil + } + + return true, nil +} + +func (m *mockTritonClient) ModelLoad(ctx context.Context, modelName string) error { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.modelLoadErr[modelName]; err != nil { + return err + } + + if m.readyState == nil { + m.readyState = make(map[string]bool) + } + + m.readyState[modelName] = true + + return nil +} + +func (m *mockTritonClient) ModelUnload(ctx context.Context, modelName string) error { + ch := m.unloadCh + if ch != nil { + ch <- modelName + } + return nil +} + +func (m *mockTritonClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + if m.metadata == nil { + m.metadata = &ModelMetadata{ + Inputs: []MetadataTensor{ + {Name: "input1", Datatype: "BYTES"}, + {Name: "input2", Datatype: "BYTES"}, + }, + Outputs: []MetadataTensor{ + {Name: "output1", Datatype: "FP32"}, + }, + } + } + + return m.metadata, nil +} + +func (m *mockTritonClient) Close() error { return nil } + +func newTritonEvaluator(cfg *config.Model, mockClient *mockTritonClient) *TritonEvaluator { + cfg.Triton.Init(cfg.IsRouter()) + + evaluator := &TritonEvaluator{ + modelName: cfg.Triton.ModelName, + isPrivateClient: true, + repositoryExplicit: false, + service: &Service{Client: mockClient}, + configuredInputs: cfg.MetaInput.Inputs, + } + + evaluator.ReloadIfNeeded(context.Background()) + + return evaluator +} + +func TestTritonEvaluator_Signature(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{}) + + defer evaluator.Close() + + sig := evaluator.Signature() + require.NotNil(t, sig) + assert.Equal(t, 2, len(sig.Inputs)) + assert.Equal(t, 1, len(sig.Outputs)) + assert.Equal(t, "input1", sig.Inputs[0].Name) + assert.Equal(t, "output1", sig.Outputs[0].Name) +} + +func TestTritonEvaluator_Inputs(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + MetaInput: shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "input1", Index: 0, DataType: "string"}, + {Name: "input2", Index: 0, DataType: "string"}, + {Name: "input_aux", Index: 0, DataType: "string", Auxiliary: true}, + }, + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{}) + + defer evaluator.Close() + + inputMap := evaluator.Inputs() + require.NotNil(t, inputMap) + assert.Equal(t, 3, len(inputMap)) + + i1 := inputMap["input1"] + assert.Equal(t, "input1", i1.Name) + assert.Equal(t, false, i1.Auxiliary) + assert.Equal(t, 0, i1.Index) + + i2 := inputMap["input2"] + assert.Equal(t, "input2", i2.Name) + assert.Equal(t, false, i2.Auxiliary) + assert.Equal(t, 1, i2.Index) + + iAux := inputMap["input_aux"] + assert.Equal(t, "input_aux", iAux.Name) + assert.Equal(t, true, iAux.Auxiliary) + assert.Equal(t, 2, iAux.Index) +} + +func TestTritonEvaluator_SignatureWithAuxiliaryInputs(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + MetaInput: shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "input_aux", Index: 0, DataType: "string", Auxiliary: true}, + }, + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{ + metadata: &ModelMetadata{ + Inputs: []MetadataTensor{ + {Name: "input1", Datatype: "BYTES"}, + }, + Outputs: []MetadataTensor{ + {Name: "output1", Datatype: "FP32"}, + }, + }, + }) + + defer evaluator.Close() + + sig := evaluator.Signature() + require.NotNil(t, sig) + assert.Equal(t, 1, len(sig.Inputs)) + assert.Equal(t, "input1", sig.Inputs[0].Name) +} + +func TestTritonEvaluator_PredictEmptyBatch(t *testing.T) { + cfg := &config.Model{ + ID: "test_model", + Triton: &config.TritonConfig{ + ModelName: "test_model", + }, + } + + evaluator := newTritonEvaluator(cfg, &mockTritonClient{}) + defer evaluator.Close() + + _, err := evaluator.Predict(context.Background(), []interface{}{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no input parameters") +} diff --git a/service/triton/grpc.go b/service/triton/grpc.go index 820238b..0b92cdf 100644 --- a/service/triton/grpc.go +++ b/service/triton/grpc.go @@ -24,26 +24,13 @@ func NewGRPCClient(grpcConn *grpc.ClientConn) *GRPCClient { } } -// preparedInput represents processed input data ready for gRPC transport -type preparedInput struct { - name string - datatype string // Triton datatype: "BYTES", "INT32", "INT64", "FP32", "FP64" - shape []int64 // Shape in int64 for gRPC compatibility - data interface{} // Flattened data: []string, []int32, []int64, []float32, []float64 -} - func (c *GRPCClient) ServerReady(ctx context.Context) error { _, err := c.grpcClient.ServerReady(ctx, &triton.ServerReadyRequest{}) return err } func (c *GRPCClient) ModelInfer(ctx context.Context, modelName string, inputs []interface{}, indexToName map[int]string) ([]interface{}, error) { - preparedInputs, err := prepareInputs(indexToName, inputs) - if err != nil { - return nil, err - } - - grpcRequest, err := buildGRPCRequest(modelName, preparedInputs) + grpcRequest, err := toGRPCRequest(modelName, inputs, indexToName) if err != nil { return nil, err } @@ -94,48 +81,144 @@ func (c *GRPCClient) ModelUnload(ctx context.Context, modelName string) error { return nil } +func (c *GRPCClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + grpcResponse, err := c.grpcClient.ModelMetadata(ctx, &triton.ModelMetadataRequest{ + Name: modelName, + }) + if err != nil { + return nil, err + } + return convertGRPCModelMetadataResponse(grpcResponse), nil +} + +func convertGRPCModelMetadataResponse(response *triton.ModelMetadataResponse) *ModelMetadata { + inputs := make([]MetadataTensor, len(response.Inputs)) + for i, input := range response.Inputs { + inputs[i] = MetadataTensor{ + Name: input.Name, + Datatype: input.Datatype, + Shape: input.Shape, + } + } + + outputs := make([]MetadataTensor, len(response.Outputs)) + for i, output := range response.Outputs { + outputs[i] = MetadataTensor{ + Name: output.Name, + Datatype: output.Datatype, + Shape: output.Shape, + } + } + + return &ModelMetadata{ + Inputs: inputs, + Outputs: outputs, + } +} + func (c *GRPCClient) Close() error { return c.grpcConn.Close() } -func buildGRPCRequest(modelName string, preparedInputs []preparedInput) (*triton.ModelInferRequest, error) { +func toGRPCRequest(modelName string, params []interface{}, indexToName map[int]string) (*triton.ModelInferRequest, error) { req := &triton.ModelInferRequest{ ModelName: modelName, - Inputs: make([]*triton.ModelInferRequest_InferInputTensor, len(preparedInputs)), + Inputs: make([]*triton.ModelInferRequest_InferInputTensor, len(params)), } - for i, input := range preparedInputs { - tensor := &triton.ModelInferRequest_InferInputTensor{ - Name: input.name, - Datatype: input.datatype, - Shape: input.shape, - Contents: &triton.InferTensorContents{}, + for i, param := range params { + inputName, exists := indexToName[i] + if !exists { + return nil, fmt.Errorf("no input name found for index %d", i) + } + + inputContents := &triton.InferTensorContents{} + + inputTensor := &triton.ModelInferRequest_InferInputTensor{ + Name: inputName, + Contents: inputContents, } - switch data := input.data.(type) { - case []string: - tensor.Contents.BytesContents = make([][]byte, len(data)) - for j, s := range data { - tensor.Contents.BytesContents[j] = []byte(s) + var batchSize int + var datatype string + + switch v := param.(type) { + case [][]string: + if len(v) > 0 { + batchSize = len(v) + datatype = "BYTES" + + inputContents.BytesContents = make([][]byte, batchSize) + for j := range batchSize { + inputContents.BytesContents[j] = []byte(v[j][0]) + } + } + case [][]int: + if len(v) > 0 { + batchSize := len(v) + datatype = "INT32" + + inputContents.IntContents = make([]int32, batchSize) + for j := range batchSize { + inputContents.IntContents[j] = int32(v[j][0]) + } + } + case [][]int32: + if len(v) > 0 { + batchSize := len(v) + datatype = "INT32" + + inputContents.IntContents = make([]int32, batchSize) + for j := range batchSize { + inputContents.IntContents[j] = v[j][0] + } + + } + case [][]int64: + if len(v) > 0 { + batchSize := len(v) + datatype = "INT64" + + inputContents.Int64Contents = make([]int64, batchSize) + for j := range batchSize { + inputContents.Int64Contents[j] = v[j][0] + } + + } + case [][]float32: + if len(v) > 0 { + batchSize := len(v) + datatype = "FP32" + + inputContents.Fp32Contents = make([]float32, batchSize) + for j := range batchSize { + inputContents.Fp32Contents[j] = v[j][0] + } + } + case [][]float64: + if len(v) > 0 { + batchSize := len(v) + datatype = "FP64" + + inputContents.Fp64Contents = make([]float64, batchSize) + for j := range batchSize { + inputContents.Fp64Contents[j] = v[j][0] + } } - case []int32: - tensor.Contents.IntContents = data - case []int64: - tensor.Contents.Int64Contents = data - case []float32: - tensor.Contents.Fp32Contents = data - case []float64: - tensor.Contents.Fp64Contents = data default: - return nil, fmt.Errorf("unsupported input data type %T for %s", data, input.name) + return nil, fmt.Errorf("unsupported input type for %s at index %d: %T", inputName, i, param) } - req.Inputs[i] = tensor + inputTensor.Datatype = datatype + inputTensor.Shape = []int64{int64(batchSize), 1} + + req.Inputs[i] = inputTensor } return req, nil } +// parseRawOutput if output is provided in raw format. func parseRawOutput(rawData []byte, datatype string, batchSize int) (interface{}, error) { switch datatype { case "INT64": @@ -207,124 +290,6 @@ func parseRawOutput(rawData []byte, datatype string, batchSize int) (interface{} } } -func prepareInputs(indexToName map[int]string, params []interface{}) ([]preparedInput, error) { - if len(params) == 0 { - return nil, fmt.Errorf("no input parameters provided") - } - - var inputs []preparedInput - - for i, param := range params { - inputName, exists := indexToName[i] - if !exists { - return nil, fmt.Errorf("no input name found for index %d", i) - } - - switch v := param.(type) { - case [][]string: - if len(v) > 0 { - batchSize := len(v) - data := make([]string, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "BYTES", - data: data, - }) - } - case [][]int: - if len(v) > 0 { - batchSize := len(v) - data := make([]int32, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = int32(v[j][0]) - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "INT32", - data: data, - }) - } - case [][]int32: - if len(v) > 0 { - batchSize := len(v) - data := make([]int32, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "INT32", - data: data, - }) - } - case [][]int64: - if len(v) > 0 { - batchSize := len(v) - data := make([]int64, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "INT64", - data: data, - }) - } - case [][]float32: - if len(v) > 0 { - batchSize := len(v) - data := make([]float32, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "FP32", - data: data, - }) - } - case [][]float64: - if len(v) > 0 { - batchSize := len(v) - data := make([]float64, batchSize) - for j := 0; j < batchSize; j++ { - if len(v[j]) > 0 { - data[j] = v[j][0] - } - } - inputs = append(inputs, preparedInput{ - name: inputName, - shape: []int64{int64(batchSize), 1}, - datatype: "FP64", - data: data, - }) - } - default: - return nil, fmt.Errorf("unsupported input type for %s at index %d: %T", inputName, i, param) - } - } - - return inputs, nil -} - func convertGRPCResponse(response *triton.ModelInferResponse) ([]interface{}, error) { if len(response.Outputs) == 0 { return nil, fmt.Errorf("no outputs in response") @@ -358,90 +323,40 @@ func convertGRPCResponse(response *triton.ModelInferResponse) ([]interface{}, er switch output.Datatype { case "FP32": - if len(output.Contents.Fp32Contents) == 0 { - converted := make([][]float32, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []float32{0.0} - } - result[i] = converted - } else { - converted := make([][]float32, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Fp32Contents); j++ { - converted[j] = []float32{output.Contents.Fp32Contents[j]} - } - result[i] = converted + converted := make([][]float32, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.Fp32Contents); j++ { + converted[j] = []float32{output.Contents.Fp32Contents[j]} } - + result[i] = converted case "FP64": - if len(output.Contents.Fp64Contents) == 0 { - converted := make([][]float64, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []float64{0.0} - } - result[i] = converted - } else { - converted := make([][]float64, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Fp64Contents); j++ { - converted[j] = []float64{output.Contents.Fp64Contents[j]} - } - result[i] = converted + converted := make([][]float64, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.Fp64Contents); j++ { + converted[j] = []float64{output.Contents.Fp64Contents[j]} } + result[i] = converted case "INT32": - if len(output.Contents.IntContents) == 0 { - converted := make([][]int32, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []int32{0} - } - result[i] = converted - } else { - converted := make([][]int32, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.IntContents); j++ { - converted[j] = []int32{output.Contents.IntContents[j]} - } - result[i] = converted + converted := make([][]int32, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.IntContents); j++ { + converted[j] = []int32{output.Contents.IntContents[j]} } - + result[i] = converted case "INT64": - if len(output.Contents.Int64Contents) == 0 { - converted := make([][]int64, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []int64{0} - } - result[i] = converted - } else { - converted := make([][]int64, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Int64Contents); j++ { - converted[j] = []int64{output.Contents.Int64Contents[j]} - } - result[i] = converted + converted := make([][]int64, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.Int64Contents); j++ { + converted[j] = []int64{output.Contents.Int64Contents[j]} } + result[i] = converted case "BYTES": - if len(output.Contents.BytesContents) == 0 { - converted := make([][]string, batchSize) - for j := 0; j < batchSize; j++ { - converted[j] = []string{""} - } - result[i] = converted - } else { - converted := make([][]string, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.BytesContents); j++ { - converted[j] = []string{string(output.Contents.BytesContents[j])} - } - result[i] = converted + converted := make([][]string, batchSize) + for j := 0; j < batchSize && j < len(output.Contents.BytesContents); j++ { + converted[j] = []string{string(output.Contents.BytesContents[j])} } + result[i] = converted default: - if len(output.Contents.Fp32Contents) > 0 { - converted := make([][]float32, batchSize) - for j := 0; j < batchSize && j < len(output.Contents.Fp32Contents); j++ { - converted[j] = []float32{output.Contents.Fp32Contents[j]} - } - result[i] = converted - } else { - return nil, fmt.Errorf("unsupported output datatype %s for %s", output.Datatype, output.Name) - } + return nil, fmt.Errorf("unsupported output datatype %s for %s", output.Datatype, output.Name) } } diff --git a/service/triton/grpc_test.go b/service/triton/grpc_test.go index 57e9c0c..836defc 100644 --- a/service/triton/grpc_test.go +++ b/service/triton/grpc_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" triton "github.com/viant/mly/proto/triton" - "github.com/viant/mly/service/config" - "github.com/viant/mly/shared" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -28,73 +26,90 @@ func createMockTritonConn(ctx context.Context, t *testing.T, listener *bufconn.L return conn } -func createMockGRPCTritonEvaluator(t *testing.T, cfg *config.Model, grpcConn *grpc.ClientConn) *TritonEvaluator { - grpcClient := triton.NewGRPCInferenceServiceClient(grpcConn) - cfg.Triton.Init() - evaluator, err := NewTritonEvaluator(cfg, map[string]TritonClient{ - cfg.Triton.ServerID: &GRPCClient{ - grpcConn: grpcConn, - grpcClient: grpcClient, - }, - }) +// mockTritonServer implements triton.GRPCInferenceServiceServer for testing +type mockTritonServer struct { + triton.UnimplementedGRPCInferenceServiceServer - require.NoError(t, err) - return evaluator + modelReady bool + responses map[string]*triton.ModelInferResponse } -func TestTritonEvaluator_ReloadAndSupportsReload(t *testing.T) { - ctx := context.Background() +func (m *mockTritonServer) RepositoryModelLoad(ctx context.Context, req *triton.RepositoryModelLoadRequest) (*triton.RepositoryModelLoadResponse, error) { + return &triton.RepositoryModelLoadResponse{}, nil +} - // Set up mock server - mock := &mockTritonServer{ - modelReady: true, - responses: map[string]*triton.ModelInferResponse{ - "test_model": { - ModelName: "test_model", - Outputs: []*triton.ModelInferResponse_InferOutputTensor{ - { - Name: "output", - Datatype: "INT64", - Shape: []int64{2, 1}, - Contents: &triton.InferTensorContents{ - Int64Contents: []int64{42, 100}, - }, - }, +func (m *mockTritonServer) ModelReady(ctx context.Context, req *triton.ModelReadyRequest) (*triton.ModelReadyResponse, error) { + return &triton.ModelReadyResponse{Ready: m.modelReady}, nil +} + +func (m *mockTritonServer) ModelInfer(ctx context.Context, req *triton.ModelInferRequest) (*triton.ModelInferResponse, error) { + if resp, ok := m.responses[req.ModelName]; ok { + return resp, nil + } + + // Return a default response + return &triton.ModelInferResponse{ + ModelName: req.ModelName, + Outputs: []*triton.ModelInferResponse_InferOutputTensor{ + { + Name: "output", + Datatype: "FP32", + Shape: []int64{1, 1}, + Contents: &triton.InferTensorContents{ + Fp32Contents: []float32{0.5}, }, }, }, - } + }, nil +} +// startMockGRPCServer starts an in-memory gRPC server for testing +func startMockGRPCServer(t *testing.T, mock *mockTritonServer) (*grpc.Server, *bufconn.Listener) { + buffer := 1024 * 1024 + listener := bufconn.Listen(buffer) + + server := grpc.NewServer() + triton.RegisterGRPCInferenceServiceServer(server, mock) + + go func() { + if err := server.Serve(listener); err != nil { + t.Fatalf("Server exited with error: %v", err) + } + }() + + return server, listener +} + +func createClient(t *testing.T, ctx context.Context, mock *mockTritonServer) (func(), *GRPCClient) { server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + client := &GRPCClient{ + grpcConn: grpcConn, + grpcClient: triton.NewGRPCInferenceServiceClient(grpcConn), + } - // ReloadIfNeeded should be a no-op - err := evaluator.ReloadIfNeeded(context.Background()) - assert.NoError(t, err) + return func() { + server.Stop() + }, client } -func TestTritonEvaluator_PredictWithMockServer(t *testing.T) { +func TestGRPCClient_ModelLoad(t *testing.T) { + ctx := context.Background() + + // Set up mock server + mock := &mockTritonServer{ + modelReady: true, + } + + stopper, client := createClient(t, ctx, mock) + defer stopper() + + err := client.ModelLoad(ctx, "test_model") + require.NoError(t, err) +} + +func TestGRPCClient_ModelInfer(t *testing.T) { ctx := context.Background() // Set up mock server @@ -117,36 +132,15 @@ func TestTritonEvaluator_PredictWithMockServer(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - // Create evaluator with mock connection - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() // Test prediction params := []interface{}{ [][]string{{"value1"}, {"value2"}}, // 2 batch items } - results, err := evaluator.Predict(ctx, params) + results, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -158,7 +152,7 @@ func TestTritonEvaluator_PredictWithMockServer(t *testing.T) { assert.Equal(t, []int64{100}, output[1]) } -func TestTritonEvaluator_PredictWithRawOutputContents(t *testing.T) { +func TestGRPCClient_ModelInferWithRawOutputContents(t *testing.T) { ctx := context.Background() // Set up mock server with raw output contents @@ -181,35 +175,14 @@ func TestTritonEvaluator_PredictWithRawOutputContents(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"value1"}, {"value2"}}, } - results, err := evaluator.Predict(ctx, params) + results, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -220,7 +193,7 @@ func TestTritonEvaluator_PredictWithRawOutputContents(t *testing.T) { assert.InDelta(t, 50.0, output[1][0], 0.01) } -func TestTritonEvaluator_PredictAllInputTypes(t *testing.T) { +func TestGRPCClient_ModelInferAllInputTypes(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -315,31 +288,10 @@ func TestTritonEvaluator_PredictAllInputTypes(t *testing.T) { responses: map[string]*triton.ModelInferResponse{"test_model": tc.expectedResp}, } - server, listener := startMockGRPCServer(t, mock) - grpcConn := createMockTritonConn(ctx, t, listener) - defer server.Stop() + stopper, client := createClient(t, ctx, mock) + defer stopper() - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: tc.inputType}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: tc.inputType}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() - - results, err := evaluator.Predict(ctx, []interface{}{tc.inputData}) + results, err := client.ModelInfer(ctx, "test_model", []interface{}{tc.inputData}, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) assert.NotNil(t, results[0]) @@ -347,55 +299,7 @@ func TestTritonEvaluator_PredictAllInputTypes(t *testing.T) { } } -// mockTritonServer implements triton.GRPCInferenceServiceServer for testing -type mockTritonServer struct { - triton.UnimplementedGRPCInferenceServiceServer - modelReady bool - responses map[string]*triton.ModelInferResponse -} - -func (m *mockTritonServer) ModelReady(ctx context.Context, req *triton.ModelReadyRequest) (*triton.ModelReadyResponse, error) { - return &triton.ModelReadyResponse{Ready: m.modelReady}, nil -} - -func (m *mockTritonServer) ModelInfer(ctx context.Context, req *triton.ModelInferRequest) (*triton.ModelInferResponse, error) { - if resp, ok := m.responses[req.ModelName]; ok { - return resp, nil - } - // Return a default response - return &triton.ModelInferResponse{ - ModelName: req.ModelName, - Outputs: []*triton.ModelInferResponse_InferOutputTensor{ - { - Name: "output", - Datatype: "FP32", - Shape: []int64{1, 1}, - Contents: &triton.InferTensorContents{ - Fp32Contents: []float32{0.5}, - }, - }, - }, - }, nil -} - -// startMockGRPCServer starts an in-memory gRPC server for testing -func startMockGRPCServer(t *testing.T, mock *mockTritonServer) (*grpc.Server, *bufconn.Listener) { - buffer := 1024 * 1024 - listener := bufconn.Listen(buffer) - - server := grpc.NewServer() - triton.RegisterGRPCInferenceServiceServer(server, mock) - - go func() { - if err := server.Serve(listener); err != nil { - t.Logf("Server exited with error: %v", err) - } - }() - - return server, listener -} - -func TestTritonEvaluator_PredictBytesOutput(t *testing.T) { +func TestGRPCClient_ModelInferBytesOutput(t *testing.T) { ctx := context.Background() mock := &mockTritonServer{ @@ -417,35 +321,14 @@ func TestTritonEvaluator_PredictBytesOutput(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "string"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"input1"}, {"input2"}}, } - results, err := evaluator.Predict(ctx, params) + results, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -456,7 +339,7 @@ func TestTritonEvaluator_PredictBytesOutput(t *testing.T) { assert.Equal(t, []string{"result2"}, output[1]) } -func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { +func TestGRPCClient_ModelInferDifferentBatchSizes(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -496,29 +379,8 @@ func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() // Generate input batch inputBatch := make([][]string, tc.batchSize) @@ -526,7 +388,7 @@ func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { inputBatch[i] = []string{fmt.Sprintf("input_%d", i)} } - results, err := evaluator.Predict(ctx, []interface{}{inputBatch}) + results, err := client.ModelInfer(ctx, "test_model", []interface{}{inputBatch}, map[int]string{0: "input1"}) require.NoError(t, err) require.Len(t, results, 1) @@ -537,7 +399,7 @@ func TestTritonEvaluator_PredictDifferentBatchSizes(t *testing.T) { } } -func TestTritonEvaluator_PredictUnsupportedType(t *testing.T) { +func TestGRPCClient_ModelInferUnsupportedType(t *testing.T) { ctx := context.Background() mock := &mockTritonServer{ @@ -560,66 +422,19 @@ func TestTritonEvaluator_PredictUnsupportedType(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "string"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"test"}}, } - _, err := evaluator.Predict(ctx, params) + _, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.Error(t, err) assert.Contains(t, err.Error(), "unsupported") } -func TestTritonEvaluator_PredictEmptyBatch(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - evaluator := createMockGRPCTritonEvaluator(t, cfg, nil) - defer evaluator.Close() - - _, err := evaluator.Predict(context.Background(), []interface{}{}) - require.Error(t, err) - assert.Contains(t, err.Error(), "no input parameters") -} - -func TestTritonEvaluator_PredictMissingOutput(t *testing.T) { +func TestGRPCClient_ModelInferMissingOutput(t *testing.T) { ctx := context.Background() mock := &mockTritonServer{ @@ -639,35 +454,14 @@ func TestTritonEvaluator_PredictMissingOutput(t *testing.T) { }, } - server, listener := startMockGRPCServer(t, mock) - defer server.Stop() - - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "int64"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - ServerID: "test_server", - }, - } - - grpcConn := createMockTritonConn(ctx, t, listener) - evaluator := createMockGRPCTritonEvaluator(t, cfg, grpcConn) - defer evaluator.Close() + stopper, client := createClient(t, ctx, mock) + defer stopper() params := []interface{}{ [][]string{{"test"}}, } - _, err := evaluator.Predict(ctx, params) + _, err := client.ModelInfer(ctx, "test_model", params, map[int]string{0: "input1"}) require.Error(t, err) assert.Contains(t, err.Error(), "missing contents") } diff --git a/service/triton/http.go b/service/triton/http.go index 4195a81..6de7227 100644 --- a/service/triton/http.go +++ b/service/triton/http.go @@ -22,30 +22,6 @@ type HTTPClient struct { debug bool } -func (c *HTTPClient) sendRequestCheckStatus(ctx context.Context, method, path string) (*http.Response, error) { - if c.debug { - log.Printf("Sending request %s %s\n", method, path) - } - - httpReq, err := http.NewRequestWithContext(ctx, method, c.serverURL+path, nil) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) - } - - resp, err := c.handleRequestWithRetry(ctx, httpReq, nil) - - if err != nil { - return nil, err - } - - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return resp, fmt.Errorf("triton server http status code: %d for %s %s", resp.StatusCode, method, path) - } - - return resp, nil -} - func (c *HTTPClient) ServerReady(ctx context.Context) error { path := "/v2/health/ready" _, err := c.sendRequestCheckStatus(ctx, "GET", path) @@ -58,7 +34,7 @@ func (c *HTTPClient) ModelInfer(ctx context.Context, modelName string, inputs [] return nil, err } - tritonResponse, err := c.sendRequest(ctx, modelName, tritonRequest) + tritonResponse, err := c.sendInferRequest(ctx, modelName, tritonRequest) if err != nil { return nil, err } @@ -96,94 +72,61 @@ func (c *HTTPClient) ModelUnload(ctx context.Context, modelName string) error { return err } -func (c *HTTPClient) Close() error { - return nil -} +func (c *HTTPClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + url := c.serverURL + "/v2/models/" + modelName -// TritonInput represents a single input tensor for Triton -type TritonInput struct { - Name string `json:"name"` - Shape []int `json:"shape"` - DataType string `json:"datatype"` - Data interface{} `json:"data"` -} - -type TritonOutput struct { - Name string `json:"name"` - Data interface{} `json:"data"` -} + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } -// TritonRequest represents the HTTP request format to Triton -type TritonRequest struct { - Inputs []TritonInput `json:"inputs"` -} + resp, err := c.handleRequestWithRetry(ctx, req, nil) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } -// TritonResponse represents the HTTP response format from Triton -type TritonResponse struct { - Outputs []TritonOutput `json:"outputs"` -} + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("triton server http status code: %d for %s", resp.StatusCode, url) + } -func (t *TritonInput) MarshalJSONObject(enc *gojay.Encoder) { - enc.StringKey("name", t.Name) - enc.ArrayKey("shape", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { - for _, v := range t.Shape { - enc.AddInt(v) - } - })) - enc.StringKey("datatype", t.DataType) + modelMetadata := new(ModelMetadata) + if err := json.NewDecoder(resp.Body).Decode(modelMetadata); err != nil { + return nil, fmt.Errorf("failed to parse Triton response: %w", err) + } - enc.ArrayKey("data", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { - switch data := t.Data.(type) { - case []string: - for _, v := range data { - enc.AddString(v) - } - case []int: - for _, v := range data { - enc.AddInt(v) - } - case []float32: - for _, v := range data { - enc.AddFloat32(v) - } - case []float64: - for _, v := range data { - enc.AddFloat64(v) - } - default: - for i := 0; i < reflect.ValueOf(data).Len(); i++ { - val := reflect.ValueOf(data).Index(i).Interface() - enc.AddInterface(val) - } - } - })) + return modelMetadata, nil } -func (t *TritonInput) IsNil() bool { - return t == nil +func (c *HTTPClient) Close() error { + return nil } -func (t *TritonRequest) MarshalJSONObject(enc *gojay.Encoder) { - enc.ArrayKey("inputs", (*TritonInputs)(&t.Inputs)) -} +func (c *HTTPClient) sendRequestCheckStatus(ctx context.Context, method, path string) (*http.Response, error) { + if c.debug { + log.Printf("Sending request %s %s\n", method, path) + } -func (t *TritonRequest) IsNil() bool { - return t == nil -} + httpReq, err := http.NewRequestWithContext(ctx, method, c.serverURL+path, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } -type TritonInputs []TritonInput + resp, err := c.handleRequestWithRetry(ctx, httpReq, nil) -func (t *TritonInputs) MarshalJSONArray(enc *gojay.Encoder) { - for i := range *t { - enc.AddObject(&(*t)[i]) + if err != nil { + return nil, err } -} -func (t *TritonInputs) IsNil() bool { - return t == nil || len(*t) == 0 + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return resp, fmt.Errorf("triton server http status code: %d for %s %s", resp.StatusCode, method, path) + } + + return resp, nil } -func (c *HTTPClient) sendRequest(ctx context.Context, modelName string, request *TritonRequest) (*TritonResponse, error) { +func (c *HTTPClient) sendInferRequest(ctx context.Context, modelName string, request *TritonRequest) (*TritonResponse, error) { url := c.serverURL + "/v2/models/" + modelName + "/infer" buf := bytes.NewBuffer(make([]byte, 0, 1024)) @@ -193,7 +136,6 @@ func (c *HTTPClient) sendRequest(ctx context.Context, modelName string, request } jsonData := buf.Bytes() - // Create HTTP request httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %w", err) @@ -256,12 +198,91 @@ func (c *HTTPClient) handleRequestWithRetry(ctx context.Context, httpReq *http.R return resp, nil } -// convertToTritonRequest converts Feeds ([numInputs]([batchSize][1]T)) to Triton request format -func convertToTritonRequest(params []interface{}, indexToName map[int]string) (*TritonRequest, error) { - if len(params) == 0 { - return nil, fmt.Errorf("no input parameters provided") +// TritonInput represents a single input tensor for Triton +type TritonInput struct { + Name string `json:"name"` + Shape []int `json:"shape"` + DataType string `json:"datatype"` + Data interface{} `json:"data"` +} + +type TritonOutput struct { + Name string `json:"name"` + Data interface{} `json:"data"` +} + +// TritonRequest represents the HTTP request format to Triton +type TritonRequest struct { + Inputs []TritonInput `json:"inputs"` +} + +// TritonResponse represents the HTTP response format from Triton +type TritonResponse struct { + Outputs []TritonOutput `json:"outputs"` +} + +func (t *TritonInput) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("name", t.Name) + enc.ArrayKey("shape", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { + for _, v := range t.Shape { + enc.AddInt(v) + } + })) + enc.StringKey("datatype", t.DataType) + + enc.ArrayKey("data", gojay.EncodeArrayFunc(func(enc *gojay.Encoder) { + switch data := t.Data.(type) { + case []string: + for _, v := range data { + enc.AddString(v) + } + case []int: + for _, v := range data { + enc.AddInt(v) + } + case []float32: + for _, v := range data { + enc.AddFloat32(v) + } + case []float64: + for _, v := range data { + enc.AddFloat64(v) + } + default: + for i := 0; i < reflect.ValueOf(data).Len(); i++ { + val := reflect.ValueOf(data).Index(i).Interface() + enc.AddInterface(val) + } + } + })) +} + +func (t *TritonInput) IsNil() bool { + return t == nil +} + +func (t *TritonRequest) MarshalJSONObject(enc *gojay.Encoder) { + enc.ArrayKey("inputs", (*TritonInputs)(&t.Inputs)) +} + +func (t *TritonRequest) IsNil() bool { + return t == nil +} + +type TritonInputs []TritonInput + +func (t *TritonInputs) MarshalJSONArray(enc *gojay.Encoder) { + for i := range *t { + enc.AddObject(&(*t)[i]) } +} +func (t *TritonInputs) IsNil() bool { + return t == nil || len(*t) == 0 +} + +// convertToTritonRequest converts Feeds ([numInputs]([batchSize][1]T)) to Triton request format +func convertToTritonRequest(params []interface{}, indexToName map[int]string) (*TritonRequest, error) { var inputs []TritonInput // Convert each parameter to Triton input format diff --git a/service/triton/metrics.go b/service/triton/metrics.go index 73cfe49..22043a0 100644 --- a/service/triton/metrics.go +++ b/service/triton/metrics.go @@ -76,6 +76,16 @@ var ( []string{"model"}, ) + modelMetadataDurationMicrosSummary = prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "triton", + Name: "model_metadata_duration_summary_us", + Help: "Duration of Triton ModelMetadata RPCs, labeled by model name, successful only.", + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) ) func init() { @@ -88,6 +98,7 @@ func init() { prometheus.MustRegister(modelReadyDurationMicrosSummary) prometheus.MustRegister(modelLoadDurationMicrosSummary) prometheus.MustRegister(modelUnloadDurationMicrosSummary) + prometheus.MustRegister(modelMetadataDurationMicrosSummary) } type MeteredTritonClient struct { @@ -169,6 +180,19 @@ func (c *MeteredTritonClient) ModelUnload(ctx context.Context, modelName string) return err } +func (c *MeteredTritonClient) ModelMetadata(ctx context.Context, modelName string) (*ModelMetadata, error) { + var metadata *ModelMetadata + var err error + err = withGatherers(func() error { + metadata, err = c.client.ModelMetadata(ctx, modelName) + return err + }, func(duration float64) { + modelMetadataDurationMicrosSummary.WithLabelValues(modelName).Observe(duration) + }) + + return metadata, err +} + func (c *MeteredTritonClient) Close() error { return c.client.Close() } diff --git a/service/triton/repository.go b/service/triton/repository.go new file mode 100644 index 0000000..a28e975 --- /dev/null +++ b/service/triton/repository.go @@ -0,0 +1,48 @@ +package triton + +import ( + "sync" +) + +type mlyModelID string +type tritonModelName string + +type Repository struct { + mu sync.Mutex + usage map[tritonModelName]map[mlyModelID]struct{} +} + +func (r *Repository) RegisterUsage(mlyID mlyModelID, tritonName tritonModelName) { + r.mu.Lock() + defer r.mu.Unlock() + mlyUsages, ok := r.usage[tritonName] + if !ok { + mlyUsages = make(map[mlyModelID]struct{}) + r.usage[tritonName] = mlyUsages + } + + mlyUsages[mlyID] = struct{}{} +} + +// UnregisterUsage returns true if all usages of a model have been unregistered. +// The TritonClient should then actual unload the model on the server. +func (r *Repository) UnregisterUsage(mlyID mlyModelID, tritonName tritonModelName) bool { + r.mu.Lock() + defer r.mu.Unlock() + + mlyUsages, ok := r.usage[tritonName] + if !ok { + // this was never registered, so this is considered having been unregistered. + return true + } + + delete(mlyUsages, mlyID) + return len(mlyUsages) == 0 +} + +func NewRepository() *Repository { + return &Repository{ + usage: make(map[tritonModelName]map[mlyModelID]struct{}), + mu: sync.Mutex{}, + } +} diff --git a/service/triton/service.go b/service/triton/service.go new file mode 100644 index 0000000..b5c3f16 --- /dev/null +++ b/service/triton/service.go @@ -0,0 +1,46 @@ +package triton + +import ( + "context" +) + +// Service is a container for a client and a representation of model repository management. +type Service struct { + Client TritonClient + + Unloader ModelUnloader + Repository *Repository +} + +func (s *Service) RegisterUsage(mlyID string, tritonName string) { + if s.Repository == nil { + return + } + + s.Repository.RegisterUsage(mlyModelID(mlyID), tritonModelName(tritonName)) +} + +func NewService(client TritonClient) *Service { + return &Service{ + Client: client, + Unloader: client, + Repository: NewRepository(), + } +} + +func (s *Service) UnloadModel(ctx context.Context, mlyID string, tritonName string) error { + if s.Repository == nil { + return nil + } + + if s.Unloader == nil { + return nil + } + + shouldUnload := s.Repository.UnregisterUsage(mlyModelID(mlyID), tritonModelName(tritonName)) + if shouldUnload { + return s.Unloader.ModelUnload(ctx, tritonName) + } + + return nil +} diff --git a/service/triton/triton.go b/service/triton/triton.go deleted file mode 100644 index a4c8347..0000000 --- a/service/triton/triton.go +++ /dev/null @@ -1,228 +0,0 @@ -package triton - -import ( - "context" - "fmt" - "net/http" - "reflect" - "time" - - "github.com/viant/mly/service/config" - "github.com/viant/mly/service/domain" - "github.com/viant/mly/shared" - "github.com/viant/mly/shared/common" -) - -// TritonEvaluator implements PlatformEvaluator for Triton Inference Server via gRPC -type TritonEvaluator struct { - client TritonClient - - modelName string - - // if true, this client is used only for this instance - isPrivateClient bool - repositoryExplicit bool - - timeout time.Duration - - signature *domain.Signature - indexToName map[int]string - - inputs map[string]*domain.Input -} - -// NewTritonEvaluator creates a new Triton evaluator -func NewTritonEvaluator(config *config.Model, tritonClients map[string]TritonClient) (*TritonEvaluator, error) { - var client TritonClient - - isPrivateClient := config.URL != "" - timeout := time.Duration(config.Triton.Timeout) * time.Millisecond - - if isPrivateClient { - // "Private" URL configuration will only support HTTP - client = &HTTPClient{ - httpClient: &http.Client{ - Timeout: timeout, - }, - serverURL: config.URL, - debug: config.Debug, - } - } else { - client = tritonClients[config.Triton.ServerID] - if client == nil { - return nil, fmt.Errorf("client not found for Triton, server ID: %s", config.Triton.ServerID) - } - } - - evaluator := &TritonEvaluator{ - client: client, - modelName: config.Triton.ModelName, - timeout: timeout, - - isPrivateClient: isPrivateClient, - - // clients defined in TritonServers are assumed to be in EXPLICIT mode - repositoryExplicit: !isPrivateClient || config.Triton.RepositoryExplicit, - } - - if err := evaluator.handleIO(&config.MetaInput); err != nil { - return nil, err - } - - return evaluator, nil -} - -func NewRoutedTritonEvaluator(modelName string, client TritonClient, timeoutMs int, indexToName map[int]string) (*TritonEvaluator, error) { - return &TritonEvaluator{ - client: client, - modelName: modelName, - timeout: time.Duration(timeoutMs) * time.Millisecond, - repositoryExplicit: true, - indexToName: indexToName, - }, nil -} - -// Predict performs inference via Triton Inference Server -func (t *TritonEvaluator) Predict(ctx context.Context, params []interface{}) ([]interface{}, error) { - requestCtx := ctx - if _, hasDeadline := ctx.Deadline(); !hasDeadline { - var cancel context.CancelFunc - requestCtx, cancel = context.WithTimeout(ctx, t.timeout) - defer cancel() - } - - return t.client.ModelInfer(requestCtx, t.modelName, params, t.indexToName) -} - -func (t *TritonEvaluator) handleIO(io *shared.MetaInput) error { - var inputs []domain.Input - var outputs []domain.Output - - indexToName := make(map[int]string) - - mappedInputs := make(map[string]*domain.Input) - - if len(io.Inputs) > 0 { - for _, input := range io.Inputs { - if !input.Auxiliary { - inputs = append(inputs, domain.Input{ - Name: input.Name, - Index: input.Index, - }) - - indexToName[input.Index] = input.Name - } - - inputType := reflect.TypeOf("") - if input.DataType != "" { - switch input.DataType { - case "string": - inputType = reflect.TypeOf("") - case "int": - inputType = reflect.TypeOf(0) - case "int32": - inputType = reflect.TypeOf(int32(0)) - case "int64": - inputType = reflect.TypeOf(int64(0)) - case "float32": - inputType = reflect.TypeOf(float32(0)) - case "float64": - inputType = reflect.TypeOf(float64(0)) - } - } - - mappedInputs[input.Name] = &domain.Input{ - Name: input.Name, - Index: input.Index, - Type: inputType, - Vocab: false, - Auxiliary: input.Auxiliary, - } - } - } else { - return fmt.Errorf("missing input configuration for Triton evaluator. " + - "Add 'inputs' section to your model configuration YAML with field definitions") - } - - if len(io.Outputs) > 0 { - for i, output := range io.Outputs { - outputs = append(outputs, domain.Output{ - Name: output.Name, - Index: i, - DataType: output.DataType, - }) - } - } else { - return fmt.Errorf("missing output configuration for Triton evaluator. " + - "Add 'outputs' section to your model configuration YAML with field definitions") - } - - t.indexToName = indexToName - - t.signature = &domain.Signature{ - Inputs: inputs, - Outputs: outputs, - Output: outputs[0], - } - - t.inputs = mappedInputs - - return nil -} - -func (t *TritonEvaluator) Signature() *domain.Signature { - return t.signature -} - -func (t *TritonEvaluator) Dictionary() *common.Dictionary { - return nil -} - -func (t *TritonEvaluator) Stats(stats map[string]interface{}) { - // no stats -} - -func (t *TritonEvaluator) Inputs() map[string]*domain.Input { - return t.inputs -} - -// Close releases Triton client resources and stops health monitoring -func (t *TritonEvaluator) Close() error { - if t.isPrivateClient { - return t.client.Close() - } - - return nil -} - -// ReloadIfNeeded for independent Triton models, reloading is not supported. -func (t *TritonEvaluator) ReloadIfNeeded(ctx context.Context) error { - ready, err := t.client.ModelReady(ctx, t.modelName) - if err != nil { - return fmt.Errorf("failed to check Triton model %s health: %w", t.modelName, err) - } - - if ready { - return nil - } - - if !t.repositoryExplicit { - return fmt.Errorf("model %s not ready and Triton is not in EXPLICIT Model Control Mode: %w", t.modelName, err) - } - - err = t.client.ModelLoad(ctx, t.modelName) - if err != nil { - return fmt.Errorf("failed to load Triton model %s: %w", t.modelName, err) - } - - ready, err = t.client.ModelReady(ctx, t.modelName) - if err != nil { - return fmt.Errorf("failed to check Triton model %s health after loading: %w", t.modelName, err) - } - - if !ready { - return fmt.Errorf("model %s is not ready after loading", t.modelName) - } - - return nil -} diff --git a/service/triton/triton_test.go b/service/triton/triton_test.go deleted file mode 100644 index cd16312..0000000 --- a/service/triton/triton_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package triton - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/viant/mly/service/config" - "github.com/viant/mly/shared" -) - -func newTritonEvaluator(t *testing.T, cfg *config.Model) *TritonEvaluator { - cfg.Triton.Init() - evaluator, err := NewTritonEvaluator(cfg, nil) - require.NoError(t, err) - return evaluator -} - -func TestTritonEvaluator_Signature(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - {Name: "input2", Index: 1, DataType: "int64"}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - sig := evaluator.Signature() - require.NotNil(t, sig) - assert.Equal(t, 2, len(sig.Inputs)) - assert.Equal(t, 1, len(sig.Outputs)) - assert.Equal(t, "input1", sig.Inputs[0].Name) - assert.Equal(t, "output1", sig.Outputs[0].Name) -} - -func TestTritonEvaluator_SignatureWithAuxiliaryInputs(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - {Name: "auxiliary_input", Index: 1, DataType: "int64", Auxiliary: true}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - sig := evaluator.Signature() - require.NotNil(t, sig) - // Auxiliary inputs should be excluded from signature - assert.Equal(t, 1, len(sig.Inputs)) - assert.Equal(t, "input1", sig.Inputs[0].Name) -} - -func TestTritonEvaluator_Dictionary(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "input1", Index: 0, DataType: "string"}, - }, - Outputs: []*shared.Field{ - {Name: "output1", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - dict := evaluator.Dictionary() - assert.Nil(t, dict) -} - -func TestTritonEvaluator_InputsMapping(t *testing.T) { - cfg := &config.Model{ - ID: "test_model", - Platform: "triton", - URL: "http://localhost:8000", - MetaInput: shared.MetaInput{ - Inputs: []*shared.Field{ - {Name: "string_input", Index: 0, DataType: "string"}, - {Name: "int64_input", Index: 1, DataType: "int64"}, - {Name: "int32_input", Index: 2, DataType: "int32"}, - {Name: "float32_input", Index: 3, DataType: "float32"}, - {Name: "float64_input", Index: 4, DataType: "float64"}, - }, - Outputs: []*shared.Field{ - {Name: "output", Index: 0, DataType: "float32"}, - }, - }, - Triton: &config.TritonConfig{ - ModelName: "test_model", - }, - } - - evaluator := newTritonEvaluator(t, cfg) - defer evaluator.Close() - - inputs := evaluator.Inputs() - assert.Len(t, inputs, 5) - assert.Contains(t, inputs, "string_input") - assert.Contains(t, inputs, "int64_input") - assert.Contains(t, inputs, "int32_input") - assert.Contains(t, inputs, "float32_input") - assert.Contains(t, inputs, "float64_input") -} diff --git a/service/triton/types.go b/service/triton/types.go new file mode 100644 index 0000000..d94901a --- /dev/null +++ b/service/triton/types.go @@ -0,0 +1,23 @@ +package triton + +import ( + "fmt" + "reflect" +) + +func TritonToGoType(datatype string) reflect.Type { + switch datatype { + case "INT64": + return reflect.TypeOf(int64(0)) + case "INT32": + return reflect.TypeOf(int32(0)) + case "FP32": + return reflect.TypeOf(float32(0)) + case "FP64": + return reflect.TypeOf(float64(0)) + case "BYTES": + return reflect.TypeOf("") + default: + panic(fmt.Sprintf("unsupported Triton datatype: %s", datatype)) + } +} diff --git a/shared/circut/breaker.go b/shared/circut/breaker.go index 04b1f73..10baba5 100644 --- a/shared/circut/breaker.go +++ b/shared/circut/breaker.go @@ -26,11 +26,19 @@ func (b *Breaker) IsUp() bool { } // FlagUp is used to reset the backoff. +// +// Uses CompareAndSwap so the resetDuration reset only fires on an actual +// down->up transition (not on idempotent FlagUp calls), and so the write +// to b.Down is atomic with respect to IsUp's atomic.LoadInt32. The +// resetDuration write is performed under the mutex so it cannot race +// with FlagDown's resetDuration *= 2 (lost-update bug). func (b *Breaker) FlagUp() { + if !atomic.CompareAndSwapInt32(&b.Down, 1, 0) { + return + } b.mux.Lock() - b.Down = 0 - b.mux.Unlock() b.resetDuration = b.initialResetDuration + b.mux.Unlock() } // resetIfDue will spawn a goroutine to probe the resource if the backoff time @@ -57,21 +65,19 @@ func (b *Breaker) resetIfDue() { } // FlagDown is used to indicate the resource is down. +// +// CompareAndSwap atomically transitions the Down flag exactly once per +// up->down edge, so backoff state (resetTime, resetDuration) is updated +// once per trip even under concurrent FlagDown calls. The atomic write +// is also synchronized with IsUp's atomic.LoadInt32. func (b *Breaker) FlagDown() { - down := atomic.LoadInt32(&b.Down) - if down == 1 { + if !atomic.CompareAndSwapInt32(&b.Down, 0, 1) { return } - b.mux.Lock() - defer b.mux.Unlock() - if b.Down == 1 { - return - } - b.Down = 1 - b.resetTime = time.Now().Add(b.resetDuration) b.resetDuration *= 2 //double reset time each time service is Down + b.mux.Unlock() } // New creates a new circut breaker diff --git a/shared/circut/breaker_test.go b/shared/circut/breaker_test.go new file mode 100644 index 0000000..0ddcd4e --- /dev/null +++ b/shared/circut/breaker_test.go @@ -0,0 +1,139 @@ +package circut + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeProber records probe invocations and never calls FlagUp. +// Tests that need probe-driven recovery call b.FlagUp() directly. +type fakeProber struct { + probes int64 +} + +func (f *fakeProber) Probe() { + atomic.AddInt64(&f.probes, 1) +} + +// TestBreaker_Concurrent_NoDataRace exercises FlagDown / FlagUp / IsUp +// from many goroutines simultaneously. Run with `-race` to catch the +// previously-existing data race on b.Down (atomic read, non-atomic write). +// +// Without the fix this test reliably triggers a race-detector report: +// +// WARNING: DATA RACE +// Read at 0x... by goroutine N (atomic.LoadInt32): +// shared/circut.(*Breaker).IsUp +// Previous write at 0x... by goroutine M: +// shared/circut.(*Breaker).FlagUp / FlagDown +func TestBreaker_Concurrent_NoDataRace(t *testing.T) { + b := New(50*time.Millisecond, &fakeProber{}) + + const goroutines = 16 + const iterations = 5000 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(seed int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + switch (seed + i) % 3 { + case 0: + b.FlagDown() + case 1: + b.FlagUp() + case 2: + _ = b.IsUp() + } + } + }(g) + } + wg.Wait() + + // Final state assertion is intentionally weak; the point of this test + // is the race detector, not the terminal flag value. + _ = b.IsUp() +} + +// TestBreaker_BackoffAccumulates verifies that resetDuration doubles on +// each successive trip and is NOT clobbered by an interleaved FlagUp. +// Catches the lost-update bug where FlagUp's resetDuration reset ran +// outside the mutex and could race with FlagDown's resetDuration *= 2. +func TestBreaker_BackoffAccumulates(t *testing.T) { + const initial = 50 * time.Millisecond + b := New(initial, &fakeProber{}) + + require.Equal(t, initial, b.resetDuration, "initial resetDuration") + + // First trip: doubles to 100ms. + b.FlagDown() + require.Equal(t, 2*initial, b.resetDuration, "after 1st trip") + + // Recover and trip again: must double from 100ms to 200ms (NOT + // reset to 100ms, which is what the lost-update bug would do + // under racy timing). + b.FlagUp() + require.Equal(t, initial, b.resetDuration, "FlagUp resets to initial") + + b.FlagDown() + require.Equal(t, 2*initial, b.resetDuration, "after 2nd trip from initial") + + // Multiple FlagDowns without intervening FlagUp must NOT + // re-double. Idempotency comes from the CAS. + b.FlagDown() + b.FlagDown() + b.FlagDown() + assert.Equal(t, 2*initial, b.resetDuration, "extra FlagDowns are no-ops while down") +} + +// TestBreaker_FlagUp_Idempotent verifies that FlagUp on an already-up +// breaker does NOT reset resetDuration (which would be wrong if the +// breaker is in the middle of a backoff sequence and a stale Probe +// callback fires FlagUp redundantly). +func TestBreaker_FlagUp_Idempotent(t *testing.T) { + const initial = 50 * time.Millisecond + b := New(initial, &fakeProber{}) + + // Trip and recover -- resetDuration is back to initial. + b.FlagDown() + b.FlagUp() + require.Equal(t, initial, b.resetDuration) + + // Trip again -- resetDuration doubles. + b.FlagDown() + require.Equal(t, 2*initial, b.resetDuration) + + // Spurious FlagUp callback on an already-up breaker would be the + // state where Down is already 0 -- but we just trip'd, so it's 1. + // The realistic spurious FlagUp scenario is Probe firing twice + // after recovery has already happened. Simulate that. + b.FlagUp() // legitimate recovery; resetDuration -> initial + require.Equal(t, initial, b.resetDuration) + b.FlagUp() // spurious; must NOT clobber any subsequent backoff state + require.Equal(t, initial, b.resetDuration, "spurious FlagUp is a no-op") +} + +// TestBreaker_FlagDown_Idempotent verifies that repeated FlagDown calls +// while the breaker is already down do not advance resetTime or grow +// resetDuration further. +func TestBreaker_FlagDown_Idempotent(t *testing.T) { + const initial = 50 * time.Millisecond + b := New(initial, &fakeProber{}) + + b.FlagDown() + firstResetTime := b.resetTime + require.Equal(t, 2*initial, b.resetDuration) + + // Subsequent FlagDowns while down must be no-ops. + for i := 0; i < 5; i++ { + b.FlagDown() + } + assert.Equal(t, 2*initial, b.resetDuration, "resetDuration unchanged across redundant FlagDowns") + assert.Equal(t, firstResetTime, b.resetTime, "resetTime unchanged across redundant FlagDowns") +} diff --git a/shared/circut/latency_breaker.go b/shared/circut/latency_breaker.go new file mode 100644 index 0000000..9d48053 --- /dev/null +++ b/shared/circut/latency_breaker.go @@ -0,0 +1,246 @@ +package circut + +import ( + "math/rand/v2" + "sync" + "sync/atomic" + "time" +) + +// LatencyBreaker is a state machine that sheds traffic when observed +// request latencies exceed configured thresholds. It is independent of +// (and parallel to) the connection-failure-based Breaker. +// +// Detection: +// +// - latest: the most recent observation. Compared against +// LatestThreshold. +// - rolling: average over a sliding window. Compared against +// RollingThreshold. +// +// State transitions: +// +// - OFF -> ON on every observation where +// latest > LatestThreshold OR rolling > RollingThreshold +// (zero-valued thresholds are skipped, so a single threshold +// can be used by leaving the other zero). +// - ON -> OFF after K consecutive observations satisfy +// latest < LatestThreshold AND rolling < RollingThreshold. +// +// While ON, IsUp() returns true with probability PassThroughFraction +// and false otherwise -- letting a small fraction of traffic through +// to drive recovery sensing without committing real load. +// +// Concurrency: +// +// State (Down) is read with atomic.LoadInt32 in IsUp() and updated +// via atomic.CompareAndSwapInt32 from Observe(). Compound state +// (latest / rolling buckets / consecutiveOK) is protected by a +// mutex; only one Observe can mutate at a time. IsUp does not block. +// +// Random number source for pass-through is math/rand/v2 top-level, +// which is concurrent-safe and lock-free in Go 1.22+. A test seam +// (randFloat) allows deterministic tests. +type LatencyBreaker struct { + // Configuration. Set at construction; not mutated after. + LatestThreshold time.Duration + RollingThreshold time.Duration + RollingWindow time.Duration + KConsecutive int + PassThroughFraction float64 + + // state holds the OFF (0) / ON (1) flag. Atomic. + state int32 + + mu sync.Mutex // guards latest, rolling, consecutiveOK + latest time.Duration + rolling *rollingAverage + consecutiveOK int + + // randFloat returns a value in [0, 1). Defaults to math/rand/v2.Float64. + // Override for deterministic tests. + randFloat func() float64 +} + +// NewLatencyBreaker constructs a LatencyBreaker. Zero-valued thresholds +// disable that branch of the trip predicate. If both thresholds are +// zero, Observe is a no-op and IsUp always returns true (effectively +// disabled). +func NewLatencyBreaker( + latestThreshold, rollingThreshold, rollingWindow time.Duration, + kConsecutive int, + passThroughFraction float64, +) *LatencyBreaker { + if rollingWindow <= 0 { + rollingWindow = time.Second + } + if kConsecutive < 1 { + kConsecutive = 1 + } + if passThroughFraction < 0 { + passThroughFraction = 0 + } + if passThroughFraction > 1 { + passThroughFraction = 1 + } + return &LatencyBreaker{ + LatestThreshold: latestThreshold, + RollingThreshold: rollingThreshold, + RollingWindow: rollingWindow, + KConsecutive: kConsecutive, + PassThroughFraction: passThroughFraction, + rolling: newRollingAverage(rollingWindow, 10), + randFloat: rand.Float64, + } +} + +// IsUp returns true if the breaker is OFF (allowing all traffic), or +// true with probability PassThroughFraction if ON (allowing a small +// fraction through for recovery sensing). +func (lb *LatencyBreaker) IsUp() bool { + if lb == nil { + return true + } + if atomic.LoadInt32(&lb.state) == 0 { + return true + } + return lb.randFloat() < lb.PassThroughFraction +} + +// Observe records the latency of a completed request and advances the +// state machine. Called from the bidder client after each httpPost +// attempt completes (success or failure -- timeouts and errors count +// as observations and the elapsed time captured by the caller). +func (lb *LatencyBreaker) Observe(latency time.Duration) { + if lb == nil { + return + } + if lb.LatestThreshold == 0 && lb.RollingThreshold == 0 { + // Both thresholds disabled; no signal to act on. + return + } + + lb.mu.Lock() + now := time.Now() + lb.latest = latency + lb.rolling.add(latency, now) + rollingAvg := lb.rolling.average(now) + + state := atomic.LoadInt32(&lb.state) + + // triggerOn: ANY threshold breached. Zero-valued thresholds skip. + triggerOn := false + if lb.LatestThreshold > 0 && latency > lb.LatestThreshold { + triggerOn = true + } + if lb.RollingThreshold > 0 && rollingAvg > lb.RollingThreshold { + triggerOn = true + } + + // triggerOffReady: BOTH thresholds satisfied as below. Zero-valued + // thresholds count as satisfied. + triggerOffReady := true + if lb.LatestThreshold > 0 && latency >= lb.LatestThreshold { + triggerOffReady = false + } + if lb.RollingThreshold > 0 && rollingAvg >= lb.RollingThreshold { + triggerOffReady = false + } + + switch state { + case 0: // OFF + if triggerOn { + atomic.StoreInt32(&lb.state, 1) + lb.consecutiveOK = 0 + } + case 1: // ON + if triggerOffReady { + lb.consecutiveOK++ + if lb.consecutiveOK >= lb.KConsecutive { + atomic.StoreInt32(&lb.state, 0) + lb.consecutiveOK = 0 + } + } else { + lb.consecutiveOK = 0 + } + } + lb.mu.Unlock() +} + +// State returns 0 (OFF / up) or 1 (ON / shedding). Primarily for tests. +func (lb *LatencyBreaker) State() int32 { + if lb == nil { + return 0 + } + return atomic.LoadInt32(&lb.state) +} + +// rollingAverage keeps a sliding-window average of durations using a +// fixed number of time-aligned buckets. Buckets that fall outside the +// current window are reset on next access. All access is serialized +// by LatencyBreaker.mu; this struct is not goroutine-safe on its own. +type rollingAverage struct { + window time.Duration + bucketDur time.Duration + bucketDurN int64 // bucketDur in nanoseconds, cached + buckets []rollingBucket +} + +type rollingBucket struct { + sum time.Duration + count int64 + until int64 // exclusive end of bucket period, in nanoseconds since epoch +} + +func newRollingAverage(window time.Duration, n int) *rollingAverage { + if n < 1 { + n = 1 + } + bd := window / time.Duration(n) + if bd <= 0 { + bd = window + n = 1 + } + return &rollingAverage{ + window: window, + bucketDur: bd, + bucketDurN: int64(bd), + buckets: make([]rollingBucket, n), + } +} + +// add records a value with completion time t. +func (r *rollingAverage) add(v time.Duration, t time.Time) { + tn := t.UnixNano() + idx := int((tn / r.bucketDurN) % int64(len(r.buckets))) + until := ((tn / r.bucketDurN) + 1) * r.bucketDurN + if r.buckets[idx].until != until { + // Bucket belongs to a different period; reset and reuse. + r.buckets[idx].sum = 0 + r.buckets[idx].count = 0 + r.buckets[idx].until = until + } + r.buckets[idx].sum += v + r.buckets[idx].count++ +} + +// average returns the average across all buckets whose period overlaps +// the window ending at t. Returns 0 if no in-window samples. +func (r *rollingAverage) average(t time.Time) time.Duration { + cutoff := t.UnixNano() - int64(r.window) + var sum time.Duration + var count int64 + for i := range r.buckets { + b := &r.buckets[i] + // Bucket's period end (until) must be after cutoff to be in-window. + if b.until <= cutoff { + continue + } + sum += b.sum + count += b.count + } + if count == 0 { + return 0 + } + return sum / time.Duration(count) +} diff --git a/shared/circut/latency_breaker_test.go b/shared/circut/latency_breaker_test.go new file mode 100644 index 0000000..ef31e7a --- /dev/null +++ b/shared/circut/latency_breaker_test.go @@ -0,0 +1,218 @@ +package circut + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + tinyWindow = 100 * time.Millisecond + winK = 3 +) + +func newTestLB(latest, rolling time.Duration, kConsecutive int, fraction float64) *LatencyBreaker { + return NewLatencyBreaker(latest, rolling, tinyWindow, kConsecutive, fraction) +} + +// TestLatencyBreaker_DisabledWhenZero verifies that a LatencyBreaker +// constructed with both thresholds = 0 is a no-op: Observe doesn't +// transition state, IsUp always returns true. +func TestLatencyBreaker_DisabledWhenZero(t *testing.T) { + lb := newTestLB(0, 0, winK, 0.01) + for i := 0; i < 10; i++ { + lb.Observe(time.Hour) + } + assert.Equal(t, int32(0), lb.State()) + assert.True(t, lb.IsUp()) +} + +// TestLatencyBreaker_TripOnLatest verifies OFF -> ON transition when +// the latest observation alone exceeds LatestThreshold (rolling +// threshold disabled). +func TestLatencyBreaker_TripOnLatest(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 0.01) + require.Equal(t, int32(0), lb.State()) + + lb.Observe(50 * time.Millisecond) + assert.Equal(t, int32(1), lb.State(), "single observation above latestThreshold trips ON") +} + +// TestLatencyBreaker_TripOnRolling verifies OFF -> ON transition when +// the rolling average crosses RollingThreshold even though no single +// observation hits the latest threshold. +func TestLatencyBreaker_TripOnRolling(t *testing.T) { + // LatestThreshold=0 (disabled), RollingThreshold=20ms. + lb := newTestLB(0, 20*time.Millisecond, winK, 0.01) + + // Stream of 25ms observations -- each below latestThreshold (which + // is disabled), but rolling crosses 20ms quickly. + for i := 0; i < 5; i++ { + lb.Observe(25 * time.Millisecond) + } + assert.Equal(t, int32(1), lb.State(), "sustained observations above rollingThreshold trip ON") +} + +// TestLatencyBreaker_RecoverViaKConsecutive verifies ON -> OFF takes K +// consecutive observations satisfying the configured thresholds. +// +// Uses RollingThreshold=0 (disabled) so the test isolates the K-consecutive +// state machine from rolling-window pollution -- a single slow observation +// in the rolling window keeps the rolling average elevated for the full +// window duration regardless of how many subsequent fast observations +// arrive, which is correct production behavior but obscures this test's +// intent. +func TestLatencyBreaker_RecoverViaKConsecutive(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 1.0) + + // Trip via latest. + lb.Observe(50 * time.Millisecond) + require.Equal(t, int32(1), lb.State()) + + // K-1 fast observations -- not enough to recover. + for i := 0; i < winK-1; i++ { + lb.Observe(5 * time.Millisecond) + require.Equal(t, int32(1), lb.State(), "still ON after %d observations", i+1) + } + + // Kth fast observation -- recovers. + lb.Observe(5 * time.Millisecond) + assert.Equal(t, int32(0), lb.State(), "ON -> OFF after K consecutive fast observations") +} + +// TestLatencyBreaker_RecoveryResetsOnSlow verifies that a single +// above-threshold observation resets the consecutive-OK counter. +// RollingThreshold=0 to isolate the consecutive-OK reset logic. +func TestLatencyBreaker_RecoveryResetsOnSlow(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 1.0) + + lb.Observe(50 * time.Millisecond) // trip + require.Equal(t, int32(1), lb.State()) + + lb.Observe(5 * time.Millisecond) // 1 OK + lb.Observe(5 * time.Millisecond) // 2 OK + require.Equal(t, int32(1), lb.State(), "still ON before K consecutive") + + lb.Observe(50 * time.Millisecond) // bad observation -- resets consecutiveOK to 0 + require.Equal(t, int32(1), lb.State()) + + // Now need K consecutive again from scratch. + lb.Observe(5 * time.Millisecond) + lb.Observe(5 * time.Millisecond) + require.Equal(t, int32(1), lb.State(), "still ON after only 2 fast observations following reset") + lb.Observe(5 * time.Millisecond) + assert.Equal(t, int32(0), lb.State(), "ON -> OFF after K consecutive following reset") +} + +// TestLatencyBreaker_PassThroughFraction verifies the probabilistic +// pass-through behavior while ON, using an injected deterministic +// random source. +func TestLatencyBreaker_PassThroughFraction(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 0.25) + + // Trip. + lb.Observe(50 * time.Millisecond) + require.Equal(t, int32(1), lb.State()) + + // Inject a deterministic counter generating values 0.0, 0.1, 0.2, + // 0.3, 0.4, ... -- pass-through fires when value < 0.25, i.e. for + // the first 3 (0.0, 0.1, 0.2) of every 10. + var counter int + lb.randFloat = func() float64 { + v := float64(counter%10) / 10.0 + counter++ + return v + } + + pass := 0 + const N = 1000 + for i := 0; i < N; i++ { + if lb.IsUp() { + pass++ + } + } + // Expected: 30% pass exactly with this generator (3/10 buckets). + // Allow ±2% drift for any rounding. + assert.InDelta(t, 0.30, float64(pass)/float64(N), 0.02, + "pass-through fraction should match injected random source") +} + +// TestLatencyBreaker_PassThroughZero verifies that PassThroughFraction=0 +// sheds 100% while ON. +func TestLatencyBreaker_PassThroughZero(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 0, winK, 0.0) + lb.Observe(50 * time.Millisecond) + require.Equal(t, int32(1), lb.State()) + + for i := 0; i < 100; i++ { + assert.False(t, lb.IsUp(), "PassThroughFraction=0 must shed 100%% while ON") + } +} + +// TestLatencyBreaker_NilSafe verifies a nil receiver behaves as +// a permanently-up no-op breaker (lets the caller treat +// "no LatencyBreaker configured" identically to "configured but OFF"). +func TestLatencyBreaker_NilSafe(t *testing.T) { + var lb *LatencyBreaker + assert.True(t, lb.IsUp()) + lb.Observe(time.Second) // must not panic + assert.Equal(t, int32(0), lb.State()) +} + +// TestLatencyBreaker_Concurrent_NoDataRace exercises Observe / IsUp +// from many goroutines simultaneously. Run with `-race` to catch any +// data races introduced by future edits. +func TestLatencyBreaker_Concurrent_NoDataRace(t *testing.T) { + lb := newTestLB(40*time.Millisecond, 20*time.Millisecond, winK, 0.5) + + const goroutines = 16 + const iterations = 2000 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(seed int) { + defer wg.Done() + for i := 0; i < iterations; i++ { + switch (seed + i) % 3 { + case 0: + lb.Observe(10 * time.Millisecond) + case 1: + lb.Observe(50 * time.Millisecond) + case 2: + _ = lb.IsUp() + } + } + }(g) + } + wg.Wait() + + // Sanity: state must be 0 or 1. + st := lb.State() + assert.True(t, st == 0 || st == 1, "state must be 0 or 1, got %d", st) +} + +// TestRollingAverage_BucketRotation verifies that observations outside +// the rolling window are excluded from the average. +func TestRollingAverage_BucketRotation(t *testing.T) { + r := newRollingAverage(100*time.Millisecond, 10) + t0 := time.Unix(0, 1_000_000_000) // 1s exactly + + // Add 10 observations of 50ms within a single bucket period. + for i := 0; i < 10; i++ { + r.add(50*time.Millisecond, t0) + } + assert.Equal(t, 50*time.Millisecond, r.average(t0)) + + // 200ms later, all old buckets should be outside the 100ms window. + tLater := t0.Add(200 * time.Millisecond) + assert.Equal(t, time.Duration(0), r.average(tLater), + "window should expire all 10-bucket-old observations") +} + +// Sanity check: ensure the atomic types we use compile. +var _ = atomic.LoadInt32 diff --git a/shared/client/config.go b/shared/client/config.go index f103c15..8a07fa6 100644 --- a/shared/client/config.go +++ b/shared/client/config.go @@ -1,10 +1,25 @@ package client import ( + "fmt" + "time" + "github.com/viant/mly/shared/client/config" ) -//Config represents a client config +const ( + defaultLatencyBreakerRollingWindow = time.Second + defaultLatencyBreakerKConsecutive = 3 + defaultLatencyBreakerPassThroughFraction = 0.01 +) + +type latencyBreakerSettings struct { + latest, rolling, window time.Duration + k int + fraction float64 +} + +// Config represents a client config type Config struct { Hosts []*Host Model string @@ -22,9 +37,86 @@ type Config struct { Debug bool DictHashValidation bool + + // LatencyBreaker fields are passed through to circut.LatencyBreaker + // at Service init time. When both LatencyBreakerLatestThreshold and + // LatencyBreakerRollingThreshold are zero, the LatencyBreaker is not + // constructed and the host's IsUp() reflects only the connection + // breaker -- backward-compatible default. + // + // LatestThreshold is the per-attempt latency above which a single + // observation is enough to trip into the shedding state. The caller + // is expected to size this near (or just below) its own request + // timeout so the breaker fires before requests would have failed. + LatencyBreakerLatestThreshold time.Duration + + // RollingThreshold is the rolling-average latency above which the + // breaker trips. Detects sustained slow-creep that no single + // observation crosses LatestThreshold for. + LatencyBreakerRollingThreshold time.Duration + + // RollingWindow is the duration over which the rolling average is + // computed. Default 1s if zero. + LatencyBreakerRollingWindow time.Duration + + // KConsecutive is the number of consecutive observations satisfying + // (latest < LatestThreshold AND rolling < RollingThreshold) needed + // to transition from ON back to OFF. Higher = more conservative + // recovery, prevents flap on outliers. Default 3 if zero. + LatencyBreakerKConsecutive int + + // PassThroughFraction is the probability that a request is allowed + // through while the breaker is ON, to drive recovery sensing. + // Default 0.01 (1%). Set higher for low-QPS models that need more + // observations to recover. Valid range: [0, 1]. A zero value means + // use the default. + LatencyBreakerPassThroughFraction float64 +} + +func (c *Config) latencyBreakerSettings() (latencyBreakerSettings, bool, error) { + settings := latencyBreakerSettings{ + latest: c.LatencyBreakerLatestThreshold, + rolling: c.LatencyBreakerRollingThreshold, + window: c.LatencyBreakerRollingWindow, + k: c.LatencyBreakerKConsecutive, + fraction: c.LatencyBreakerPassThroughFraction, + } + + if settings.latest < 0 { + return settings, false, fmt.Errorf("LatencyBreakerLatestThreshold must be >= 0, got %s", settings.latest) + } + if settings.rolling < 0 { + return settings, false, fmt.Errorf("LatencyBreakerRollingThreshold must be >= 0, got %s", settings.rolling) + } + if settings.window < 0 { + return settings, false, fmt.Errorf("LatencyBreakerRollingWindow must be >= 0, got %s", settings.window) + } + if settings.k < 0 { + return settings, false, fmt.Errorf("LatencyBreakerKConsecutive must be >= 0, got %d", settings.k) + } + if settings.fraction < 0 || settings.fraction > 1 { + return settings, false, fmt.Errorf("LatencyBreakerPassThroughFraction must be in [0, 1], got %v", settings.fraction) + } + + enabled := settings.latest > 0 || settings.rolling > 0 + if !enabled { + return settings, false, nil + } + + if settings.window == 0 { + settings.window = defaultLatencyBreakerRollingWindow + } + if settings.k == 0 { + settings.k = defaultLatencyBreakerKConsecutive + } + if settings.fraction == 0 { + settings.fraction = defaultLatencyBreakerPassThroughFraction + } + + return settings, true, nil } -//CacheSize returns cache size +// CacheSize returns cache size func (c *Config) CacheSize() int { if c.CacheSizeMb == 0 { return 0 diff --git a/shared/client/config_test.go b/shared/client/config_test.go new file mode 100644 index 0000000..983e9d2 --- /dev/null +++ b/shared/client/config_test.go @@ -0,0 +1,76 @@ +package client + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_LatencyBreakerSettings_Disabled(t *testing.T) { + settings, enabled, err := (&Config{}).latencyBreakerSettings() + require.NoError(t, err) + assert.False(t, enabled) + assert.Zero(t, settings.latest) + assert.Zero(t, settings.rolling) +} + +func TestConfig_LatencyBreakerSettings_Defaults(t *testing.T) { + cfg := &Config{LatencyBreakerLatestThreshold: 40 * time.Millisecond} + + settings, enabled, err := cfg.latencyBreakerSettings() + require.NoError(t, err) + assert.True(t, enabled) + assert.Equal(t, 40*time.Millisecond, settings.latest) + assert.Equal(t, defaultLatencyBreakerRollingWindow, settings.window) + assert.Equal(t, defaultLatencyBreakerKConsecutive, settings.k) + assert.Equal(t, defaultLatencyBreakerPassThroughFraction, settings.fraction) +} + +func TestConfig_LatencyBreakerSettings_ExplicitValues(t *testing.T) { + cfg := &Config{ + LatencyBreakerRollingThreshold: 20 * time.Millisecond, + LatencyBreakerRollingWindow: 500 * time.Millisecond, + LatencyBreakerKConsecutive: 5, + LatencyBreakerPassThroughFraction: 0.25, + } + + settings, enabled, err := cfg.latencyBreakerSettings() + require.NoError(t, err) + assert.True(t, enabled) + assert.Equal(t, 20*time.Millisecond, settings.rolling) + assert.Equal(t, 500*time.Millisecond, settings.window) + assert.Equal(t, 5, settings.k) + assert.Equal(t, 0.25, settings.fraction) +} + +func TestConfig_LatencyBreakerSettings_Invalid(t *testing.T) { + cases := []struct { + name string + config Config + }{ + {name: "negative latest threshold", config: Config{LatencyBreakerLatestThreshold: -time.Millisecond}}, + {name: "negative rolling threshold", config: Config{LatencyBreakerRollingThreshold: -time.Millisecond}}, + {name: "negative rolling window", config: Config{LatencyBreakerRollingWindow: -time.Millisecond}}, + {name: "negative consecutive count", config: Config{LatencyBreakerKConsecutive: -1}}, + {name: "negative pass-through fraction", config: Config{LatencyBreakerPassThroughFraction: -0.1}}, + {name: "pass-through fraction above one", config: Config{LatencyBreakerPassThroughFraction: 1.1}}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := tc.config.latencyBreakerSettings() + assert.Error(t, err) + }) + } +} + +func TestNew_InvalidLatencyBreakerOptionReturnsError(t *testing.T) { + _, err := New( + "invalid-latency-breaker", + []*Host{NewHost("localhost", 8080)}, + WithLatencyBreaker(-time.Millisecond, 0, 0, 0, 0), + ) + assert.ErrorContains(t, err, "LatencyBreakerLatestThreshold") +} diff --git a/shared/client/dictionary.go b/shared/client/dictionary.go index 663c411..10594e6 100644 --- a/shared/client/dictionary.go +++ b/shared/client/dictionary.go @@ -11,6 +11,7 @@ type fieldOffset int const ( // oov = out of vocabulary + // TODO - technically the OOV value can be overwritten - [UNK] may be a valid value oovString = "[UNK]" oovInt = 0 @@ -19,13 +20,16 @@ const ( unknownKeyField = fieldOffset(-1) ) -// Dictionary helps identify any out-of-vocabulary input values for reducing the cache space - this enables us to leverage any -// dimensionality reduction within the model to optimize wall-clock performance. This is primarily useful for categorical inputs -// as well as any continous inputs with an acceptable quantization. +// Dictionary helps identify any out-of-vocabulary input values for reducing the cache space, as well as an explicit cache-invalidation strategy via hash. +// See shared/common.Dictionary type Dictionary struct { - hash int + hash int + + // registry key is the input name registry map[string]*entry - inputs map[string]*shared.Field + + // inputs is an index, key is the input name + inputs map[string]*shared.Field } func (d *Dictionary) KeysLen() int { @@ -36,10 +40,6 @@ func (d *Dictionary) inputSize() int { return len(d.inputs) } -func (d *Dictionary) size() int { - return len(d.registry) -} - // TODO refactor, this has a singular use case func (d *Dictionary) Fields() map[string]*shared.Field { return d.inputs @@ -73,12 +73,15 @@ func (d *Dictionary) getEntry(n string) *entry { } if elem == nil { + // generally speaking, if d.registry has data, it should have data for ALL columns + // TODO this shouldn't print, it should tick some counter log.Printf("registry entry was nil for %v", n) } return elem } +// lookupString returns the mapped key, or unknownKeyField, meaning no mapping exists func (d *Dictionary) lookupString(key string, value string) (string, fieldOffset) { input := d.getInput(key) if input == nil { @@ -103,7 +106,7 @@ func (d *Dictionary) lookupString(key string, value string) (string, fieldOffset return oovString, ii } -// TODO integration and boundary testing; OOV may depend on vocabulary +// lookupInt returns the mapped key, or unknownKeyField, meaning no mapping exists func (d *Dictionary) lookupInt(key string, value int) (int, fieldOffset) { input := d.getInput(key) if input == nil { @@ -128,6 +131,7 @@ func (d *Dictionary) lookupInt(key string, value int) (int, fieldOffset) { return oovInt, ii } +// reduceFloat returns a lower-precision float key, or unknownKeyField, meaning no reduction exists func (d *Dictionary) reduceFloat(key string, value float32) (float32, int, fieldOffset) { input := d.getInput(key) if input == nil { diff --git a/shared/client/host.go b/shared/client/host.go index 5580db3..e7fb356 100644 --- a/shared/client/host.go +++ b/shared/client/host.go @@ -13,7 +13,7 @@ import ( var defaultRequestTimeout = 50 * time.Millisecond -//Host represents endpoint host +// Host represents endpoint host type Host struct { name string port int @@ -25,30 +25,52 @@ type Host struct { mux sync.RWMutex *circut.Breaker + // LatencyBreaker is an optional latency-driven shed mechanism that + // runs in parallel to the connection-failure-based Breaker. When + // configured (non-nil), getHost() requires both Breaker.IsUp() and + // LatencyBreaker.IsUp() to return true before letting a request + // through. nil = disabled (acts as permanently up). + LatencyBreaker *circut.LatencyBreaker + // memoization prefix string } +// IsUp combines the connection-failure Breaker and the (optional) +// LatencyBreaker. Both must say up for the host to be considered up. +// Shadows the embedded Breaker.IsUp(). +func (h *Host) IsUp() bool { + if h.Breaker != nil && !h.Breaker.IsUp() { + return false + } + if !h.LatencyBreaker.IsUp() { + return false + } + return true +} + func isSecurePort(port int) bool { return port == 443 || port == 1443 } -//IsSecurePort() returns true if secure port +// IsSecurePort() returns true if secure port func (h *Host) IsSecurePort() bool { return isSecurePort(h.port) } -//URL returns model eval URL +// URL returns model eval URL func (h *Host) evalURL(model string) string { return h.prefix + fmt.Sprintf(common.ModelURI, model) } -//URL returns meta config model eval URL +// URL returns meta config model eval URL +// See service/endpoint/meta.(*metaHandler).ServeHTTP func (h *Host) metaConfigURL(model string) string { return h.prefix + fmt.Sprintf(common.MetaConfigURI, model) } -//URL returns meta config model eval URL +// URL returns meta config model eval URL +// See service/endpoint/meta.(*metaHandler).ServeHTTP func (h *Host) metaDictionaryURL(model string) string { return h.prefix + fmt.Sprintf(common.MetaDictionaryURI, model) } @@ -93,7 +115,7 @@ func (h *Host) Port() int { return h.port } -//NewHost returns new host +// NewHost returns new host func NewHost(name string, port int) *Host { if port <= 0 { port = 80 @@ -106,7 +128,7 @@ func NewHost(name string, port int) *Host { } } -//NewHosts creates hosts +// NewHosts creates hosts func NewHosts(port int, names []string) []*Host { var result = make([]*Host, 0) for _, name := range names { diff --git a/shared/client/marshal.go b/shared/client/marshal.go index ca91fa5..9a144bc 100644 --- a/shared/client/marshal.go +++ b/shared/client/marshal.go @@ -7,6 +7,7 @@ import ( "github.com/francoispqt/gojay" ) +// Deprecated. No need to be exported. func Marshal(data interface{}, id string) ([]byte, error) { if data == nil { return nil, fmt.Errorf("data was nil") diff --git a/shared/client/message-spec.yaml b/shared/client/message-spec.yaml new file mode 100644 index 0000000..c18b352 --- /dev/null +++ b/shared/client/message-spec.yaml @@ -0,0 +1,43 @@ +"$schema": "https://json-schema.org/draft/2020-12/schema" +$title: Message Specification v1.0 +type: object + +properties: + batch_size: + type: integer + default: 0 + description: | + The number of samples in this request. + Must be the first property present if batch mode is to be enabled, with a value greater than 0. + + cache_key: + oneOf: + - type: string + description: A single cache key for this request. + - type: array + items: + type: string + description: | + An array of cache keys for this request, corresponding to each batched sample. + Must be of length equal to batch_size. + +patternProperties: + ".*": + oneOf: + - oneOf: + - type: string + - type: number + description: A single value for this input. + - type: array + items: + oneOf: + - type: string + - type: number + description: | + A batch of values for input. + Must be of length equal to batch_size, or of length 1. + +examples: + - cache_key: "1234567890" + input1: "value1" + input2: "value2" \ No newline at end of file diff --git a/shared/client/option.go b/shared/client/option.go index e626dfb..2667ed3 100644 --- a/shared/client/option.go +++ b/shared/client/option.go @@ -1,6 +1,8 @@ package client import ( + "time" + "github.com/viant/gmetric" cconfig "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/datastore" @@ -39,6 +41,21 @@ func WithGmetrics(gmetrics *gmetric.Service) Option { return &gmetricsOpt{gmetrics: gmetrics} } +type prometheusMetricsOpt struct { + enable bool +} + +func (o *prometheusMetricsOpt) Apply(c *Service) { + c.noPrometheusMetrics = !o.enable +} + +// WithPrometheusMetrics enables or disables native Prometheus client +// metrics. Metrics are enabled by default; disable them for short-lived +// helper clients that should not register long-lived model series. +func WithPrometheusMetrics(enable bool) Option { + return &prometheusMetricsOpt{enable: enable} +} + type dictHashValidationOpt struct { enable bool } @@ -146,3 +163,29 @@ func (o *clientOptionsOption) Apply(c *Service) { func WithClientOptions(clientOptions ...dscli.Option) Option { return &clientOptionsOption{clientOptions: clientOptions} } + +type latencyBreakerOpt struct { + latest, rolling, window time.Duration + k int + fraction float64 +} + +func (o *latencyBreakerOpt) Apply(c *Service) { + c.Config.LatencyBreakerLatestThreshold = o.latest + c.Config.LatencyBreakerRollingThreshold = o.rolling + c.Config.LatencyBreakerRollingWindow = o.window + c.Config.LatencyBreakerKConsecutive = o.k + c.Config.LatencyBreakerPassThroughFraction = o.fraction +} + +// WithLatencyBreaker enables the latency-aware breaker on each host +// constructed for this Service. Defaults are applied during Service +// init: pass-through fraction 0.01, rolling window 1s, KConsecutive 3. +// +// Setting both latest and rolling to zero leaves the breaker disabled +// (backward compatible). Both thresholds are taken as raw durations; +// the caller is responsible for sizing them appropriately for the +// model's traffic profile and the caller's request timeout. +func WithLatencyBreaker(latest, rolling, window time.Duration, k int, fraction float64) Option { + return &latencyBreakerOpt{latest: latest, rolling: rolling, window: window, k: k, fraction: fraction} +} diff --git a/shared/client/prometheus.go b/shared/client/prometheus.go new file mode 100644 index 0000000..4a3158c --- /dev/null +++ b/shared/client/prometheus.go @@ -0,0 +1,320 @@ +package client + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/viant/mly/shared/stat/buckets" + "github.com/viant/mly/shared/stat/promc" +) + +const ( + promDescRunDuration = "Duration of client Run calls." + promDescHTTPDuration = "Duration of client HTTP calls, including retries." + promDescHTTPClientDuration = "Duration of client HTTP client calls." + promDescBatchSize = "Size of client batches." +) + +var ( + // EarlyCtxError + // loadFromCache error - this can only be a type error from Response.DataItemType(), (*Service).readFromCache() + + runErrorCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "prediction_error_counter", + Help: "Number of client and kind of prediction errors.", + }, + []string{"model", "error"}, + ) + + httpErrorCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_error_counter", + Help: "Number of client HTTP errors.", + }, + []string{"model", "error"}, + ) + + httpClientErrorCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_error_counter", + Help: "Number of client HTTP client errors.", + }, + []string{"model", "error"}, + ) +) + +type prometheusMetrics struct { + runDurationHistogram prometheus.Observer + batchSizeHistogram prometheus.Observer + httpDurationHistogram prometheus.Observer + httpClientDurationHistogram prometheus.Observer + + // The Summary metrics may be nil if noPrometheusSummaries is true. + + runDurationSummary prometheus.Observer + batchSizeSummary prometheus.Observer + httpDurationSummary prometheus.Observer + httpClientDurationSummary prometheus.Observer + + runErrorEarlyCtxCounter prometheus.Counter + runBaseErrorCounters promc.BaseErrorCounters + + httpDownCounter prometheus.Counter + httpBaseErrorCounters promc.BaseErrorCounters + + httpClientBaseErrorCounters promc.BaseErrorCounters +} + +func (m prometheusMetrics) observeRunDuration(duration float64) { + if m.runDurationHistogram != nil { + m.runDurationHistogram.Observe(duration) + } + if m.runDurationSummary != nil { + m.runDurationSummary.Observe(duration) + } +} + +func (m prometheusMetrics) observeBatchSize(batchSize float64) { + if m.batchSizeHistogram != nil { + m.batchSizeHistogram.Observe(batchSize) + } + if m.batchSizeSummary != nil { + m.batchSizeSummary.Observe(batchSize) + } +} + +func (m prometheusMetrics) observeHttpDuration(duration float64) { + if m.httpDurationHistogram != nil { + m.httpDurationHistogram.Observe(duration) + } + if m.httpDurationSummary != nil { + m.httpDurationSummary.Observe(duration) + } +} + +func (m prometheusMetrics) observeHttpClientDuration(duration float64) { + if m.httpClientDurationHistogram != nil { + m.httpClientDurationHistogram.Observe(duration) + } + if m.httpClientDurationSummary != nil { + m.httpClientDurationSummary.Observe(duration) + } +} + +// Used strictly to test for error type. +var are prometheus.AlreadyRegisteredError + +func isPrometheusAlreadyRegisteredError(err error) bool { + if err == nil { + return false + } + + return errors.As(err, &are) +} + +func (m *prometheusMetrics) registerPrometheusMetrics(registerer prometheus.Registerer, model string, noPrometheusSummaries bool) error { + // convenience function + register := func(metric prometheus.Collector) error { + err := registerer.Register(metric) + if err != nil && !isPrometheusAlreadyRegisteredError(err) { + + dc := make(chan *prometheus.Desc) + go func() { + metric.Describe(dc) + }() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + var metricName string + select { + case desc := <-dc: + metricName = desc.String() + case <-ctx.Done(): + metricName = "unknown" + } + + return fmt.Errorf("failed to register %s: %T, %w", metricName, err, err) + } + + return nil + } + + var err error + if !noPrometheusSummaries { + runDurationSummaryMicros := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "run_duration_summary_us", + Help: promDescRunDuration, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + err = register(runDurationSummaryMicros) + if err != nil { + return err + } + m.runDurationSummary = runDurationSummaryMicros.WithLabelValues(model) + + batchSizeSummary := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "batch_size_summary", + Help: promDescBatchSize, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + err = register(batchSizeSummary) + if err != nil { + return err + } + m.batchSizeSummary = batchSizeSummary.WithLabelValues(model) + + httpDurationSummaryMicros := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_duration_summary_us", + Help: promDescHTTPDuration, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + err = register(httpDurationSummaryMicros) + if err != nil { + return err + } + m.httpDurationSummary = httpDurationSummaryMicros.WithLabelValues(model) + + httpClientDurationSummaryMicros := prometheus.NewSummaryVec( + prometheus.SummaryOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_duration_summary_us", + Help: promDescHTTPClientDuration, + Objectives: buckets.CommonSummaryObjectives, + }, + []string{"model"}, + ) + + err = register(httpClientDurationSummaryMicros) + if err != nil { + return err + } + m.httpClientDurationSummary = httpClientDurationSummaryMicros.WithLabelValues(model) + } + + runDurationHistogramMicros := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "run_duration_histogram_us", + Help: promDescRunDuration, + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + + err = register(runDurationHistogramMicros) + if err != nil { + return err + } + m.runDurationHistogram = runDurationHistogramMicros.WithLabelValues(model) + + batchSizeHistogram := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "batch_size_histogram", + Help: promDescBatchSize, + Buckets: []float64{1, 2, 3, 4, 5, 7, 10, 12, 15, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100}, + }, + []string{"model"}, + ) + err = register(batchSizeHistogram) + if err != nil { + return err + } + m.batchSizeHistogram = batchSizeHistogram.WithLabelValues(model) + + httpDurationHistogramMicros := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_duration_histogram_us", + Help: promDescHTTPDuration, + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + err = register(httpDurationHistogramMicros) + if err != nil { + return err + } + m.httpDurationHistogram = httpDurationHistogramMicros.WithLabelValues(model) + + httpClientDurationHistogramMicros := prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "mly", + Subsystem: "client", + Name: "http_client_duration_histogram_us", + Help: promDescHTTPClientDuration, + Buckets: buckets.MicrosecondBuckets, + }, + []string{"model"}, + ) + + err = register(httpClientDurationHistogramMicros) + if err != nil { + return err + } + m.httpClientDurationHistogram = httpClientDurationHistogramMicros.WithLabelValues(model) + + err = register(runErrorCounter) + if err != nil { + return err + } + + m.runErrorEarlyCtxCounter = runErrorCounter.WithLabelValues(model, "earlyCtx") + + // convenience function + mkBECs := func(bec *promc.BaseErrorCounters, counter *prometheus.CounterVec) { + bec.OtherErrorCounter = counter.WithLabelValues(model, "error") + bec.DeadlineExceededCounter = counter.WithLabelValues(model, "deadlineExceeded") + bec.CanceledCounter = counter.WithLabelValues(model, "canceled") + } + + mkBECs(&m.runBaseErrorCounters, runErrorCounter) + + err = register(httpErrorCounter) + if err != nil { + return err + } + mkBECs(&m.httpBaseErrorCounters, httpErrorCounter) + + m.httpDownCounter = httpErrorCounter.WithLabelValues(model, "down") + + err = register(httpClientErrorCounter) + if err != nil { + return err + } + mkBECs(&m.httpClientBaseErrorCounters, runErrorCounter) + + return nil +} diff --git a/shared/client/service.go b/shared/client/service.go index 1b72425..a676150 100644 --- a/shared/client/service.go +++ b/shared/client/service.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "log" @@ -18,7 +19,9 @@ import ( "time" "github.com/francoispqt/gojay" + "github.com/prometheus/client_golang/prometheus" "github.com/viant/gmetric" + "github.com/viant/mly/shared/circut" "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/common/storable" @@ -66,6 +69,23 @@ type Service struct { httpCliCounter *gmetric.Operation dictCounter *gmetric.Operation + // PrometheusRegisterer is used to register Prometheus metrics. + // If not provided, the default Prometheus registry will be used. + PrometheusRegisterer prometheus.Registerer + + // noPrometheusSummaries is used to disable Prometheus summaries. + // If true, only histograms will be registered and used. + // See https://prometheus.io/docs/practices/histograms for guidance. + noPrometheusSummaries bool + + // noPrometheusMetrics disables native Prometheus metric registration. + // This is useful for short-lived helper clients (for example server + // startup self-tests) that should not leave zero-valued model series in + // the process-wide registry after the helper has finished. + noPrometheusMetrics bool + + prometheusMetrics prometheusMetrics + ErrorHistory tracker.Tracker } @@ -80,15 +100,23 @@ func (s *Service) NewMessage() *Message { // input can vary in types, but if it is an instance of Cachable, then the configured // caching system will be used. func (s *Service) Run(ctx context.Context, input interface{}, response *Response) error { - onDone := s.counter.Begin(time.Now()) + startTime := time.Now() + onDone := s.counter.Begin(startTime) stats := stat.NewValues() + defer func() { onDone(time.Now(), *stats...) + + duration := time.Since(startTime).Microseconds() + s.prometheusMetrics.observeRunDuration(float64(duration)) s.releaseMessage(input) }() if ctx.Err() != nil { stats.Append(stat.EarlyCtxError) + if s.prometheusMetrics.runErrorEarlyCtxCounter != nil { + s.prometheusMetrics.runErrorEarlyCtxCounter.Inc() + } } if response.Data == nil { @@ -107,6 +135,8 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response cachedCount, err = s.loadFromCache(ctx, &cached, batchSize, response, cachable) if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) + if ctx.Err() == nil && s.ErrorHistory != nil { go s.ErrorHistory.AddBytes([]byte(err.Error())) } @@ -123,6 +153,8 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response s.reportBatch(cachedCount, cached) } + s.prometheusMetrics.observeBatchSize(float64(batchSize)) + if (batchSize > 0 && cachedCount == batchSize) || (batchSize == 0 && cachedCount > 0) { response.Status = common.StatusCached return s.handleResponse(ctx, response.Data, cached, cachable) @@ -131,6 +163,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response data, err := Marshal(input, modelName) if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) return err } @@ -140,13 +173,18 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response } body, err := func() ([]byte, error) { - httpOnDone := s.httpCounter.Begin(time.Now()) + startTime := time.Now() + httpOnDone := s.httpCounter.Begin(startTime) httpStats := stat.NewValues() - od := metric.EnterThenExit(s.httpCounter, time.Now(), stat.Enter, stat.Exit) + od := metric.EnterThenExit(s.httpCounter, startTime, stat.Enter, stat.Exit) defer func() { httpOnDone(time.Now(), httpStats.Values()...) + + duration := time.Since(startTime).Microseconds() + s.prometheusMetrics.observeHttpDuration(float64(duration)) + od() }() @@ -158,13 +196,26 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response if err != nil { httpStats.AppendError(err) + s.prometheusMetrics.httpBaseErrorCounters.Observe(err) } return body, err }() if err != nil { + // Best-effort: parse the body as a Response struct so callers + // that check response.Error see the server-side error message + // (v0.20.0+ servers emit a structured JSON error body alongside + // the HTTP 4xx/5xx; older servers emit plain text and the + // unmarshal silently fails, leaving response untouched). + // The returned err remains the source-of-truth signal; this is + // purely additive population of the response struct. + if len(body) > 0 { + _ = gojay.Unmarshal(body, response) + } + stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) if ctx.Err() == nil && s.ErrorHistory != nil { go s.ErrorHistory.AddBytes([]byte(err.Error())) } @@ -175,6 +226,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response err = gojay.Unmarshal(body, response) if err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) return fmt.Errorf("failed to unmarshal: '%s'; due to %w", body, err) } @@ -188,6 +240,7 @@ func (s *Service) Run(ctx context.Context, input interface{}, response *Response if err = s.handleResponse(ctx, response.Data, cached, cachable); err != nil { stats.AppendError(err) + s.prometheusMetrics.runBaseErrorCounters.Observe(err) return fmt.Errorf("failed to handle resp: %w", err) } @@ -224,6 +277,7 @@ func (s *Service) loadFromCache(ctx context.Context, cached *[]interface{}, batc response.Status = common.StatusCached response.DictHash = dictHash } + return cachedCount, nil } @@ -257,6 +311,7 @@ func (s *Service) readFromCacheInBatch(ctx context.Context, batchSize int, dataT return cachedCount, err } +// readFromCache will return an error if target is not a pointer. func (s *Service) readFromCache(ctx context.Context, key string, target interface{}) (bool, int, error) { if s.datastore == nil || !s.datastore.Enabled() { return false, 0, nil @@ -264,7 +319,7 @@ func (s *Service) readFromCache(ctx context.Context, key string, target interfac dataType := reflect.TypeOf(target) if dataType.Kind() != reflect.Ptr { - return false, 0, fmt.Errorf("invalid response data type: expeted ptr but had: %T", target) + return false, 0, fmt.Errorf("invalid response data type: expected reflect.Ptr but had: %T", target) } storeKey := s.datastore.Key(key) @@ -292,6 +347,18 @@ func (s *Service) dictionary() *Dictionary { return dict } +func (s *Service) registerPrometheusMetrics() error { + if s.noPrometheusMetrics { + return nil + } + pr := prometheus.DefaultRegisterer + if s.PrometheusRegisterer != nil { + pr = s.PrometheusRegisterer + } + + return s.prometheusMetrics.registerPrometheusMetrics(pr, s.Model, s.noPrometheusSummaries) +} + func (s *Service) init() error { if s.gmetrics == nil { s.gmetrics = gmetric.New() @@ -303,6 +370,11 @@ func (s *Service) init() error { s.httpCliCounter = s.gmetrics.MultiOperationCounter(location, s.Model+"ClientHTTPCli", s.Model+" client HTTP client performance", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) s.dictCounter = s.gmetrics.MultiOperationCounter(location, s.Model+"ClientDict", s.Model+" client dictionary performance", time.Microsecond, time.Minute, 1, stat.ErrorOnly()) + err := s.registerPrometheusMetrics() + if err != nil { + return fmt.Errorf("failed to register Prometheus metrics: %w", err) + } + if s.ErrorHistory == nil { s.ErrorHistory = mg.NewK(20) } @@ -311,33 +383,39 @@ func (s *Service) init() error { s.Config.MaxRetry = 3 } - err := s.initHTTPClient() + if err := s.initLatencyBreakers(); err != nil { + return fmt.Errorf("failed to initialize latency breakers: %w", err) + } + + err = s.initHTTPClient() if err != nil { - return err + return fmt.Errorf("failed to initialize HTTP client: %w", err) } if s.Config.Datastore == nil { if err := s.loadModelConfig(); err != nil { - return err + return fmt.Errorf("failed to load model config: %w", err) } } if s.dict == nil { if err := s.loadModelDictionary(); err != nil { - return err + return fmt.Errorf("failed to load model dictionary: %w", err) } } if ds := s.Config.Datastore; ds != nil { ds.Init() if err = ds.Validate(); err != nil { - return err + return fmt.Errorf("failed to validate datastore config: %w", err) } } if s.datastore == nil { err := s.initDatastore() - return err + if err != nil { + return fmt.Errorf("failed to initialize datastore: %w", err) + } } s.messages = NewMessages(s.dictionary) @@ -350,7 +428,7 @@ func (s *Service) initHTTPClient() error { if host != nil && host.IsSecurePort() { cert, err := getCertPool() if err != nil { - return fmt.Errorf("failed to create certificate: %v", err) + return fmt.Errorf("failed to create certificate: %w", err) } tslConfig = &tls.Config{ @@ -528,7 +606,7 @@ func (s *Service) discoverConfig(host *Host, URL string) (*config.Remote, error) cfg := &config.Remote{} err = json.Unmarshal(data, cfg) if err != nil { - return nil, fmt.Errorf("failed to parse load %v, config: %s, %v", URL, data, err) + return nil, fmt.Errorf("failed to parse load %v, config: %s, %v", URL, data, err) } if s.Config.Debug { @@ -604,39 +682,96 @@ func (s *Service) postRequest(ctx context.Context, data []byte, mvt *stat.Values // TODO per-host counters host, err := s.getHost() if err != nil { + // getHost returns ErrNodeDown when the host's breaker IsUp() is + // false. Mark the request as shed so the operator can distinguish + // requests rejected pre-flight by the breaker from requests that + // reached httpPost and failed there. Without this, every shed + // request was conflated into the generic _error counter. + if errors.Is(err, common.ErrNodeDown) { + mvt.Append(stat.Shed) + } return nil, err } var output []byte + start := time.Now() output, err = s.httpPost(ctx, data, host) + // Feed the latency observation to the latency breaker (if one is + // configured on this host). Observe is nil-safe. + host.LatencyBreaker.Observe(time.Since(start)) + if common.IsConnectionError(err) { if s.Config.Debug { log.Printf("[%s postRequest] connection error:%s", s.Config.Model, err) } + mvt.Append(stat.Down) + if s.prometheusMetrics.httpDownCounter != nil { + s.prometheusMetrics.httpDownCounter.Inc() + } + host.FlagDown() } return output, err } +// initLatencyBreakers attaches a circut.LatencyBreaker to each +// configured host when the Config has at least one non-zero threshold. +// Both thresholds zero -> no breaker constructed (backward-compatible +// no-op). +func (s *Service) initLatencyBreakers() error { + settings, enabled, err := s.Config.latencyBreakerSettings() + if err != nil { + return err + } + if !enabled { + return nil + } + for _, h := range s.Config.Hosts { + if h == nil || h.LatencyBreaker != nil { + continue + } + h.LatencyBreaker = circut.NewLatencyBreaker( + settings.latest, + settings.rolling, + settings.window, + settings.k, + settings.fraction, + ) + } + return nil +} + func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte, error) { evalUrl := host.evalURL(s.Model) var terminate bool var postErr error + // postBody captures the response body across retry iterations so that + // non-2xx terminal errors can return the JSON error body alongside the + // error. Run() does a best-effort unmarshal of this body to populate + // response.Error and response.Status from a v0.20.0+ server's structured + // error response. Older servers return plain-text bodies; the best-effort + // unmarshal silently fails on those, leaving response untouched. + var postBody []byte for i := 0; i < s.MaxRetry; i++ { data, err := func() ([]byte, error) { - onDone := s.httpCliCounter.Begin(time.Now()) + startTime := time.Now() + onDone := s.httpCliCounter.Begin(startTime) stats := stat.NewValues() defer func() { onDone(time.Now(), stats.Values()...) + + duration := time.Since(startTime).Microseconds() + s.prometheusMetrics.observeHttpClientDuration(float64(duration)) }() request, err := http.NewRequestWithContext(ctx, http.MethodPost, evalUrl, bytes.NewReader(data)) if err != nil { stats.AppendError(err) + s.prometheusMetrics.httpClientBaseErrorCounters.Observe(err) return nil, err } @@ -647,6 +782,7 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte if err != nil { stats.AppendError(err) + s.prometheusMetrics.httpClientBaseErrorCounters.Observe(err) return nil, err } @@ -660,15 +796,32 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte // as long as this func is run synchronously, // this is safe terminate = true - return nil, fmt.Errorf("HTTP Code:%d, Body:\"%s\" (read nil:%v error:%v)", + // Return the body so the caller can parse the JSON error response + // (v0.20.0+ servers emit a Response struct here; older servers emit + // plain text). The error keeps the same wrapping format for backward + // compatibility with consumers that string-match on it. + return data, fmt.Errorf("HTTP Code:%d, Body:\"%s\" (read nil:%v error:%v)", response.StatusCode, string(data), response.Body == nil, err) } + if err != nil { + // 200 OK with a partial / aborted body read is not a success. + // Surfacing this prevents callers from silently unmarshaling an empty body + // (observed downstream as "Invalid JSON, wrong char ' ' found at position 0"). + return nil, fmt.Errorf("HTTP Code:%d, partial body read: %w (got %d bytes)", + response.StatusCode, err, len(data)) + } + return data, nil }() if err != nil { postErr = err + // Capture body for terminal errors so the caller can parse it. + // On retryable errors data is nil, so this is a no-op there. + if data != nil { + postBody = data + } } if terminate || ctx.Err() != nil { @@ -676,12 +829,12 @@ func (s *Service) httpPost(ctx context.Context, data []byte, host *Host) ([]byte break } - if data != nil { + if data != nil && err == nil { return data, nil } } - return nil, postErr + return postBody, postErr } func (s *Service) getHost() (*Host, error) { diff --git a/shared/client/service_test.go b/shared/client/service_test.go index 4d3810f..ed6506b 100644 --- a/shared/client/service_test.go +++ b/shared/client/service_test.go @@ -2,6 +2,7 @@ package client import ( "context" + "errors" "fmt" "net/http" "path" @@ -9,13 +10,16 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/viant/bintly" + "github.com/viant/gmetric" "github.com/viant/mly/shared" cconfig "github.com/viant/mly/shared/client/config" "github.com/viant/mly/shared/client/faker" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config" "github.com/viant/mly/shared/datastore/mock" + "github.com/viant/mly/shared/stat" "github.com/viant/scache" "github.com/viant/toolbox" ) @@ -85,6 +89,7 @@ func TestService_Run(t *testing.T) { WithDictionary(dictionary), WithDataStorer(mock.New()), WithDebug(true), + WithPrometheusMetrics(false), } } @@ -228,3 +233,196 @@ func TestService_Run(t *testing.T) { } } } + +// TestService_Run_ShedIncrementsBreakerShedMetric verifies that when the +// host's circuit breaker is in the down state at request time, Run() +// increments the new ClientHTTP_shed marker on the http counter (and does +// NOT increment _down, which is reserved for the trip event itself). +// +// Before this fix, shed requests were conflated into the generic _error +// counter, leaving operators unable to distinguish "request rejected +// pre-flight by the breaker" from "request reached httpPost and failed +// there." See shared/client/service.go postRequest. +func TestService_Run_ShedIncrementsBreakerShedMetric(t *testing.T) { + baseURL := toolbox.CallerDirectory(3) + + selectPort := 8089 + server := faker.Server{URL: path.Join(baseURL, "testdata"), Port: selectPort, Debug: true} + server.Start() + defer server.Stop() + + metaInput := shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "i1"}, + {Name: "i2", Wildcard: true}, + }, + } + dictionary := NewDictionary(&common.Dictionary{ + Layers: []common.Layer{{Name: "i1", Strings: []string{"v1", "v2"}}}, + Hash: 123, + }, metaInput.Inputs) + hosts := []*Host{NewHost("localhost", selectPort)} + + gmetrics := gmetric.New() + const modelID = "shed_metric_case" + options := []Option{ + WithGmetrics(gmetrics), + WithRemoteConfig(&cconfig.Remote{ + Datastore: config.Datastore{ + Cache: &scache.Config{SizeMb: 64, Shards: 10, EntrySize: 1024}, + }, + MetaInput: metaInput, + }), + WithCacheScope(CacheScopeLocal), + WithDictionary(dictionary), + WithDataStorer(mock.New()), + WithDebug(true), + } + srv, err := New(modelID, hosts, options...) + require.NoError(t, err) + + // Force the host's breaker into the down state so getHost() will + // return ErrNodeDown without ever calling httpPost. + hosts[0].FlagDown() + require.False(t, hosts[0].IsUp(), "host must be flagged down for the shed path") + + msg := srv.NewMessage() + msg.StringKey("i1", "v1") + msg.StringKey("i2", "v10") + + response := &Response{Data: &TestOutput{}} + err = srv.Run(context.Background(), msg, response) + + require.Error(t, err, "shed request must surface as a non-nil err") + assert.True(t, errors.Is(err, common.ErrNodeDown), "shed err must wrap ErrNodeDown, got %v", err) + + // Inspect the cumulative counter values for ClientHTTP. The + // new Shed marker must increment by 1; the existing Down marker must + // stay at 0 (no FlagDown was called by this request -- the breaker + // was already down before getHost was called). + shedCount := gmetrics.LookupOperationCumulativeMetric(modelID+"ClientHTTP", stat.Shed) + downCount := gmetrics.LookupOperationCumulativeMetric(modelID+"ClientHTTP", stat.Down) + errorCount := gmetrics.LookupOperationCumulativeMetric(modelID+"ClientHTTP", stat.ErrorKey) + + assert.EqualValues(t, 1, shedCount, "ClientHTTP_shed must increment on shed") + assert.EqualValues(t, 0, downCount, "ClientHTTP_down must NOT increment on shed (only on trip)") + // _error still increments because Run()'s AppendError fires for the + // non-context ErrNodeDown -- this is the historical behavior that + // the new _shed marker disambiguates without changing. + assert.EqualValues(t, 1, errorCount, "ClientHTTP_error continues to increment as before") +} + +// TestService_Run_ParsesErrorBody verifies that when the server returns +// a non-2xx response with a JSON-encoded Response body (the v0.20.0+ +// error-response contract), Run() does a best-effort unmarshal of the +// body so the caller's response struct has Status="error" and Error +// populated -- in addition to receiving a non-nil err return value. +// +// Backward-compatibility: when the server returns a plain-text body +// (older mly versions, or any non-JSON body), the unmarshal silently +// fails and the response struct stays untouched. The non-nil err +// return remains the source-of-truth signal in either case. +func TestService_Run_ParsesErrorBody(t *testing.T) { + baseURL := toolbox.CallerDirectory(3) + + selectPort := 8088 + server := faker.Server{URL: path.Join(baseURL, "testdata"), Port: selectPort, Debug: true} + server.Start() + defer server.Stop() + + metaInput := shared.MetaInput{ + Inputs: []*shared.Field{ + {Name: "i1"}, + {Name: "i2", Wildcard: true}, + }, + } + dictionary := NewDictionary(&common.Dictionary{ + Layers: []common.Layer{{Name: "i1", Strings: []string{"v1", "v2"}}}, + Hash: 123, + }, metaInput.Inputs) + hosts := []*Host{NewHost("localhost", selectPort)} + options := []Option{ + WithRemoteConfig(&cconfig.Remote{ + Datastore: config.Datastore{ + Cache: &scache.Config{SizeMb: 64, Shards: 10, EntrySize: 1024}, + }, + MetaInput: metaInput, + }), + WithCacheScope(CacheScopeLocal), + WithDictionary(dictionary), + WithDataStorer(mock.New()), + WithDebug(true), + } + + cases := []struct { + description string + bodyContentType string + body string + statusCode int + expectErrorMsg string // non-empty if response.Error should be populated + expectStatus string // non-empty if response.Status should be populated + }{ + { + description: "v0.20.0 server: 400 JSON error body populates response.Error", + bodyContentType: "application/json", + body: `{"status":"error","error":"invalid input shape","serviceTimeMcs":150}`, + statusCode: http.StatusBadRequest, + expectErrorMsg: "invalid input shape", + expectStatus: common.StatusError, + }, + { + description: "v0.20.0 server: 500 JSON error body populates response.Error", + bodyContentType: "application/json", + body: `{"status":"error","error":"upstream blew up","serviceTimeMcs":2200}`, + statusCode: http.StatusInternalServerError, + expectErrorMsg: "upstream blew up", + expectStatus: common.StatusError, + }, + { + description: "older server: 400 plain-text body leaves response.Error empty", + bodyContentType: "text/plain", + body: "bad request\n", + statusCode: http.StatusBadRequest, + expectErrorMsg: "", + expectStatus: "", // gojay.Unmarshal silently fails on non-JSON; struct untouched + }, + { + description: "older server: 500 plain-text body leaves response.Error empty", + bodyContentType: "text/plain", + body: "server error\n", + statusCode: http.StatusInternalServerError, + expectErrorMsg: "", + expectStatus: "", + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + body := tc.body + contentType := tc.bodyContentType + statusCode := tc.statusCode + server.Handler.Then(func(d []byte, w http.ResponseWriter) { + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) + w.WriteHeader(statusCode) + _, _ = w.Write([]byte(body)) + }) + + srv, err := New("error_body_case", hosts, options...) + require.NoError(t, err) + + msg := srv.NewMessage() + msg.StringKey("i1", "v1") + msg.StringKey("i2", "v10") + + response := &Response{Data: &TestOutput{}} + err = srv.Run(context.Background(), msg, response) + + assert.Error(t, err, "non-2xx must always surface as a non-nil err return") + assert.Equal(t, tc.expectErrorMsg, response.Error, + "response.Error population (best-effort JSON unmarshal of error body)") + assert.Equal(t, tc.expectStatus, response.Status, + "response.Status population (best-effort JSON unmarshal of error body)") + }) + } +} diff --git a/shared/common/dictionary.go b/shared/common/dictionary.go index 715bf0b..24d13ba 100644 --- a/shared/common/dictionary.go +++ b/shared/common/dictionary.go @@ -28,9 +28,8 @@ type ( } ) -// TODO this should use a fixed size integer? // UpdateHash will memoize dictionary hashing. -// Since wildcard fields don't provide an actual dictionary, we use the modification time information to generate a hash based on the file, passed in as fsHash. +// fsHash provides a base hash value, used for cases when a particular model doesn't have a vocabulary to be hashed. func (d *Dictionary) UpdateHash(fsHash int64) int { d.Hash = int(fsHash) diff --git a/shared/common/error.go b/shared/common/error.go index ba62b13..0039897 100644 --- a/shared/common/error.go +++ b/shared/common/error.go @@ -1,9 +1,10 @@ package common import ( + "strings" + "github.com/aerospike/aerospike-client-go/types" "github.com/pkg/errors" - "strings" ) const ( @@ -11,10 +12,10 @@ const ( connRefusedError = "refused" ) -//ErrNodeDown node down error +// ErrNodeDown node down error var ErrNodeDown = errors.New("node is down") -//IsKeyNotFound returns true if key not found error +// IsKeyNotFound returns true if key not found error func IsKeyNotFound(err error) bool { if err == nil { return false @@ -33,7 +34,7 @@ func IsKeyNotFound(err error) bool { return aeroError.ResultCode() == types.KEY_NOT_FOUND_ERROR } -//IsTimeout returns true if timeout error +// IsTimeout returns true if timeout error func IsTimeout(err error) bool { if err == nil { return false @@ -52,12 +53,13 @@ func IsTimeout(err error) bool { return aeroError.ResultCode() == types.TIMEOUT } -//IsTransientError returns if transient error +// IsTransientError returns if transient error +// NOTE: This has an inverted dependency on Aerospike; the downstream implementation detail is not abstracted out. func IsTransientError(err error) bool { return IsKeyNotFound(err) || IsInvalidNode(err) || IsTimeout(err) || IsInvalidNode(err) || IsConnectionError(err) } -//IsInvalidNode returns true is node/cluster is down +// IsInvalidNode returns true is node/cluster is down func IsInvalidNode(err error) bool { if err == nil { return false @@ -79,7 +81,7 @@ func IsInvalidNode(err error) bool { return aeroError.ResultCode() == types.INVALID_NODE_ERROR } -//IsConnectionError returns true if error is connection errpr +// IsConnectionError returns true if error is connection errpr func IsConnectionError(err error) bool { if err == nil { return false diff --git a/shared/common/type.go b/shared/common/type.go index f0cf627..697f77c 100644 --- a/shared/common/type.go +++ b/shared/common/type.go @@ -10,11 +10,11 @@ import ( // dataType == "" is treated as a string type. func DataType(dataType string) (reflect.Type, error) { switch strings.ToLower(dataType) { - case "string": + case "string", "": return reflect.TypeOf(""), nil case "float64": return reflect.TypeOf(float64(0)), nil - case "float32", "float": + case "float32": return reflect.TypeOf(float32(0)), nil case "int": return reflect.TypeOf(int(0)), nil @@ -35,9 +35,6 @@ func DataType(dataType string) (reflect.Type, error) { case "[]float64": return reflect.TypeOf([]float64{}), nil default: - if dataType == "" { - return reflect.TypeOf(""), nil - } return nil, fmt.Errorf("unsupported data type: %v", dataType) } } diff --git a/shared/config/datastores.go b/shared/config/datastores.go index c25d5d3..65be57e 100644 --- a/shared/config/datastores.go +++ b/shared/config/datastores.go @@ -2,16 +2,17 @@ package config import ( "fmt" + "github.com/viant/mly/shared/config/datastore" ) -//DatastoreList represents datastore list +// DatastoreList represents datastore list type DatastoreList struct { Connections []*datastore.Connection Datastores []*Datastore } -//Init initialises list +// Init initialises list func (d *DatastoreList) Init() { if len(d.Connections) > 0 { for i := range d.Connections { @@ -25,14 +26,16 @@ func (d *DatastoreList) Init() { } } -//Validate checks if datastore list is valid +// Validate checks if datastore list is valid func (d *DatastoreList) Validate() error { if len(d.Connections) == 0 && len(d.Datastores) == 0 { return nil } + if len(d.Connections) > 0 && len(d.Datastores) == 0 { return fmt.Errorf("item were empty, but item defined") } + if len(d.Connections) > 0 { for _, item := range d.Connections { if err := item.Validate(); err != nil { @@ -40,10 +43,12 @@ func (d *DatastoreList) Validate() error { } } } + for _, item := range d.Datastores { if err := item.Validate(); err != nil { return err } } + return nil } diff --git a/shared/config/router/router.go b/shared/config/router/router.go index 58bd880..4a49743 100644 --- a/shared/config/router/router.go +++ b/shared/config/router/router.go @@ -1,6 +1,6 @@ package router -type RouterConfig struct { +type RoutingConfig struct { EntityMapping []EntityKV `json:"entityMapping" yaml:"entityMapping"` GlobalModelName string `json:"globalModelName" yaml:"globalModelName"` diff --git a/shared/config/router/router_test.go b/shared/config/router/router_test.go index e77c58b..c5aeb3c 100644 --- a/shared/config/router/router_test.go +++ b/shared/config/router/router_test.go @@ -8,7 +8,7 @@ import ( ) func TestJSON_EncodeDecode_WithGlobal(t *testing.T) { - cfg := &RouterConfig{ + cfg := &RoutingConfig{ EntityMapping: []EntityKV{ {EntityID: 12345, ModelName: "roas_model_12345_202511121116"}, {EntityID: 12347, ModelName: "roas_model_12347_202511111116"}, @@ -22,7 +22,7 @@ func TestJSON_EncodeDecode_WithGlobal(t *testing.T) { expected := `{"entityMapping":[{"entityID":12345,"modelName":"roas_model_12345_202511121116"},{"entityID":12347,"modelName":"roas_model_12347_202511111116"}],"globalModelName":"roas_global_202511111116"}` require.Equal(t, expected, string(data)) - var decoded RouterConfig + var decoded RoutingConfig require.NoError(t, json.Unmarshal(data, &decoded)) require.Equal(t, cfg.GlobalModelName, decoded.GlobalModelName) @@ -35,7 +35,7 @@ func TestJSON_EncodeDecode_WithGlobal(t *testing.T) { func TestJSON_Decode_NoGlobal(t *testing.T) { data := []byte(`{"entityMapping":[{"entityID":1,"modelName":"m1"}]}`) - var cfg RouterConfig + var cfg RoutingConfig require.NoError(t, json.Unmarshal(data, &cfg)) require.Empty(t, cfg.GlobalModelName) require.Len(t, cfg.EntityMapping, 1) @@ -45,7 +45,7 @@ func TestJSON_Decode_NoGlobal(t *testing.T) { func TestJSON_Decode_EmptyArray(t *testing.T) { data := []byte(`{"entityMapping":[]}`) - var cfg RouterConfig + var cfg RoutingConfig require.NoError(t, json.Unmarshal(data, &cfg)) require.NotNil(t, cfg.EntityMapping) require.Len(t, cfg.EntityMapping, 0) @@ -53,6 +53,6 @@ func TestJSON_Decode_EmptyArray(t *testing.T) { func TestJSON_Decode_InvalidEntityID(t *testing.T) { data := []byte(`{"entityMapping":[{"entityID":"oops","modelName":"x"}]}`) - var cfg RouterConfig + var cfg RoutingConfig require.Error(t, json.Unmarshal(data, &cfg)) } diff --git a/shared/datastore/client/service.go b/shared/datastore/client/service.go index d8052b0..0bd7c1d 100644 --- a/shared/datastore/client/service.go +++ b/shared/datastore/client/service.go @@ -3,13 +3,16 @@ package client import ( "context" "fmt" + "reflect" "strings" "time" aero "github.com/aerospike/aerospike-client-go" + "github.com/viant/gmetric" "github.com/viant/mly/shared/circut" "github.com/viant/mly/shared/common" "github.com/viant/mly/shared/config/datastore" + "github.com/viant/mly/shared/stat" "golang.org/x/sync/singleflight" ) @@ -28,6 +31,10 @@ const ( type Service struct { Client Aero + gmOpGet *gmetric.Operation + gmOpPutReq *gmetric.Operation + gmOpPutExec *gmetric.Operation + config *datastore.Connection // bypassConfiguredTimeout is used to bypass the configured timeout if using WithClientPolicy or WithBasePolicy. @@ -54,6 +61,12 @@ func (s *Service) Get(ctx context.Context, key *aero.Key, binNames ...string) (r return nil, common.ErrNodeDown } + onDone := s.gmOpGet.Begin(time.Now()) + stats := stat.NewValues() + defer func() { + onDone(time.Now(), stats.Values()...) + }() + defer func() { if r := recover(); r != nil { connection := s.config.ID @@ -63,6 +76,8 @@ func (s *Service) Get(ctx context.Context, key *aero.Key, binNames ...string) (r record, err = s.Client.Get(s.basePolicy, key, binNames...) s.checkConnectionError(err) + stats.AppendError(err) + return record, err } @@ -79,6 +94,12 @@ func (s *Service) Put(writePolicy *aero.WritePolicy, key *aero.Key, value aero.B keyStr := keyString(key) + onDone := s.gmOpPutReq.Begin(time.Now()) + stats := stat.NewValues() + defer func() { + onDone(time.Now(), stats.Values()...) + }() + defer func() { if r := recover(); r != nil { connection := s.config.ID @@ -97,8 +118,16 @@ func (s *Service) Put(writePolicy *aero.WritePolicy, key *aero.Key, value aero.B defer cancel() ch := s.group.DoChan(keyStr, func() (interface{}, error) { + onDone := s.gmOpPutExec.Begin(time.Now()) + stats := stat.NewValues() + defer func() { + onDone(time.Now(), stats.Values()...) + }() + err := s.Client.Put(writePolicy, key, value) s.checkConnectionError(err) + stats.AppendError(err) + return nil, err }) @@ -110,7 +139,7 @@ func (s *Service) Put(writePolicy *aero.WritePolicy, key *aero.Key, value aero.B err = fmt.Errorf("put aerospike[%s] key: %s shared: %v error: %w", s.config.ID, keyStr, res.Shared, res.Err) } } - + stats.AppendError(err) return err } @@ -189,12 +218,29 @@ func New(config *datastore.Connection) (*Service, error) { } func NewWithOptions(config *datastore.Connection, options ...Option) (*Service, error) { + return NewWithOptionsV2(config, nil, options...) +} + +func NewWithOptionsV2(config *datastore.Connection, gmetrics *gmetric.Service, options ...Option) (*Service, error) { + if gmetrics == nil { + gmetrics = gmetric.New() + } + + location := reflect.TypeOf(Service{}).PkgPath() + gmOpGet := gmetrics.MultiOperationCounter(location, config.ID+"AerospikeGet", config.ID+" get performance", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) + gmOpPutReq := gmetrics.MultiOperationCounter(location, config.ID+"AerospikePutRequested", config.ID+" put performance including singleflight", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) + gmOpPutExec := gmetrics.MultiOperationCounter(location, config.ID+"AerospikePutExecuted", config.ID+" put performance", time.Microsecond, time.Minute, 2, stat.NewCtxErrOnly()) + srv := &Service{ - config: config, - group: new(singleflight.Group), + config: config, + group: new(singleflight.Group), + gmOpGet: gmOpGet, + gmOpPutReq: gmOpPutReq, + gmOpPutExec: gmOpPutExec, } srv.init(options...) + breaker := circut.New(time.Second, srv) srv.Breaker = breaker return srv, srv.connect() diff --git a/shared/datastore/client/service_test.go b/shared/datastore/client/service_test.go index b99136f..c1da7d0 100644 --- a/shared/datastore/client/service_test.go +++ b/shared/datastore/client/service_test.go @@ -9,7 +9,7 @@ import ( aero "github.com/aerospike/aerospike-client-go" "github.com/viant/mly/shared/circut" - "golang.org/x/sync/singleflight" + "github.com/viant/mly/shared/config/datastore" ) type MockAero struct { @@ -37,13 +37,13 @@ func TestPut(t *testing.T) { L: mockLock, } - service := &Service{ - Client: mockAero, - group: new(singleflight.Group), - basePolicy: &aero.BasePolicy{ - TotalTimeout: 15 * time.Second, - }, + config := &datastore.Connection{ + ID: "test", } + config.Init() + + service, _ := NewWithOptionsV2(config, nil) + service.Client = mockAero breaker := circut.New(time.Second, service) service.Breaker = breaker diff --git a/shared/datastore/service.go b/shared/datastore/service.go index 5ac6c1f..cfcb66c 100644 --- a/shared/datastore/service.go +++ b/shared/datastore/service.go @@ -263,7 +263,8 @@ func (s *Service) updateCache(keyString string, entryData EntryData, dictHash in return nil } -// reads from local +// readFromCache reads from local. +// This can return an error if the cache cannot be unmarshalled. func (s *Service) readFromCache(keyString string, value Value, stats *stat.Values) (CacheStatus, int, error) { data, _ := s.cache.Get(keyString) if len(data) == 0 { diff --git a/shared/datastore/service_test.go b/shared/datastore/service_test.go index 5323cf9..ed69cf8 100644 --- a/shared/datastore/service_test.go +++ b/shared/datastore/service_test.go @@ -3,13 +3,12 @@ package datastore import ( "context" "testing" - "time" "github.com/aerospike/aerospike-client-go" "github.com/stretchr/testify/assert" - "github.com/viant/mly/shared/circut" "github.com/viant/mly/shared/common" + "github.com/viant/mly/shared/config/datastore" "github.com/viant/mly/shared/datastore/client" ) @@ -42,10 +41,16 @@ func TestFromClientMapsRecordAndDoesNotMapHashBin(t *testing.T) { "Field": "value", common.HashBin: 123, }} - clientSvc := &client.Service{ - Client: stubAeroRecord{record: rec}, - Breaker: circut.New(time.Second*10, &stubProber{}), + + config := &datastore.Connection{ + ID: "test", } + config.Init() + + clientSvc, _ := client.New(config) + // ignore err since we're mocking the client and don't need ot actually connect + + clientSvc.Client = stubAeroRecord{record: rec} key := &Key{Namespace: "ns", Set: "set", Value: "key"} type Foo struct { diff --git a/shared/datastore/stores.go b/shared/datastore/stores.go index 29f943b..18c8149 100644 --- a/shared/datastore/stores.go +++ b/shared/datastore/stores.go @@ -53,7 +53,7 @@ func NewStoresV4(cfg *config.DatastoreList, gmetrics *gmetric.Service, verbose b continue } - aero, err := client.NewWithOptions(connection, clientOptions...) + aero, err := client.NewWithOptionsV2(connection, gmetrics, clientOptions...) if err != nil { return nil, fmt.Errorf("failed to create client for %v, due to %w", connID, err) } diff --git a/shared/field.go b/shared/field.go index f498040..e05bf2f 100644 --- a/shared/field.go +++ b/shared/field.go @@ -15,7 +15,6 @@ type ( Index int // The type of the field. - // Supports "float" which maps to float32. // Otherwise, refer to reflect.Type.Name(). DataType string `json:",omitempty" yaml:",omitempty"` @@ -39,8 +38,8 @@ type ( MetaInput struct { Inputs []*Field - // This is used to order inputs and provide extra caching information to the client. - // All inputs from the model will automatically be added here. + // KeyFields is a method of forcing inputs to be part of the key even if not part of the model input. + // The primary use case of this is when there is a Transformer that depends on an Auxiliary input. KeyFields []string `json:",omitempty" yaml:",omitempty"` // Deprecated: use Field.Auxiliary @@ -78,9 +77,6 @@ func (f *Field) DataTypeToRawType() { // fieldDataTypeToRawType is a subset of reverse Name() to reflect.Type func fieldDataTypeToRawType(dataType string) reflect.Type { switch dataType { - case "float": - // provided as a convenience - return reflect.TypeOf(float32(0)) case "": // this case is treated as string in common.DataType(), but here it's not OK. panic(fmt.Sprintf("unsupported data type: %s", dataType)) @@ -119,6 +115,8 @@ func (m *MetaInput) OutputByName() map[string]*Field { return outputByName } +// TODO look into history of this method then document its purpose. +// Is "key" key as in "important" or as in "cache key"? func (d *MetaInput) KeysLen() int { return len(d.Inputs) } @@ -135,7 +133,6 @@ func (m *MetaInput) FieldByName() map[string]*Field { // On the server, it is called after reading the configuration file. // On the client, it is called after fetching the configuration from the server, which will have already processed it via reconcileIOFromSignature(). func (m *MetaInput) Init() { - // TODO assess why this approach was taken - this condition could be improved by having a map to see if the field by name already exists if len(m.Inputs) == 0 { // Add KeyFields to Inputs if len(m.KeyFields) > 0 { diff --git a/shared/stat/buckets/prometheus.go b/shared/stat/buckets/prometheus.go index 24f2b64..f11e4f5 100644 --- a/shared/stat/buckets/prometheus.go +++ b/shared/stat/buckets/prometheus.go @@ -6,10 +6,12 @@ package buckets var MicrosecondBuckets []float64 = []float64{ 100, 500, + // 1 millisecond 1000, 2000, 3000, 5000, 7500, 10000, 20000, 30000, 50000, 75000, 100000, 200000, 400000, 800000, - 1000000, + // 1 second + 1000000, 2000000, } var MillisecondBuckets []float64 = []float64{ @@ -28,8 +30,9 @@ var SecondBuckets []float64 = []float64{ } var CommonSummaryObjectives = map[float64]float64{ - 0.5: 0.05, - 0.9: 0.01, - 0.95: 0.005, - 0.99: 0.001, + 0.5: 0.05, + 0.9: 0.01, + 0.95: 0.005, + 0.99: 0.001, + 0.999: 0.001, } diff --git a/shared/stat/http.go b/shared/stat/http.go index a2e4320..a70ad25 100644 --- a/shared/stat/http.go +++ b/shared/stat/http.go @@ -5,15 +5,63 @@ import "github.com/viant/gmetric/counter" // TODO move to shared/client type http struct{} -const Pending = "pending" +const ( + // Pending is the column for the in-flight gauge maintained by the + // metric.EnterThenExit Inc/Dec pattern. The exporter publishes: + // - _pending -- the current-in-flight counter (defective; see below) + // - _pending_Max -- per-bucket peak from the Occupancy CustomCounter + // + // Known defects in the underlying mechanism: + // + // 1. Bucket-mismatch on Exit. EnterThenExit captures the recent-bucket + // index at Enter time and decrements that same bucket on Exit. If + // the bucket has rotated between Enter and Exit, the wrong bucket + // is decremented -- the previous bucket's value goes negative + // while the current bucket's value drifts high. + // + // 2. Mutex serialization on Inc/Dec. The Dir typed value is not a + // string, so MultiCounter.incrementValueBy takes c.locker.Lock() + // on every Enter and Exit. Under high QPS this is a real + // serialization point. + // + // Defect #1 inflates both _pending (current) and _pending_Max (per-bucket + // peak); the inflation is in the conservative direction (over-estimation), + // so the metrics are still operationally useful in different regimes: + // + // - _pending_Max grouped per reporting dimension (e.g. by + // availability_zone, environment, op): for high-QPS operations + // the per-group peak rises substantially above baseline noise + // during fleet-wide saturation events, making this the cleaner + // saturation signal for those operations. + // + // - _pending summed across reporting instances: exhibits dramatic + // spikes during saturation for any QPS profile, partially + // amplified by defect #1. For low-QPS operations where the + // per-group _pending_Max signal is lost in baseline noise, the + // fleet sum is the more visible saturation signal. + // + // _pending_Max is the cleaner peak-concurrency signal for capacity + // sizing; pick per-operation based on QPS profile. + Pending = "pending" + // Shed marks a request that the client did NOT send because the host's + // circuit breaker was already in the down state when getHost() was + // called. Distinct from Down, which marks the trip event itself + // (the request that observed the connection error and called + // FlagDown). Shed is the count of subsequent requests that the + // breaker rejected before recovery. + Shed = "shed" +) func (p http) Keys() []string { + // New keys must be appended at the end so existing column indices + // remain stable for downstream consumers (Mimir queries, dashboards). return []string{ ErrorKey, Pending, Down, Canceled, DeadlineExceeded, + Shed, } } @@ -35,6 +83,8 @@ func (p http) Map(value interface{}) int { return 3 case DeadlineExceeded: return 4 + case Shed: + return 5 } case Dir: return 1 diff --git a/shared/stat/promc/error.go b/shared/stat/promc/error.go new file mode 100644 index 0000000..7f1d81f --- /dev/null +++ b/shared/stat/promc/error.go @@ -0,0 +1,25 @@ +package promc + +import ( + "context" + "errors" + + "github.com/prometheus/client_golang/prometheus" +) + +type BaseErrorCounters struct { + DeadlineExceededCounter prometheus.Counter + CanceledCounter prometheus.Counter + + OtherErrorCounter prometheus.Counter +} + +func (c BaseErrorCounters) Observe(err error) { + if c.DeadlineExceededCounter != nil && errors.Is(err, context.DeadlineExceeded) { + c.DeadlineExceededCounter.Inc() + } else if c.CanceledCounter != nil && errors.Is(err, context.Canceled) { + c.CanceledCounter.Inc() + } else if c.OtherErrorCounter != nil { + c.OtherErrorCounter.Inc() + } +} diff --git a/shared/transfer/input.go b/shared/transfer/input.go index 22dfd13..6775061 100644 --- a/shared/transfer/input.go +++ b/shared/transfer/input.go @@ -6,9 +6,14 @@ import ( type Input struct { BatchSize int - Keys Strings + + // cache keys + Keys Strings + Values - Unmapped Values // values that are not part of an input + + // values that are not part of an input + Unmapped Values } func (i *Input) BatchMode() bool {