Skip to content

Commit

Permalink
Add support for torch-amd and torch-rocm image
Browse files Browse the repository at this point in the history
  • Loading branch information
ChughShilpa committed Oct 7, 2024
1 parent 11fd6e3 commit 1086b7a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
8 changes: 5 additions & 3 deletions support/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package support
// ***********************

const (
RayVersion = "2.35.0"
RayImage = "quay.io/modh/ray:2.35.0-py39-cu121"
RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61"
RayVersion = "2.35.0"
RayImage = "quay.io/modh/ray:2.35.0-py39-cu121"
RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61"
RayTorchCudaImage = "quay.io/rhoai/2.35.0-py39-cu121-torch24-fa26"
RayTorchROCmImage = "quay.io/rhoai/ray:2.35.0-py39-rocm61-torch24-fa26"
)
18 changes: 14 additions & 4 deletions support/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ const (
// The environment variables hereafter can be used to change the components
// used for testing.

CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE"
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE"
CodeFlareTestRayTorchCudaImage = "CODEFLARE_TEST_RAY_TORCH_CUDA_IMAGE"
CodeFlareTestRayTorchROCmImage = "CODEFLARE_TEST_RAY_TORCH_ROCM_IMAGE"
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"

// The testing output directory, to write output files into.
CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR"
Expand Down Expand Up @@ -83,6 +85,14 @@ func GetRayROCmImage() string {
return lookupEnvOrDefault(CodeFlareTestRayROCmImage, RayROCmImage)
}

func GetRayTorchCudaImage() string {
return lookupEnvOrDefault(CodeFlareTestRayTorchCudaImage, RayTorchCudaImage)
}

func GetRayTorchROCmImage() string {
return lookupEnvOrDefault(CodeFlareTestRayTorchROCmImage, RayTorchROCmImage)
}

func GetPyTorchImage() string {
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
}
Expand Down

0 comments on commit 1086b7a

Please sign in to comment.