Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions runner/cmd/runner/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func main() {

func mainInner() int {
var tempDir string
var httpAddress string
var httpPort int
var sshPort int
var sshAuthorizedKeys []string
Expand Down Expand Up @@ -61,6 +62,13 @@ func mainInner() int {
Destination: &tempDir,
TakesFile: true,
},
&cli.StringFlag{
Name: "http-address",
Usage: "Set a http bind address",
Value: "",
DefaultText: "all interfaces",
Destination: &httpAddress,
},
&cli.IntFlag{
Name: "http-port",
Usage: "Set a http port",
Expand All @@ -86,7 +94,7 @@ func mainInner() int {
},
},
Action: func(ctx context.Context, cmd *cli.Command) error {
return start(ctx, tempDir, httpPort, sshPort, sshAuthorizedKeys, logLevel, Version)
return start(ctx, logLevel, tempDir, httpAddress, httpPort, sshPort, sshAuthorizedKeys)
},
},
},
Expand All @@ -103,7 +111,12 @@ func mainInner() int {
return 0
}

func start(ctx context.Context, tempDir string, httpPort int, sshPort int, sshAuthorizedKeys []string, logLevel int, version string) error {
func start(
ctx context.Context,
logLevel int, tempDir string,
httpAddress string, httpPort int,
sshPort int, sshAuthorizedKeys []string,
) error {
if err := os.MkdirAll(tempDir, 0o755); err != nil {
return fmt.Errorf("create temp directory: %w", err)
}
Expand Down Expand Up @@ -191,7 +204,7 @@ func start(ctx context.Context, tempDir string, httpPort int, sshPort int, sshAu
return fmt.Errorf("create executor: %w", err)
}

server, err := api.NewServer(ctx, fmt.Sprintf(":%d", httpPort), version, ex)
server, err := api.NewServer(ctx, fmt.Sprintf("%s:%d", httpAddress, httpPort), Version, ex)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
Expand Down
23 changes: 17 additions & 6 deletions runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,6 @@ func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error {
}
mounts = append(mounts, instanceMounts...)

ports := d.dockerParams.DockerPorts()

// Set the environment variables
envVars := []string{}
if d.dockerParams.DockerPJRTDevice() != "" {
Expand All @@ -827,9 +825,19 @@ func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error {
}
}

networkMode := getNetworkMode(task.config.NetworkMode)
ports := d.dockerParams.DockerPorts()

// Bridge mode - all interfaces
runnerHttpAddress := ""
if networkMode.IsHost() {
runnerHttpAddress = "localhost"
}
shellCommands := d.dockerParams.DockerShellCommands(task.config.ContainerSshKeys, runnerHttpAddress)

containerConfig := &container.Config{
Image: task.config.ImageName,
Cmd: []string{strings.Join(d.dockerParams.DockerShellCommands(task.config.ContainerSshKeys), " && ")},
Cmd: []string{strings.Join(shellCommands, " && ")},
Entrypoint: []string{"/bin/sh", "-c"},
ExposedPorts: exposePorts(ports),
Env: envVars,
Expand All @@ -843,7 +851,7 @@ func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error {
}
hostConfig := &container.HostConfig{
Privileged: task.config.Privileged || d.dockerParams.DockerPrivileged(),
NetworkMode: getNetworkMode(task.config.NetworkMode),
NetworkMode: networkMode,
PortBindings: bindPorts(ports),
Mounts: mounts,
ShmSize: task.config.ShmSize,
Expand Down Expand Up @@ -1182,7 +1190,7 @@ func (c *CLIArgs) DockerPJRTDevice() string {
return c.Docker.PJRTDevice
}

func (c *CLIArgs) DockerShellCommands(publicKeys []string) []string {
func (c *CLIArgs) DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string {
commands := getSSHShellCommands()
runnerCommand := []string{
consts.RunnerBinaryPath,
Expand All @@ -1192,7 +1200,10 @@ func (c *CLIArgs) DockerShellCommands(publicKeys []string) []string {
"--http-port", strconv.Itoa(c.Runner.HTTPPort),
"--ssh-port", strconv.Itoa(c.Runner.SSHPort),
}
for _, key := range publicKeys {
if runnerHttpAddress != "" {
runnerCommand = append(runnerCommand, "--http-address", runnerHttpAddress)
}
for _, key := range authorizedKeys {
runnerCommand = append(runnerCommand, "--ssh-authorized-key", fmt.Sprintf("'%s'", key))
}
return append(commands, strings.Join(runnerCommand, " "))
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/shim/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (c *dockerParametersMock) DockerPJRTDevice() string {
return ""
}

func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string {
func (c *dockerParametersMock) DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string {
commands := make([]string, 0)
if c.sshShellCommands {
commands = append(commands, getSSHShellCommands()...)
Expand Down
2 changes: 1 addition & 1 deletion runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

type DockerParameters interface {
DockerPrivileged() bool
DockerShellCommands([]string) []string
DockerShellCommands(authorizedKeys []string, runnerHttpAddress string) []string
DockerMounts(string) ([]mount.Mount, error)
DockerPorts() []int
MakeRunnerDir(name string) (string, error)
Expand Down