diff --git a/README.md b/README.md index 4495d54..c209a01 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A lightweight, idiomatic AI SDK for Go — inspired by [Vercel AI SDK](https://s ## Features -- **Simple API** — `GenerateText`, `StreamText`, `Embed`, `EmbedMany`, `GenerateImage`, `EditImage`, `GenerateSpeech`, and `StreamSpeech` cover most use cases +- **Simple API** — `GenerateText`, `StreamText`, `Embed`, `EmbedMany`, `GenerateImage`, `EditImage`, `GenerateVideo`, `GenerateSpeech`, and `StreamSpeech` cover most use cases - **Provider-agnostic** — swap between OpenAI, Anthropic, Google, GitHub Copilot, Edge TTS, or any OpenAI-compatible endpoint - **Model discovery** — `ListModels` fetches available models, `Test` checks provider connectivity and model support - **Tool calling** — define tools with Go structs, SDK infers JSON Schema and handles multi-step execution @@ -17,6 +17,7 @@ A lightweight, idiomatic AI SDK for Go — inspired by [Vercel AI SDK](https://s - **Rich message types** — text, images, files, reasoning content, tool calls/results - **Embeddings** — generate embeddings with `Embed` / `EmbedMany`, supports OpenAI and Google providers - **Image generation** — generate and edit images with `GenerateImage` / `EditImage`, supports dall-e-2, dall-e-3, and gpt-image-1 +- **Video generation** — create, poll, and download video jobs with OpenRouter and Ark/ModelArk providers - **Speech synthesis** — generate speech with `GenerateSpeech` / `StreamSpeech`, supports Edge TTS with an open provider model - **Approval flow** — optional human-in-the-loop approval for sensitive tool calls diff --git a/docs/videos.md b/docs/videos.md new file mode 100644 index 0000000..5684496 --- /dev/null +++ b/docs/videos.md @@ -0,0 +1,207 @@ +# Videos + +The Twilight AI SDK provides asynchronous video generation through `sdk.CreateVideo`, `sdk.GetVideo`, `sdk.CancelVideo`, `sdk.DownloadVideo`, and the convenience helper `sdk.GenerateVideo`. + +Video generation jobs can take minutes, so providers are modeled as create-and-poll backends instead of streaming or single-response calls. + +## Quick Start + +```go +package main + +import ( + "context" + "fmt" + "log" + "time" + + openroutervideos "github.com/memohai/twilight-ai/provider/openrouter/videos" + "github.com/memohai/twilight-ai/sdk" +) + +func main() { + provider := openroutervideos.New( + openroutervideos.WithAPIKey("sk-or-..."), + ) + model := provider.VideoModel("google/veo-3.1") + + result, err := sdk.GenerateVideo(context.Background(), + sdk.WithVideoModel(model), + sdk.WithVideoPrompt("A cinematic tracking shot of waves at sunrise"), + sdk.WithVideoDuration(8), + sdk.WithVideoResolution("720p"), + sdk.WithVideoAspectRatio("16:9"), + sdk.WithVideoPollInterval(5*time.Second), + ) + if err != nil { + log.Fatal(err) + } + + fmt.Println(result.Job.ID, result.Job.Status) + if result.Output != nil { + fmt.Println(result.Output.URL) + } +} +``` + +For manual polling: + +```go +job, err := sdk.CreateVideo(ctx, + sdk.WithVideoModel(model), + sdk.WithVideoPrompt("A short product reveal clip"), + sdk.WithVideoWait(false), +) +if err != nil { + return err +} + +for job.Status != sdk.VideoJobSucceeded { + job, err = sdk.GetVideo(ctx, model, job.ID) + if err != nil { + return err + } + time.Sleep(5 * time.Second) +} + +data, contentType, err := sdk.DownloadVideo(ctx, model, job.Outputs[0]) +``` + +## Unified API + +`VideoParams` covers common video fields and keeps provider-specific options in `Config`: + +| Field | Purpose | +|-------|---------| +| `Prompt` | Text description | +| `Size` | Exact output dimensions, e.g. `1280x720` | +| `Resolution` | Provider resolution label, e.g. `720p` | +| `AspectRatio` | Ratio such as `16:9`, `9:16`, `1:1` | +| `DurationSeconds` | Clip duration | +| `Seed` | Deterministic generation seed when supported | +| `GenerateAudio` | Request generated audio when supported | +| `CallbackURL` | Provider webhook callback URL | +| `InputImage` | First-frame or image-to-video input | +| `InputVideo` | Video-to-video or edit input when supported | +| `ReferenceImages` | Image references | +| `ReferenceVideos` | Video references | +| `ReferenceAudio` | Audio references | +| `Config` | Provider-specific passthrough | + +Statuses are normalized to: + +```go +sdk.VideoJobQueued +sdk.VideoJobRunning +sdk.VideoJobSucceeded +sdk.VideoJobFailed +sdk.VideoJobCanceled +``` + +`sdk.GenerateVideo` waits by default with a 10 minute timeout and 5 second poll interval. Use `WithVideoWait(false)`, `WithVideoPollTimeout(...)`, and `WithVideoPollInterval(...)` to override this behavior. + +## OpenRouter Videos + +Package: + +```go +import openroutervideos "github.com/memohai/twilight-ai/provider/openrouter/videos" +``` + +Default base URL: + +```text +https://openrouter.ai/api +``` + +The provider appends `/v1/...`, so custom base URLs should follow the same convention and include `/api` when targeting OpenRouter directly. + +Official docs: + +- [Video Generation](https://openrouter.ai/docs/guides/overview/multimodal/video-generation) +- [Submit a video generation request](https://openrouter.ai/docs/api/api-reference/video-generation/create-videos) +- [Poll video generation status](https://openrouter.ai/docs/api/api-reference/video-generation/get-videos) +- [Download generated video content](https://openrouter.ai/docs/api/api-reference/video-generation/list-videos-content) +- [List all video generation models](https://openrouter.ai/docs/api/api-reference/video-generation/list-videos-models) + +Endpoints: + +| SDK operation | OpenRouter API | +|---------------|----------------| +| `CreateVideo` | `POST /v1/videos` | +| `GetVideo` | `GET /v1/videos/{jobId}` | +| `DownloadVideo` | returned `unsigned_urls` | +| `ListModels` | `GET /v1/videos/models` | + +Field mapping: + +| SDK field | OpenRouter field | +|-----------|------------------| +| `Prompt` | `prompt` | +| `Model.ID` | `model` | +| `DurationSeconds` | `duration` | +| `Resolution` | `resolution` | +| `AspectRatio` | `aspect_ratio` | +| `Size` | `size` | +| `Seed` | `seed` | +| `GenerateAudio` | `generate_audio` | +| `CallbackURL` | `callback_url` | +| `InputImage` | `frame_images[0]` as `first_frame` | +| references | `input_references` | +| `Config["provider"]` | `provider` passthrough object | + +## Ark / ModelArk Videos + +Package: + +```go +import arkvideos "github.com/memohai/twilight-ai/provider/ark/videos" +``` + +Base URL constants: + +```go +arkvideos.BytePlusBaseURL // https://ark.ap-southeast.bytepluses.com/api/v3 +arkvideos.VolcengineBaseURL // https://ark.cn-beijing.volces.com/api/v3 +``` + +Official docs: + +- [BytePlus Seedance 2.0 API Reference](https://docs.byteplus.com/en/docs/ModelArk/1520757) +- [BytePlus Video generation API](https://docs.byteplus.com/en/docs/ModelArk/Video_Generation_API) +- [火山方舟查询视频生成任务 API](https://www.volcengine.com/docs/82379/1521309) + +Example: + +```go +provider := arkvideos.New( + arkvideos.WithAPIKey("ark-..."), + arkvideos.WithBaseURL(arkvideos.VolcengineBaseURL), +) +model := provider.VideoModel("doubao-seedance-2-0-260128") +``` + +Endpoints: + +| SDK operation | Ark / ModelArk API | +|---------------|--------------------| +| `CreateVideo` | `POST /contents/generations/tasks` | +| `GetVideo` | `GET /contents/generations/tasks/{id}` | +| `CancelVideo` | `DELETE /contents/generations/tasks/{id}` | + +Field mapping: + +| SDK field | Ark / ModelArk field | +|-----------|----------------------| +| `Model.ID` | `model` | +| `Prompt` | `content[0] = {type:"text", text:...}` | +| media inputs | `content[]` URL items | +| `DurationSeconds` | `duration` | +| `Resolution` | `resolution` | +| `AspectRatio` | `ratio` | +| `GenerateAudio` | `generate_audio` | +| `Seed` | `seed` | +| `CallbackURL` | `callback_url` | +| `Config` | top-level passthrough fields | + +The provider reads completed outputs from `content.video_url` and other video URL fields returned by the task API. `ListModels` intentionally returns an empty list in v1 because model discovery for Ark lives in control-plane APIs rather than the data-plane task API. diff --git a/provider/ark/videos/videos.go b/provider/ark/videos/videos.go new file mode 100644 index 0000000..e243f15 --- /dev/null +++ b/provider/ark/videos/videos.go @@ -0,0 +1,375 @@ +// Package videos provides Ark / ModelArk video generation support for +// BytePlus ModelArk and Volcengine Ark data-plane APIs. +package videos + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + + "github.com/memohai/twilight-ai/internal/utils" + "github.com/memohai/twilight-ai/sdk" +) + +const ( + BytePlusBaseURL = "https://ark.ap-southeast.bytepluses.com/api/v3" + VolcengineBaseURL = "https://ark.cn-beijing.volces.com/api/v3" +) + +type Provider struct { + apiKey string + baseURL string + httpClient *http.Client +} + +type Option func(*Provider) + +func WithAPIKey(apiKey string) Option { + return func(p *Provider) { p.apiKey = apiKey } +} + +func WithBaseURL(baseURL string) Option { + return func(p *Provider) { p.baseURL = strings.TrimRight(baseURL, "/") } +} + +func WithHTTPClient(client *http.Client) Option { + return func(p *Provider) { p.httpClient = client } +} + +func New(options ...Option) *Provider { + p := &Provider{httpClient: &http.Client{}} + for _, opt := range options { + opt(p) + } + p.baseURL = strings.TrimRight(p.baseURL, "/") + return p +} + +func (p *Provider) VideoModel(id string) *sdk.VideoModel { + return &sdk.VideoModel{ID: id, Provider: p} +} + +func (p *Provider) ListModels(context.Context) ([]*sdk.VideoModel, error) { + return []*sdk.VideoModel{}, nil +} + +//nolint:gocritic // VideoProvider keeps VideoParams value-based for a stable public SDK contract. +func (p *Provider) DoCreate(ctx context.Context, params sdk.VideoParams) (*sdk.VideoJob, error) { + if params.Model == nil { + return nil, fmt.Errorf("ark videos: model is required") + } + body := p.buildCreateBody(¶ms) + resp, err := utils.FetchJSON[map[string]any](ctx, p.httpClient, &utils.RequestOptions{ + Method: http.MethodPost, + BaseURL: p.baseURL, + Path: "/contents/generations/tasks", + Headers: utils.AuthHeader(p.apiKey), + Body: body, + }) + if err != nil { + return nil, fmt.Errorf("ark videos: create request failed: %w", err) + } + return toVideoJob(*resp, params.Model.ID), nil +} + +func (p *Provider) DoGet(ctx context.Context, model *sdk.VideoModel, id string) (*sdk.VideoJob, error) { + resp, err := utils.FetchJSON[map[string]any](ctx, p.httpClient, &utils.RequestOptions{ + Method: http.MethodGet, + BaseURL: p.baseURL, + Path: "/contents/generations/tasks/" + id, + Headers: utils.AuthHeader(p.apiKey), + }) + if err != nil { + return nil, fmt.Errorf("ark videos: get request failed: %w", err) + } + modelID := "" + if model != nil { + modelID = model.ID + } + return toVideoJob(*resp, modelID), nil +} + +func (p *Provider) DoCancel(ctx context.Context, _ *sdk.VideoModel, id string) error { + resp, err := utils.FetchRaw(ctx, p.httpClient, &utils.RequestOptions{ + Method: http.MethodDelete, + BaseURL: p.baseURL, + Path: "/contents/generations/tasks/" + id, + Headers: utils.AuthHeader(p.apiKey), + }) + if err != nil { + return fmt.Errorf("ark videos: cancel/delete request failed: %w", err) + } + _ = resp.Body.Close() + return nil +} + +func (p *Provider) DoDownload(ctx context.Context, _ *sdk.VideoModel, output sdk.VideoOutput) (data []byte, contentType string, err error) { + if output.URL == "" { + return nil, "", fmt.Errorf("ark videos: output URL is required") + } + url := output.URL + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + built, err := utils.BuildURL(p.baseURL, url) + if err != nil { + return nil, "", fmt.Errorf("ark videos: build download URL: %w", err) + } + url = built + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, "", fmt.Errorf("ark videos: build download request: %w", err) + } + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, "", fmt.Errorf("ark videos: download request failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return nil, "", fmt.Errorf("ark videos: download failed with status %d: %s", resp.StatusCode, string(body)) + } + data, err = io.ReadAll(resp.Body) + if err != nil { + return nil, "", fmt.Errorf("ark videos: read download response: %w", err) + } + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = output.ContentType + } + if contentType == "" { + contentType = "video/mp4" + } + return data, contentType, nil +} + +func (p *Provider) buildCreateBody(params *sdk.VideoParams) map[string]any { + body := map[string]any{ + "model": params.Model.ID, + "content": buildContent(params), + } + if params.DurationSeconds != nil { + body["duration"] = *params.DurationSeconds + } + if params.Resolution != "" { + body["resolution"] = params.Resolution + } + if params.AspectRatio != "" { + body["ratio"] = params.AspectRatio + } + if params.GenerateAudio != nil { + body["generate_audio"] = *params.GenerateAudio + } + if params.Seed != nil { + body["seed"] = *params.Seed + } + if params.CallbackURL != "" { + body["callback_url"] = params.CallbackURL + } + for k, v := range params.Config { + body[k] = v + } + return body +} + +func buildContent(params *sdk.VideoParams) []map[string]any { + content := []map[string]any{{"type": "text", "text": params.Prompt}} + if params.InputImage != nil { + if url := mediaURL(params.InputImage); url != "" { + content = append(content, mediaItem("image_url", url, "first_frame")) + } + } + if params.InputVideo != nil { + if url := mediaURL(params.InputVideo); url != "" { + content = append(content, mediaItem("video_url", url, "input_video")) + } + } + for i := range params.ReferenceImages { + if url := mediaURL(¶ms.ReferenceImages[i]); url != "" { + content = append(content, mediaItem("image_url", url, "reference_image")) + } + } + for i := range params.ReferenceVideos { + if url := mediaURL(¶ms.ReferenceVideos[i]); url != "" { + content = append(content, mediaItem("video_url", url, "reference_video")) + } + } + for i := range params.ReferenceAudio { + if url := mediaURL(¶ms.ReferenceAudio[i]); url != "" { + content = append(content, mediaItem("audio_url", url, "reference_audio")) + } + } + return content +} + +func mediaItem(field, url, role string) map[string]any { + item := map[string]any{ + "type": field, + field: map[string]any{"url": url}, + } + if role != "" { + item["role"] = role + } + return item +} + +func toVideoJob(raw map[string]any, fallbackModelID string) *sdk.VideoJob { + inner := unwrap(raw) + id := firstString(inner, "id", "task_id", "taskId") + if id == "" { + id = firstString(raw, "id", "task_id", "taskId") + } + modelID := firstString(inner, "model", "model_id") + if modelID == "" { + modelID = fallbackModelID + } + status := firstString(inner, "status", "task_status") + progress := firstFloat(inner, "progress", "percent") + + job := &sdk.VideoJob{ + ID: id, + ModelID: modelID, + Status: mapStatus(status), + Progress: progress, + ProviderMetadata: raw, + } + if errMsg := extractError(inner); errMsg != "" { + job.Error = &sdk.VideoError{Message: errMsg} + } + for _, url := range extractVideoURLs(inner) { + job.Outputs = append(job.Outputs, sdk.VideoOutput{ + URL: url, + ContentType: "video/mp4", + ProviderMetadata: map[string]any{ + "task_id": id, + }, + }) + } + return job +} + +func unwrap(raw map[string]any) map[string]any { + for _, key := range []string{"data", "task"} { + if nested, ok := raw[key].(map[string]any); ok { + return nested + } + } + return raw +} + +func mapStatus(status string) sdk.VideoJobStatus { + switch strings.ToLower(strings.TrimSpace(status)) { + case "queued", "pending": + return sdk.VideoJobQueued + case "running", "in_progress", "processing": + return sdk.VideoJobRunning + case "succeeded", "success", "completed": + return sdk.VideoJobSucceeded + case "failed", "error": + return sdk.VideoJobFailed + case "canceled", "cancelled", "deleted": + return sdk.VideoJobCanceled + default: + return sdk.VideoJobRunning + } +} + +func mediaURL(input *sdk.MediaInput) string { + if input == nil { + return "" + } + if input.URL != "" { + return input.URL + } + if len(input.Data) == 0 { + return input.FileID + } + contentType := input.ContentType + if contentType == "" { + contentType = "application/octet-stream" + } + return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(input.Data) +} + +func firstString(obj map[string]any, keys ...string) string { + for _, key := range keys { + if s, ok := obj[key].(string); ok { + return s + } + } + return "" +} + +func firstFloat(obj map[string]any, keys ...string) *float64 { + for _, key := range keys { + if f, ok := utils.ToFloat64(obj[key]); ok { + return &f + } + } + return nil +} + +func extractError(obj map[string]any) string { + if s, ok := obj["error"].(string); ok { + return s + } + if m, ok := obj["error"].(map[string]any); ok { + if s, ok := m["message"].(string); ok { + return s + } + } + if s, ok := obj["message"].(string); ok && mapStatus(firstString(obj, "status")) == sdk.VideoJobFailed { + return s + } + return "" +} + +func extractVideoURLs(v any) []string { + var out []string + var walk func(any) + walk = func(value any) { + switch typed := value.(type) { + case map[string]any: + for _, key := range []string{"video_url", "videoUrl", "output_url", "outputUrl"} { + if s, ok := typed[key].(string); ok && s != "" { + out = append(out, s) + } + } + if s, ok := typed["url"].(string); ok && hasVideoExtension(s) { + out = append(out, s) + } + for _, nested := range typed { + walk(nested) + } + case []any: + for _, item := range typed { + walk(item) + } + } + } + walk(v) + return dedupe(out) +} + +func hasVideoExtension(s string) bool { + if s == "" { + return false + } + lower := strings.ToLower(s) + return strings.Contains(lower, ".mp4") || strings.Contains(lower, ".mov") +} + +func dedupe(values []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(values)) + for _, value := range values { + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} diff --git a/provider/ark/videos/videos_test.go b/provider/ark/videos/videos_test.go new file mode 100644 index 0000000..4504444 --- /dev/null +++ b/provider/ark/videos/videos_test.go @@ -0,0 +1,127 @@ +package videos + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/memohai/twilight-ai/sdk" +) + +func TestDoCreateBuildsArkContent(t *testing.T) { + var got map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/contents/generations/tasks" { + t.Fatalf("path = %s", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer ark-key" { + t.Fatalf("missing auth header") + } + if err := json.NewDecoder(r.Body).Decode(&got); err != nil { + t.Fatalf("decode request: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{"id": "task-1", "status": "queued"}) + })) + defer server.Close() + + prov := New(WithAPIKey("ark-key"), WithBaseURL(server.URL)) + _, err := prov.DoCreate(context.Background(), sdk.VideoParams{ + Model: prov.VideoModel("doubao-seedance-2-0-260128"), + Prompt: "city at night", + DurationSeconds: intPtr(8), + Resolution: "720p", + AspectRatio: "16:9", + GenerateAudio: boolPtr(true), + Seed: int64Ptr(9), + CallbackURL: "https://example.com/hook", + InputImage: &sdk.MediaInput{URL: "https://example.com/first.png"}, + InputVideo: &sdk.MediaInput{URL: "https://example.com/source.mp4"}, + ReferenceImages: []sdk.MediaInput{{URL: "https://example.com/ref.png"}}, + ReferenceVideos: []sdk.MediaInput{{URL: "https://example.com/ref.mp4"}}, + ReferenceAudio: []sdk.MediaInput{{URL: "https://example.com/ref.mp3"}}, + Config: map[string]any{"watermark": false}, + }) + if err != nil { + t.Fatalf("DoCreate returned error: %v", err) + } + + if got["model"] != "doubao-seedance-2-0-260128" || got["duration"].(float64) != 8 { + t.Fatalf("unexpected top-level body: %#v", got) + } + if got["resolution"] != "720p" || got["ratio"] != "16:9" || got["callback_url"] != "https://example.com/hook" { + t.Fatalf("unexpected mapped fields: %#v", got) + } + if got["watermark"] != false { + t.Fatalf("config passthrough missing: %#v", got) + } + content := got["content"].([]any) + if len(content) != 6 { + t.Fatalf("content len = %d, want 6: %#v", len(content), content) + } + text := content[0].(map[string]any) + if text["type"] != "text" || text["text"] != "city at night" { + t.Fatalf("unexpected text item: %#v", text) + } + image := content[1].(map[string]any) + if image["type"] != "image_url" || image["role"] != "first_frame" { + t.Fatalf("unexpected image item: %#v", image) + } +} + +func TestDoGetExtractsVideoURL(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/contents/generations/tasks/task-1" { + t.Fatalf("path = %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "task-1", + "model": "doubao-seedance-2-0-260128", + "status": "succeeded", + "content": map[string]any{ + "video_url": "https://example.com/out.mp4", + }, + }) + })) + defer server.Close() + + prov := New(WithAPIKey("ark-key"), WithBaseURL(server.URL)) + job, err := prov.DoGet(context.Background(), prov.VideoModel("doubao-seedance-2-0-260128"), "task-1") + if err != nil { + t.Fatalf("DoGet returned error: %v", err) + } + if job.Status != sdk.VideoJobSucceeded { + t.Fatalf("status = %s", job.Status) + } + if len(job.Outputs) != 1 || job.Outputs[0].URL != "https://example.com/out.mp4" { + t.Fatalf("unexpected outputs: %#v", job.Outputs) + } +} + +func TestDoCancelUsesDeleteTaskEndpoint(t *testing.T) { + called := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + if r.Method != http.MethodDelete { + t.Fatalf("method = %s", r.Method) + } + if r.URL.Path != "/contents/generations/tasks/task-1" { + t.Fatalf("path = %s", r.URL.Path) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + prov := New(WithAPIKey("ark-key"), WithBaseURL(server.URL)) + if err := prov.DoCancel(context.Background(), prov.VideoModel("doubao-seedance-2-0-260128"), "task-1"); err != nil { + t.Fatalf("DoCancel returned error: %v", err) + } + if !called { + t.Fatalf("server was not called") + } +} + +func intPtr(v int) *int { return &v } +func int64Ptr(v int64) *int64 { return &v } +func boolPtr(v bool) *bool { return &v } diff --git a/provider/openrouter/videos/types.go b/provider/openrouter/videos/types.go new file mode 100644 index 0000000..b0f5a81 --- /dev/null +++ b/provider/openrouter/videos/types.go @@ -0,0 +1,65 @@ +package videos + +type createRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + AspectRatio string `json:"aspect_ratio,omitempty"` + CallbackURL string `json:"callback_url,omitempty"` + Duration *int `json:"duration,omitempty"` + FrameImages []frameImage `json:"frame_images,omitempty"` + GenerateAudio *bool `json:"generate_audio,omitempty"` + InputReferences []inputReference `json:"input_references,omitempty"` + Provider any `json:"provider,omitempty"` + Resolution string `json:"resolution,omitempty"` + Seed *int64 `json:"seed,omitempty"` + Size string `json:"size,omitempty"` +} + +type mediaURLObject struct { + URL string `json:"url"` +} + +type frameImage struct { + ImageURL mediaURLObject `json:"image_url"` + Type string `json:"type"` + FrameType string `json:"frame_type"` +} + +type inputReference struct { + Type string `json:"type"` + AudioURL *mediaURLObject `json:"audio_url,omitempty"` + ImageURL *mediaURLObject `json:"image_url,omitempty"` + VideoURL *mediaURLObject `json:"video_url,omitempty"` +} + +type videoResponse struct { + ID string `json:"id"` + PollingURL string `json:"polling_url,omitempty"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + GenerationID string `json:"generation_id,omitempty"` + UnsignedURLs []string `json:"unsigned_urls,omitempty"` + Usage map[string]any `json:"usage,omitempty"` +} + +type listModelsResponse struct { + Data []modelResponse `json:"data"` +} + +type modelResponse struct { + AllowedPassthroughParameters []string `json:"allowed_passthrough_parameters"` + CanonicalSlug string `json:"canonical_slug"` + Created int64 `json:"created"` + Description string `json:"description"` + GenerateAudio *bool `json:"generate_audio"` + HuggingFaceID *string `json:"hugging_face_id"` + ID string `json:"id"` + Name string `json:"name"` + PricingSKUs map[string]any `json:"pricing_skus"` + Seed *bool `json:"seed"` + SupportedAspectRatios []string `json:"supported_aspect_ratios"` + SupportedDurations []int `json:"supported_durations"` + SupportedFrameImages []string `json:"supported_frame_images"` + SupportedResolutions []string `json:"supported_resolutions"` + SupportedSizes []string `json:"supported_sizes"` +} diff --git a/provider/openrouter/videos/videos.go b/provider/openrouter/videos/videos.go new file mode 100644 index 0000000..9e71395 --- /dev/null +++ b/provider/openrouter/videos/videos.go @@ -0,0 +1,259 @@ +// Package videos provides OpenRouter video generation support. +package videos + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + + "github.com/memohai/twilight-ai/internal/utils" + "github.com/memohai/twilight-ai/sdk" +) + +const defaultBaseURL = "https://openrouter.ai/api" + +type Provider struct { + apiKey string + baseURL string + httpClient *http.Client +} + +type Option func(*Provider) + +func WithAPIKey(apiKey string) Option { + return func(p *Provider) { p.apiKey = apiKey } +} + +func WithBaseURL(baseURL string) Option { + return func(p *Provider) { p.baseURL = strings.TrimRight(baseURL, "/") } +} + +func WithHTTPClient(client *http.Client) Option { + return func(p *Provider) { p.httpClient = client } +} + +func New(options ...Option) *Provider { + p := &Provider{ + baseURL: defaultBaseURL, + httpClient: &http.Client{}, + } + for _, opt := range options { + opt(p) + } + p.baseURL = strings.TrimRight(p.baseURL, "/") + return p +} + +func (p *Provider) VideoModel(id string) *sdk.VideoModel { + return &sdk.VideoModel{ID: id, Provider: p} +} + +func (p *Provider) ListModels(ctx context.Context) ([]*sdk.VideoModel, error) { + resp, err := utils.FetchJSON[listModelsResponse](ctx, p.httpClient, &utils.RequestOptions{ + Method: http.MethodGet, + BaseURL: p.baseURL, + Path: "/v1/videos/models", + Headers: utils.AuthHeader(p.apiKey), + }) + if err != nil { + return nil, fmt.Errorf("openrouter videos: list models request failed: %w", err) + } + models := make([]*sdk.VideoModel, 0, len(resp.Data)) + for i := range resp.Data { + item := &resp.Data[i] + models = append(models, &sdk.VideoModel{ + ID: item.ID, + Provider: p, + ProviderMetadata: map[string]any{ + "canonical_slug": item.CanonicalSlug, + "name": item.Name, + "description": item.Description, + "allowed_passthrough_parameters": item.AllowedPassthroughParameters, + "generate_audio": item.GenerateAudio, + "seed": item.Seed, + "supported_aspect_ratios": item.SupportedAspectRatios, + "supported_durations": item.SupportedDurations, + "supported_frame_images": item.SupportedFrameImages, + "supported_resolutions": item.SupportedResolutions, + "supported_sizes": item.SupportedSizes, + "pricing_skus": item.PricingSKUs, + }, + }) + } + return models, nil +} + +//nolint:gocritic // VideoProvider keeps VideoParams value-based for a stable public SDK contract. +func (p *Provider) DoCreate(ctx context.Context, params sdk.VideoParams) (*sdk.VideoJob, error) { + if params.Model == nil { + return nil, fmt.Errorf("openrouter videos: model is required") + } + req := createRequest{ + Model: params.Model.ID, + Prompt: params.Prompt, + AspectRatio: params.AspectRatio, + CallbackURL: params.CallbackURL, + Duration: params.DurationSeconds, + GenerateAudio: params.GenerateAudio, + Resolution: params.Resolution, + Seed: params.Seed, + Size: params.Size, + } + if params.InputImage != nil { + if url := mediaURL(params.InputImage); url != "" { + req.FrameImages = append(req.FrameImages, frameImage{ + Type: "image_url", + FrameType: "first_frame", + ImageURL: mediaURLObject{URL: url}, + }) + } + } + for i := range params.ReferenceImages { + if url := mediaURL(¶ms.ReferenceImages[i]); url != "" { + req.InputReferences = append(req.InputReferences, inputReference{Type: "image_url", ImageURL: &mediaURLObject{URL: url}}) + } + } + for i := range params.ReferenceAudio { + if url := mediaURL(¶ms.ReferenceAudio[i]); url != "" { + req.InputReferences = append(req.InputReferences, inputReference{Type: "audio_url", AudioURL: &mediaURLObject{URL: url}}) + } + } + for i := range params.ReferenceVideos { + if url := mediaURL(¶ms.ReferenceVideos[i]); url != "" { + req.InputReferences = append(req.InputReferences, inputReference{Type: "video_url", VideoURL: &mediaURLObject{URL: url}}) + } + } + if providerOptions, ok := params.Config["provider"]; ok { + req.Provider = providerOptions + } + + resp, err := utils.FetchJSON[videoResponse](ctx, p.httpClient, &utils.RequestOptions{ + Method: http.MethodPost, + BaseURL: p.baseURL, + Path: "/v1/videos", + Headers: utils.AuthHeader(p.apiKey), + Body: req, + }) + if err != nil { + return nil, fmt.Errorf("openrouter videos: create request failed: %w", err) + } + return toVideoJob(resp, params.Model.ID), nil +} + +func (p *Provider) DoGet(ctx context.Context, model *sdk.VideoModel, id string) (*sdk.VideoJob, error) { + resp, err := utils.FetchJSON[videoResponse](ctx, p.httpClient, &utils.RequestOptions{ + Method: http.MethodGet, + BaseURL: p.baseURL, + Path: "/v1/videos/" + id, + Headers: utils.AuthHeader(p.apiKey), + }) + if err != nil { + return nil, fmt.Errorf("openrouter videos: get request failed: %w", err) + } + modelID := "" + if model != nil { + modelID = model.ID + } + return toVideoJob(resp, modelID), nil +} + +func (p *Provider) DoCancel(_ context.Context, _ *sdk.VideoModel, _ string) error { + return fmt.Errorf("openrouter videos: cancel is not supported") +} + +func (p *Provider) DoDownload(ctx context.Context, _ *sdk.VideoModel, output sdk.VideoOutput) (data []byte, contentType string, err error) { + if output.URL == "" { + return nil, "", fmt.Errorf("openrouter videos: output URL is required") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, output.URL, http.NoBody) + if err != nil { + return nil, "", fmt.Errorf("openrouter videos: build download request: %w", err) + } + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, "", fmt.Errorf("openrouter videos: download request failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return nil, "", fmt.Errorf("openrouter videos: download failed with status %d: %s", resp.StatusCode, string(body)) + } + data, err = io.ReadAll(resp.Body) + if err != nil { + return nil, "", fmt.Errorf("openrouter videos: read download response: %w", err) + } + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = output.ContentType + } + if contentType == "" { + contentType = "video/mp4" + } + return data, contentType, nil +} + +func toVideoJob(resp *videoResponse, modelID string) *sdk.VideoJob { + job := &sdk.VideoJob{ + ID: resp.ID, + ModelID: modelID, + Status: mapStatus(resp.Status), + ProviderMetadata: map[string]any{ + "polling_url": resp.PollingURL, + "generation_id": resp.GenerationID, + "usage": resp.Usage, + }, + } + if resp.Error != "" { + job.Error = &sdk.VideoError{Message: resp.Error} + } + for _, url := range resp.UnsignedURLs { + if strings.TrimSpace(url) == "" { + continue + } + job.Outputs = append(job.Outputs, sdk.VideoOutput{ + URL: url, + ContentType: "video/mp4", + ProviderMetadata: map[string]any{ + "generation_id": resp.GenerationID, + }, + }) + } + return job +} + +func mapStatus(status string) sdk.VideoJobStatus { + switch strings.ToLower(strings.TrimSpace(status)) { + case "pending", "queued": + return sdk.VideoJobQueued + case "running", "in_progress", "processing": + return sdk.VideoJobRunning + case "completed", "succeeded", "success": + return sdk.VideoJobSucceeded + case "failed", "error": + return sdk.VideoJobFailed + case "canceled", "cancelled": + return sdk.VideoJobCanceled + default: + return sdk.VideoJobRunning + } +} + +func mediaURL(input *sdk.MediaInput) string { + if input == nil { + return "" + } + if input.URL != "" { + return input.URL + } + if len(input.Data) == 0 { + return input.FileID + } + contentType := input.ContentType + if contentType == "" { + contentType = "application/octet-stream" + } + return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(input.Data) +} diff --git a/provider/openrouter/videos/videos_test.go b/provider/openrouter/videos/videos_test.go new file mode 100644 index 0000000..3bf08e2 --- /dev/null +++ b/provider/openrouter/videos/videos_test.go @@ -0,0 +1,129 @@ +package videos + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/memohai/twilight-ai/sdk" +) + +func TestDoCreateMapsVideoRequest(t *testing.T) { + var got map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/videos" { + t.Fatalf("path = %s, want /v1/videos", r.URL.Path) + } + if r.Header.Get("Authorization") != "Bearer test-key" { + t.Fatalf("missing auth header") + } + if err := json.NewDecoder(r.Body).Decode(&got); err != nil { + t.Fatalf("decode request: %v", err) + } + _ = json.NewEncoder(w).Encode(videoResponse{ID: "job-1", Status: "pending", GenerationID: "gen-1"}) + })) + defer server.Close() + + prov := New(WithAPIKey("test-key"), WithBaseURL(server.URL)) + model := prov.VideoModel("google/veo-3.1") + _, err := prov.DoCreate(context.Background(), sdk.VideoParams{ + Model: model, + Prompt: "cinematic ocean", + DurationSeconds: intPtr(8), + Resolution: "720p", + AspectRatio: "16:9", + Size: "1280x720", + Seed: int64Ptr(42), + GenerateAudio: boolPtr(true), + CallbackURL: "https://example.com/hook", + InputImage: &sdk.MediaInput{URL: "https://example.com/first.png"}, + ReferenceImages: []sdk.MediaInput{{URL: "https://example.com/ref.png"}}, + ReferenceAudio: []sdk.MediaInput{{URL: "https://example.com/ref.mp3"}}, + ReferenceVideos: []sdk.MediaInput{{URL: "https://example.com/ref.mp4"}}, + Config: map[string]any{"provider": map[string]any{"sort": "price"}}, + }) + if err != nil { + t.Fatalf("DoCreate returned error: %v", err) + } + + if got["model"] != "google/veo-3.1" || got["prompt"] != "cinematic ocean" { + t.Fatalf("unexpected core fields: %#v", got) + } + if got["duration"].(float64) != 8 || got["resolution"] != "720p" || got["aspect_ratio"] != "16:9" { + t.Fatalf("unexpected video fields: %#v", got) + } + frameImages := got["frame_images"].([]any) + frame := frameImages[0].(map[string]any) + if frame["frame_type"] != "first_frame" { + t.Fatalf("frame_type = %#v", frame["frame_type"]) + } + refs := got["input_references"].([]any) + if len(refs) != 3 { + t.Fatalf("input refs len = %d, want 3", len(refs)) + } + if got["provider"].(map[string]any)["sort"] != "price" { + t.Fatalf("provider passthrough missing: %#v", got["provider"]) + } +} + +func TestDoGetMapsCompletedResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/videos/job-1" { + t.Fatalf("path = %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(videoResponse{ + ID: "job-1", + Status: "completed", + GenerationID: "gen-1", + UnsignedURLs: []string{"https://storage.example.com/out.mp4"}, + Usage: map[string]any{"cost": 0.5}, + }) + })) + defer server.Close() + + prov := New(WithAPIKey("test-key"), WithBaseURL(server.URL)) + job, err := prov.DoGet(context.Background(), prov.VideoModel("google/veo-3.1"), "job-1") + if err != nil { + t.Fatalf("DoGet returned error: %v", err) + } + if job.Status != sdk.VideoJobSucceeded { + t.Fatalf("status = %s", job.Status) + } + if len(job.Outputs) != 1 || job.Outputs[0].URL != "https://storage.example.com/out.mp4" { + t.Fatalf("unexpected outputs: %#v", job.Outputs) + } +} + +func TestListModelsMapsCapabilities(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/videos/models" { + t.Fatalf("path = %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(listModelsResponse{Data: []modelResponse{{ + ID: "google/veo-3.1", + Name: "Veo 3.1", + SupportedDurations: []int{5, 8}, + SupportedResolutions: []string{"720p"}, + SupportedAspectRatios: []string{"16:9"}, + }}}) + })) + defer server.Close() + + prov := New(WithAPIKey("test-key"), WithBaseURL(server.URL)) + models, err := prov.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels returned error: %v", err) + } + if len(models) != 1 || models[0].ID != "google/veo-3.1" { + t.Fatalf("unexpected models: %#v", models) + } + if models[0].ProviderMetadata["name"] != "Veo 3.1" { + t.Fatalf("metadata missing: %#v", models[0].ProviderMetadata) + } +} + +func intPtr(v int) *int { return &v } +func int64Ptr(v int64) *int64 { return &v } +func boolPtr(v bool) *bool { return &v } diff --git a/sdk/client.go b/sdk/client.go index 33be4b9..41e4265 100644 --- a/sdk/client.go +++ b/sdk/client.go @@ -60,4 +60,26 @@ func EditImage(ctx context.Context, options ...ImageEditOption) (*ImageResult, e return defaultClient.EditImage(ctx, options...) } +// --- Video convenience functions --- + +func CreateVideo(ctx context.Context, options ...VideoOption) (*VideoJob, error) { + return defaultClient.CreateVideo(ctx, options...) +} + +func GetVideo(ctx context.Context, model *VideoModel, id string) (*VideoJob, error) { + return defaultClient.GetVideo(ctx, model, id) +} + +func CancelVideo(ctx context.Context, model *VideoModel, id string) error { + return defaultClient.CancelVideo(ctx, model, id) +} + +func DownloadVideo(ctx context.Context, model *VideoModel, output VideoOutput) (data []byte, contentType string, err error) { + return defaultClient.DownloadVideo(ctx, model, output) +} + +func GenerateVideo(ctx context.Context, options ...VideoOption) (*VideoResult, error) { + return defaultClient.GenerateVideo(ctx, options...) +} + var defaultClient = &Client{} diff --git a/sdk/video.go b/sdk/video.go new file mode 100644 index 0000000..b0ff062 --- /dev/null +++ b/sdk/video.go @@ -0,0 +1,111 @@ +package sdk + +import "context" + +// VideoProvider is the interface that asynchronous video generation backends +// must implement. +type VideoProvider interface { + ListModels(ctx context.Context) ([]*VideoModel, error) + DoCreate(ctx context.Context, params VideoParams) (*VideoJob, error) + DoGet(ctx context.Context, model *VideoModel, id string) (*VideoJob, error) + DoCancel(ctx context.Context, model *VideoModel, id string) error + DoDownload(ctx context.Context, model *VideoModel, output VideoOutput) ([]byte, string, error) +} + +// VideoModel represents a video generation model bound to a VideoProvider. +type VideoModel struct { + ID string + Provider VideoProvider + ProviderMetadata map[string]any +} + +// MediaInput represents an input media asset for video generation. +// Exactly one of Data, URL, or FileID should usually be set. +type MediaInput struct { + Data []byte + URL string + FileID string + Filename string + ContentType string +} + +// VideoParams holds provider-agnostic video generation parameters. +// Config is an escape hatch for provider-specific options. +type VideoParams struct { + Model *VideoModel + Prompt string + + Size string + Resolution string + AspectRatio string + DurationSeconds *int + Seed *int64 + GenerateAudio *bool + CallbackURL string + + InputImage *MediaInput + InputVideo *MediaInput + ReferenceImages []MediaInput + ReferenceVideos []MediaInput + ReferenceAudio []MediaInput + + Config map[string]any +} + +// VideoJobStatus is the unified status for asynchronous video generation jobs. +type VideoJobStatus string + +const ( + VideoJobQueued VideoJobStatus = "queued" + VideoJobRunning VideoJobStatus = "running" + VideoJobSucceeded VideoJobStatus = "succeeded" + VideoJobFailed VideoJobStatus = "failed" + VideoJobCanceled VideoJobStatus = "canceled" +) + +// Terminal reports whether the job status cannot make more progress. +func (s VideoJobStatus) Terminal() bool { + switch s { + case VideoJobSucceeded, VideoJobFailed, VideoJobCanceled: + return true + default: + return false + } +} + +// VideoJob holds the unified state returned by a video generation provider. +type VideoJob struct { + ID string + ModelID string + Status VideoJobStatus + Progress *float64 + Outputs []VideoOutput + Error *VideoError + ProviderMetadata map[string]any +} + +// VideoOutput describes a generated video or related downloadable asset. +type VideoOutput struct { + URL string + ContentType string + Width int + Height int + DurationSeconds float64 + HasAudio bool + ProviderMetadata map[string]any +} + +// VideoError is a provider-normalized error payload for failed jobs. +type VideoError struct { + Code string + Message string +} + +// VideoResult is returned by GenerateVideo. Data is populated only when +// WithVideoDownload(true) is used. +type VideoResult struct { + Job *VideoJob + Output *VideoOutput + Data []byte + ContentType string +} diff --git a/sdk/video_generate.go b/sdk/video_generate.go new file mode 100644 index 0000000..f64a9d3 --- /dev/null +++ b/sdk/video_generate.go @@ -0,0 +1,252 @@ +package sdk + +import ( + "context" + "fmt" + "time" +) + +const ( + defaultVideoPollInterval = 5 * time.Second + defaultVideoPollTimeout = 10 * time.Minute +) + +type videoConfig struct { + Params VideoParams + Wait bool + PollInterval time.Duration + PollTimeout time.Duration + Download bool +} + +// VideoOption configures a video generation request. +type VideoOption func(*videoConfig) + +func WithVideoModel(model *VideoModel) VideoOption { + return func(c *videoConfig) { c.Params.Model = model } +} + +func WithVideoPrompt(prompt string) VideoOption { + return func(c *videoConfig) { c.Params.Prompt = prompt } +} + +func WithVideoSize(size string) VideoOption { + return func(c *videoConfig) { c.Params.Size = size } +} + +func WithVideoResolution(resolution string) VideoOption { + return func(c *videoConfig) { c.Params.Resolution = resolution } +} + +func WithVideoAspectRatio(aspectRatio string) VideoOption { + return func(c *videoConfig) { c.Params.AspectRatio = aspectRatio } +} + +func WithVideoDuration(seconds int) VideoOption { + return func(c *videoConfig) { c.Params.DurationSeconds = &seconds } +} + +func WithVideoSeed(seed int64) VideoOption { + return func(c *videoConfig) { c.Params.Seed = &seed } +} + +func WithVideoGenerateAudio(generate bool) VideoOption { + return func(c *videoConfig) { c.Params.GenerateAudio = &generate } +} + +func WithVideoCallbackURL(url string) VideoOption { + return func(c *videoConfig) { c.Params.CallbackURL = url } +} + +func WithVideoInputImage(input *MediaInput) VideoOption { + return func(c *videoConfig) { c.Params.InputImage = input } +} + +func WithVideoInputVideo(input *MediaInput) VideoOption { + return func(c *videoConfig) { c.Params.InputVideo = input } +} + +func WithVideoReferenceImages(inputs ...MediaInput) VideoOption { + return func(c *videoConfig) { c.Params.ReferenceImages = inputs } +} + +func WithVideoReferenceVideos(inputs ...MediaInput) VideoOption { + return func(c *videoConfig) { c.Params.ReferenceVideos = inputs } +} + +func WithVideoReferenceAudio(inputs ...MediaInput) VideoOption { + return func(c *videoConfig) { c.Params.ReferenceAudio = inputs } +} + +func WithVideoConfig(config map[string]any) VideoOption { + return func(c *videoConfig) { c.Params.Config = config } +} + +func WithVideoWait(wait bool) VideoOption { + return func(c *videoConfig) { c.Wait = wait } +} + +func WithVideoPollInterval(interval time.Duration) VideoOption { + return func(c *videoConfig) { c.PollInterval = interval } +} + +func WithVideoPollTimeout(timeout time.Duration) VideoOption { + return func(c *videoConfig) { c.PollTimeout = timeout } +} + +func WithVideoDownload(download bool) VideoOption { + return func(c *videoConfig) { c.Download = download } +} + +func buildVideoConfig(options []VideoOption) (*videoConfig, VideoProvider, error) { + cfg := &videoConfig{ + Wait: true, + PollInterval: defaultVideoPollInterval, + PollTimeout: defaultVideoPollTimeout, + } + for _, opt := range options { + opt(cfg) + } + if cfg.Params.Model == nil { + return nil, nil, fmt.Errorf("twilightai: video model is required (use WithVideoModel)") + } + if cfg.Params.Model.Provider == nil { + return nil, nil, fmt.Errorf("twilightai: video model %q has no provider", cfg.Params.Model.ID) + } + if cfg.Params.Prompt == "" { + return nil, nil, fmt.Errorf("twilightai: prompt is required (use WithVideoPrompt)") + } + if cfg.PollInterval <= 0 { + return nil, nil, fmt.Errorf("twilightai: video poll interval must be positive") + } + if cfg.PollTimeout <= 0 { + return nil, nil, fmt.Errorf("twilightai: video poll timeout must be positive") + } + return cfg, cfg.Params.Model.Provider, nil +} + +// CreateVideo starts an asynchronous video generation job. +func (c *Client) CreateVideo(ctx context.Context, options ...VideoOption) (*VideoJob, error) { + cfg, prov, err := buildVideoConfig(options) + if err != nil { + return nil, err + } + return prov.DoCreate(ctx, cfg.Params) +} + +// GetVideo retrieves a video generation job by ID. +func (c *Client) GetVideo(ctx context.Context, model *VideoModel, id string) (*VideoJob, error) { + prov, err := videoProviderFromModel(model) + if err != nil { + return nil, err + } + if id == "" { + return nil, fmt.Errorf("twilightai: video id is required") + } + return prov.DoGet(ctx, model, id) +} + +// CancelVideo requests cancellation of a video generation job. +func (c *Client) CancelVideo(ctx context.Context, model *VideoModel, id string) error { + prov, err := videoProviderFromModel(model) + if err != nil { + return err + } + if id == "" { + return fmt.Errorf("twilightai: video id is required") + } + return prov.DoCancel(ctx, model, id) +} + +// DownloadVideo downloads a provider output and returns bytes plus content type. +func (c *Client) DownloadVideo(ctx context.Context, model *VideoModel, output VideoOutput) (data []byte, contentType string, err error) { + prov, err := videoProviderFromModel(model) + if err != nil { + return nil, "", err + } + return prov.DoDownload(ctx, model, output) +} + +// GenerateVideo starts a job, waits for it by default, and optionally downloads +// the first output when WithVideoDownload(true) is set. +func (c *Client) GenerateVideo(ctx context.Context, options ...VideoOption) (*VideoResult, error) { + cfg, prov, err := buildVideoConfig(options) + if err != nil { + return nil, err + } + + job, err := prov.DoCreate(ctx, cfg.Params) + if err != nil { + return nil, err + } + result := &VideoResult{Job: job} + if !cfg.Wait { + return result, nil + } + + waitCtx, cancel := context.WithTimeout(ctx, cfg.PollTimeout) + defer cancel() + + ticker := time.NewTicker(cfg.PollInterval) + defer ticker.Stop() + + for job == nil || !job.Status.Terminal() { + select { + case <-waitCtx.Done(): + return nil, fmt.Errorf("twilightai: video generation timed out after %s", cfg.PollTimeout) + case <-ticker.C: + if job == nil || job.ID == "" { + return nil, fmt.Errorf("twilightai: video provider returned empty job id") + } + job, err = prov.DoGet(waitCtx, cfg.Params.Model, job.ID) + if err != nil { + return nil, err + } + result.Job = job + } + } + + if job.Status != VideoJobSucceeded { + if job.Error != nil && job.Error.Message != "" { + return result, fmt.Errorf("twilightai: video generation failed: %s", job.Error.Message) + } + return result, fmt.Errorf("twilightai: video generation finished with status %s", job.Status) + } + if len(job.Outputs) > 0 { + result.Output = &job.Outputs[0] + } + if cfg.Download && result.Output != nil { + mergeVideoOutputMetadata(result.Output, cfg.Params.Config) + data, contentType, err := prov.DoDownload(ctx, cfg.Params.Model, *result.Output) + if err != nil { + return result, err + } + result.Data = data + result.ContentType = contentType + } + return result, nil +} + +func videoProviderFromModel(model *VideoModel) (VideoProvider, error) { + if model == nil { + return nil, fmt.Errorf("twilightai: video model is required") + } + if model.Provider == nil { + return nil, fmt.Errorf("twilightai: video model %q has no provider", model.ID) + } + return model.Provider, nil +} + +func mergeVideoOutputMetadata(output *VideoOutput, config map[string]any) { + if output == nil || len(config) == 0 { + return + } + if output.ProviderMetadata == nil { + output.ProviderMetadata = map[string]any{} + } + for k, v := range config { + if _, exists := output.ProviderMetadata[k]; !exists { + output.ProviderMetadata[k] = v + } + } +} diff --git a/sdk/video_generate_test.go b/sdk/video_generate_test.go new file mode 100644 index 0000000..f101791 --- /dev/null +++ b/sdk/video_generate_test.go @@ -0,0 +1,135 @@ +package sdk + +import ( + "context" + "errors" + "strings" + "testing" + "time" +) + +func TestCreateVideoValidation(t *testing.T) { + _, err := CreateVideo(context.Background(), WithVideoPrompt("hello")) + if err == nil || !strings.Contains(err.Error(), "video model is required") { + t.Fatalf("expected missing model error, got %v", err) + } + + _, err = CreateVideo(context.Background(), WithVideoModel(&VideoModel{ID: "m"}), WithVideoPrompt("hello")) + if err == nil || !strings.Contains(err.Error(), "has no provider") { + t.Fatalf("expected missing provider error, got %v", err) + } + + _, err = CreateVideo(context.Background(), WithVideoModel(testVideoModel(&fakeVideoProvider{}))) + if err == nil || !strings.Contains(err.Error(), "prompt is required") { + t.Fatalf("expected missing prompt error, got %v", err) + } +} + +func TestGenerateVideoPollsUntilSucceeded(t *testing.T) { + prov := &fakeVideoProvider{ + createJob: &VideoJob{ID: "job-1", Status: VideoJobQueued}, + getJobs: []*VideoJob{ + {ID: "job-1", Status: VideoJobRunning}, + {ID: "job-1", Status: VideoJobSucceeded, Outputs: []VideoOutput{{URL: "https://example.com/out.mp4"}}}, + }, + downloadData: []byte("video"), + } + + result, err := GenerateVideo(context.Background(), + WithVideoModel(testVideoModel(prov)), + WithVideoPrompt("make a clip"), + WithVideoPollInterval(time.Millisecond), + WithVideoPollTimeout(time.Second), + WithVideoDownload(true), + ) + if err != nil { + t.Fatalf("GenerateVideo returned error: %v", err) + } + if result.Job.Status != VideoJobSucceeded { + t.Fatalf("status = %s, want succeeded", result.Job.Status) + } + if prov.getCalls != 2 { + t.Fatalf("get calls = %d, want 2", prov.getCalls) + } + if string(result.Data) != "video" || result.ContentType != "video/mp4" { + t.Fatalf("unexpected download result: %q %q", result.Data, result.ContentType) + } +} + +func TestGenerateVideoTimeout(t *testing.T) { + prov := &fakeVideoProvider{ + createJob: &VideoJob{ID: "job-1", Status: VideoJobQueued}, + getJobs: []*VideoJob{{ID: "job-1", Status: VideoJobRunning}}, + } + + _, err := GenerateVideo(context.Background(), + WithVideoModel(testVideoModel(prov)), + WithVideoPrompt("make a clip"), + WithVideoPollInterval(time.Millisecond), + WithVideoPollTimeout(3*time.Millisecond), + ) + if err == nil || !strings.Contains(err.Error(), "timed out") { + t.Fatalf("expected timeout error, got %v", err) + } +} + +func TestGenerateVideoFailedStatus(t *testing.T) { + prov := &fakeVideoProvider{ + createJob: &VideoJob{ID: "job-1", Status: VideoJobQueued}, + getJobs: []*VideoJob{ + {ID: "job-1", Status: VideoJobFailed, Error: &VideoError{Message: "blocked"}}, + }, + } + + result, err := GenerateVideo(context.Background(), + WithVideoModel(testVideoModel(prov)), + WithVideoPrompt("make a clip"), + WithVideoPollInterval(time.Millisecond), + WithVideoPollTimeout(time.Second), + ) + if err == nil || !strings.Contains(err.Error(), "blocked") { + t.Fatalf("expected failed status error, got %v", err) + } + if result == nil || result.Job.Status != VideoJobFailed { + t.Fatalf("expected failed result, got %#v", result) + } +} + +func testVideoModel(prov VideoProvider) *VideoModel { + return &VideoModel{ID: "model-1", Provider: prov} +} + +type fakeVideoProvider struct { + createJob *VideoJob + getJobs []*VideoJob + getCalls int + downloadData []byte +} + +func (p *fakeVideoProvider) ListModels(context.Context) ([]*VideoModel, error) { + return nil, nil +} + +func (p *fakeVideoProvider) DoCreate(context.Context, VideoParams) (*VideoJob, error) { + return p.createJob, nil +} + +func (p *fakeVideoProvider) DoGet(context.Context, *VideoModel, string) (*VideoJob, error) { + if len(p.getJobs) == 0 { + return nil, errors.New("no jobs configured") + } + idx := p.getCalls + if idx >= len(p.getJobs) { + idx = len(p.getJobs) - 1 + } + p.getCalls++ + return p.getJobs[idx], nil +} + +func (p *fakeVideoProvider) DoCancel(context.Context, *VideoModel, string) error { + return nil +} + +func (p *fakeVideoProvider) DoDownload(context.Context, *VideoModel, VideoOutput) ([]byte, string, error) { + return p.downloadData, "video/mp4", nil +}