From 83baea79352b1a71906cab7cb09d6b5a0d593e73 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 27 Oct 2025 23:58:18 +0000 Subject: [PATCH] Automatically use ROCm when appropriate On Linux at least Signed-off-by: Eric Curtin --- main.go | 9 ++ pkg/gpuinfo/amd_gpu_linux.go | 116 ++++++++++++++++++ pkg/gpuinfo/gpuinfo.go | 4 + pkg/gpuinfo/gpuinfo_linux.go | 3 + pkg/gpuinfo/gpuinfo_not_linux.go | 8 ++ pkg/gpuinfo/memory_darwin_cgo.go | 6 + pkg/gpuinfo/memory_darwin_nocgo.go | 6 + pkg/gpuinfo/memory_linux_nocgo.go | 5 + pkg/gpuinfo/memory_windows.go | 6 + .../backends/llamacpp/download_linux.go | 44 ++++++- 10 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 pkg/gpuinfo/amd_gpu_linux.go create mode 100644 pkg/gpuinfo/gpuinfo_linux.go create mode 100644 pkg/gpuinfo/gpuinfo_not_linux.go diff --git a/main.go b/main.go index bce27810..e04ee8ee 100644 --- a/main.go +++ b/main.go @@ -116,6 +116,15 @@ func main() { log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath) + // Auto-detect GPU type and set appropriate variant + // Check if we have supported AMD GPUs and set ROCm variant accordingly + if hasAMD, err := gpuInfo.HasSupportedAMDGPU(); err == nil && hasAMD { + log.Info("Supported AMD GPU detected, ROCm will be used automatically") + // This will be handled by the llama.cpp backend during server download + } else if err != nil { + log.Debugf("AMD GPU detection failed: %v", err) + } + // Create llama.cpp configuration from environment variables llamaCppConfig := createLlamaCppConfigFromEnv() diff --git a/pkg/gpuinfo/amd_gpu_linux.go b/pkg/gpuinfo/amd_gpu_linux.go new file mode 100644 index 00000000..d929df0d --- /dev/null +++ b/pkg/gpuinfo/amd_gpu_linux.go @@ -0,0 +1,116 @@ +//go:build linux + +package gpuinfo + +import ( + "bufio" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" +) + +// supportedAMDGPUs are the AMD GPU targets that should use ROCm +var supportedAMDGPUs = map[string]bool{ + "gfx908": true, + "gfx90a": true, + "gfx942": true, + "gfx1010": true, + "gfx1030": true, + "gfx1100": true, + "gfx1200": true, + "gfx1201": true, + "gfx1151": true, +} + +func hasSupportedAMDGPU() (bool, error) { + // Check if KFD topology directory exists + topologyDir := "/sys/class/kfd/kfd/topology/nodes/" + info, err := os.Stat(topologyDir) + if err != nil || !info.IsDir() { + return false, nil // KFD not available + } + + entries, err := os.ReadDir(topologyDir) + if err != nil { + return false, err + } + + // Sort entries by name to maintain consistent order + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name() < entries[j].Name() + }) + + // Compile regex to match gfx_target_version lines + reTarget := regexp.MustCompile(`gfx_target_version[ \t]+([0-9]+)`) + + for _, e := range entries { + if !e.IsDir() { + continue + } + nodePath := filepath.Join(topologyDir, e.Name()) + propPath := filepath.Join(nodePath, "properties") + + // Attempt to open the properties file directly; skip on error (e.g., permissions) + f, err := os.Open(propPath) + if err != nil { + // Could be permission denied or file doesn't exist; just skip like the Python code + continue + } + + sc := bufio.NewScanner(f) + for sc.Scan() { + line := sc.Text() + matches := reTarget.FindStringSubmatch(line) + if len(matches) < 2 { + continue + } + + deviceIDStr := matches[1] + deviceID, err := strconv.Atoi(deviceIDStr) + if err != nil || deviceID == 0 { + continue + } + + var majorVer, minorVer, steppingVer int + if gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION"); gfxOverride != "" { + parts := strings.Split(strings.TrimSpace(gfxOverride), ".") + if len(parts) != 3 { + // Invalid format, skip + continue + } + mv, err1 := strconv.Atoi(parts[0]) + nv, err2 := strconv.Atoi(parts[1]) + sv, err3 := strconv.Atoi(parts[2]) + if err1 != nil || err2 != nil || err3 != nil { + // Invalid format, skip + continue + } + if mv > 63 || nv > 255 || sv > 255 { + // Invalid values, skip + continue + } + majorVer, minorVer, steppingVer = mv, nv, sv + } else { + majorVer = (deviceID / 10000) % 100 + minorVer = (deviceID / 100) % 100 + steppingVer = deviceID % 100 + } + + gfx := "gfx" + + strconv.FormatInt(int64(majorVer), 10) + + strconv.FormatInt(int64(minorVer), 16) + + strconv.FormatInt(int64(steppingVer), 16) + + if supportedAMDGPUs[gfx] { + f.Close() + return true, nil // Found a supported AMD GPU + } + } + f.Close() + } + + return false, nil // No supported AMD GPU found +} \ No newline at end of file diff --git a/pkg/gpuinfo/gpuinfo.go b/pkg/gpuinfo/gpuinfo.go index 3bc8f66e..6dbd0f5c 100644 --- a/pkg/gpuinfo/gpuinfo.go +++ b/pkg/gpuinfo/gpuinfo.go @@ -15,3 +15,7 @@ func New(modelRuntimeInstallPath string) *GPUInfo { func (g *GPUInfo) GetVRAMSize() (uint64, error) { return getVRAMSize(g.modelRuntimeInstallPath) } + +func (g *GPUInfo) HasSupportedAMDGPU() (bool, error) { + return hasSupportedAMDGPU() +} diff --git a/pkg/gpuinfo/gpuinfo_linux.go b/pkg/gpuinfo/gpuinfo_linux.go new file mode 100644 index 00000000..87966e24 --- /dev/null +++ b/pkg/gpuinfo/gpuinfo_linux.go @@ -0,0 +1,3 @@ +//go:build linux + +package gpuinfo \ No newline at end of file diff --git a/pkg/gpuinfo/gpuinfo_not_linux.go b/pkg/gpuinfo/gpuinfo_not_linux.go new file mode 100644 index 00000000..e82612cb --- /dev/null +++ b/pkg/gpuinfo/gpuinfo_not_linux.go @@ -0,0 +1,8 @@ +//go:build !linux + +package gpuinfo + +func (g *GPUInfo) HasSupportedAMDGPU() (bool, error) { + // AMD GPU detection is only supported on Linux + return false, nil +} \ No newline at end of file diff --git a/pkg/gpuinfo/memory_darwin_cgo.go b/pkg/gpuinfo/memory_darwin_cgo.go index 95a20e3d..84a0ce73 100644 --- a/pkg/gpuinfo/memory_darwin_cgo.go +++ b/pkg/gpuinfo/memory_darwin_cgo.go @@ -17,3 +17,9 @@ func getVRAMSize(_ string) (uint64, error) { } return uint64(vramSize), nil } + +// hasSupportedAMDGPU returns true if the system has supported AMD GPUs +func hasSupportedAMDGPU() (bool, error) { + // AMD GPU detection is only supported on Linux + return false, nil +} diff --git a/pkg/gpuinfo/memory_darwin_nocgo.go b/pkg/gpuinfo/memory_darwin_nocgo.go index 915af448..3bd690bf 100644 --- a/pkg/gpuinfo/memory_darwin_nocgo.go +++ b/pkg/gpuinfo/memory_darwin_nocgo.go @@ -8,3 +8,9 @@ import "errors" func getVRAMSize(_ string) (uint64, error) { return 0, errors.New("unimplemented without cgo") } + +// hasSupportedAMDGPU returns true if the system has supported AMD GPUs +func hasSupportedAMDGPU() (bool, error) { + // AMD GPU detection is only supported on Linux + return false, nil +} diff --git a/pkg/gpuinfo/memory_linux_nocgo.go b/pkg/gpuinfo/memory_linux_nocgo.go index abe74c18..d1df90d2 100644 --- a/pkg/gpuinfo/memory_linux_nocgo.go +++ b/pkg/gpuinfo/memory_linux_nocgo.go @@ -8,3 +8,8 @@ import "errors" func getVRAMSize(_ string) (uint64, error) { return 0, errors.New("unimplemented without cgo") } + +// hasSupportedAMDGPU returns true if the system has supported AMD GPUs +func hasSupportedAMDGPU() (bool, error) { + return false, errors.New("unimplemented without cgo") +} diff --git a/pkg/gpuinfo/memory_windows.go b/pkg/gpuinfo/memory_windows.go index 7ca9a0e4..db3851dc 100644 --- a/pkg/gpuinfo/memory_windows.go +++ b/pkg/gpuinfo/memory_windows.go @@ -38,3 +38,9 @@ func getVRAMSize(modelRuntimeInstallPath string) (uint64, error) { } return 0, errors.New("unexpected nv-gpu-info output format") } + +// hasSupportedAMDGPU returns true if the system has supported AMD GPUs +func hasSupportedAMDGPU() (bool, error) { + // AMD GPU detection is only supported on Linux + return false, nil +} diff --git a/pkg/inference/backends/llamacpp/download_linux.go b/pkg/inference/backends/llamacpp/download_linux.go index 2b7d55ff..53f7d609 100644 --- a/pkg/inference/backends/llamacpp/download_linux.go +++ b/pkg/inference/backends/llamacpp/download_linux.go @@ -1,18 +1,50 @@ +//go:build linux + package llamacpp import ( "context" "fmt" "net/http" - "path/filepath" + "github.com/docker/model-runner/pkg/gpuinfo" "github.com/docker/model-runner/pkg/logging" ) -func (l *llamaCpp) ensureLatestLlamaCpp(_ context.Context, log logging.Logger, _ *http.Client, - _, vendoredServerStoragePath string, +func init() { + // Enable GPU variant detection by default on Linux + ShouldUseGPUVariantLock.Lock() + defer ShouldUseGPUVariantLock.Unlock() + ShouldUseGPUVariant = true +} + +func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client, + llamaCppPath, vendoredServerStoragePath string, ) error { - l.status = fmt.Sprintf("running llama.cpp version: %s", - getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server"))) - return errLlamaCppUpdateDisabled + var hasAMD bool + var err error + + ShouldUseGPUVariantLock.Lock() + defer ShouldUseGPUVariantLock.Unlock() + if ShouldUseGPUVariant { + // Create GPU info to check for supported AMD GPUs + gpuInfo := gpuinfo.New(vendoredServerStoragePath) + hasAMD, err = gpuInfo.HasSupportedAMDGPU() + if err != nil { + log.Debugf("AMD GPU detection failed: %v", err) + } + } + + desiredVersion := GetDesiredServerVersion() + desiredVariant := "cpu" + + // Use ROCm if supported AMD GPU is detected + if hasAMD { + log.Info("Supported AMD GPU detected, using ROCm variant") + desiredVariant = "rocm" + } + + l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant) + return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion, + desiredVariant) }