Skip to content

Commit b7e9f39

Browse files
committed
Automatically use ROCm when appropriate
On Linux at least Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent acab8b5 commit b7e9f39

File tree

10 files changed

+202
-6
lines changed

10 files changed

+202
-6
lines changed

main.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ func main() {
116116

117117
log.Infof("LLAMA_SERVER_PATH: %s", llamaServerPath)
118118

119+
// Auto-detect GPU type and set appropriate variant
120+
// Check if we have supported AMD GPUs and set ROCm variant accordingly
121+
if hasAMD, err := gpuInfo.HasSupportedAMDGPU(); err == nil && hasAMD {
122+
log.Info("Supported AMD GPU detected, ROCm will be used automatically")
123+
// This will be handled by the llama.cpp backend during server download
124+
} else if err != nil {
125+
log.Debugf("AMD GPU detection failed: %v", err)
126+
}
127+
119128
// Create llama.cpp configuration from environment variables
120129
llamaCppConfig := createLlamaCppConfigFromEnv()
121130

pkg/gpuinfo/amd_gpu_linux.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//go:build linux
2+
3+
package gpuinfo
4+
5+
import (
6+
"bufio"
7+
"os"
8+
"path/filepath"
9+
"regexp"
10+
"sort"
11+
"strconv"
12+
"strings"
13+
)
14+
15+
// supportedAMDGPUs are the AMD GPU targets that should use ROCm
16+
var supportedAMDGPUs = map[string]bool{
17+
"gfx908": true,
18+
"gfx90a": true,
19+
"gfx942": true,
20+
"gfx1010": true,
21+
"gfx1030": true,
22+
"gfx1100": true,
23+
"gfx1200": true,
24+
"gfx1201": true,
25+
"gfx1151": true,
26+
}
27+
28+
func hasSupportedAMDGPU() (bool, error) {
29+
// Check if KFD topology directory exists
30+
topologyDir := "/sys/class/kfd/kfd/topology/nodes/"
31+
info, err := os.Stat(topologyDir)
32+
if err != nil || !info.IsDir() {
33+
return false, nil // KFD not available
34+
}
35+
36+
entries, err := os.ReadDir(topologyDir)
37+
if err != nil {
38+
return false, err
39+
}
40+
41+
// Sort entries by name to maintain consistent order
42+
sort.Slice(entries, func(i, j int) bool {
43+
return entries[i].Name() < entries[j].Name()
44+
})
45+
46+
// Compile regex to match gfx_target_version lines
47+
reTarget := regexp.MustCompile(`gfx_target_version[ \t]+([0-9]+)`)
48+
49+
for _, e := range entries {
50+
if !e.IsDir() {
51+
continue
52+
}
53+
nodePath := filepath.Join(topologyDir, e.Name())
54+
propPath := filepath.Join(nodePath, "properties")
55+
56+
// Attempt to open the properties file directly; skip on error (e.g., permissions)
57+
f, err := os.Open(propPath)
58+
if err != nil {
59+
// Could be permission denied or file doesn't exist; just skip like the Python code
60+
continue
61+
}
62+
63+
sc := bufio.NewScanner(f)
64+
for sc.Scan() {
65+
line := sc.Text()
66+
matches := reTarget.FindStringSubmatch(line)
67+
if len(matches) < 2 {
68+
continue
69+
}
70+
71+
deviceIDStr := matches[1]
72+
deviceID, err := strconv.Atoi(deviceIDStr)
73+
if err != nil || deviceID == 0 {
74+
continue
75+
}
76+
77+
var majorVer, minorVer, steppingVer int
78+
if gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION"); gfxOverride != "" {
79+
parts := strings.Split(strings.TrimSpace(gfxOverride), ".")
80+
if len(parts) != 3 {
81+
// Invalid format, skip
82+
continue
83+
}
84+
mv, err1 := strconv.Atoi(parts[0])
85+
nv, err2 := strconv.Atoi(parts[1])
86+
sv, err3 := strconv.Atoi(parts[2])
87+
if err1 != nil || err2 != nil || err3 != nil {
88+
// Invalid format, skip
89+
continue
90+
}
91+
if mv > 63 || nv > 255 || sv > 255 {
92+
// Invalid values, skip
93+
continue
94+
}
95+
majorVer, minorVer, steppingVer = mv, nv, sv
96+
} else {
97+
majorVer = (deviceID / 10000) % 100
98+
minorVer = (deviceID / 100) % 100
99+
steppingVer = deviceID % 100
100+
}
101+
102+
gfx := "gfx" +
103+
strconv.FormatInt(int64(majorVer), 10) +
104+
strconv.FormatInt(int64(minorVer), 16) +
105+
strconv.FormatInt(int64(steppingVer), 16)
106+
107+
if supportedAMDGPUs[gfx] {
108+
f.Close()
109+
return true, nil // Found a supported AMD GPU
110+
}
111+
}
112+
f.Close()
113+
}
114+
115+
return false, nil // No supported AMD GPU found
116+
}

pkg/gpuinfo/gpuinfo.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ func New(modelRuntimeInstallPath string) *GPUInfo {
1515
func (g *GPUInfo) GetVRAMSize() (uint64, error) {
1616
return getVRAMSize(g.modelRuntimeInstallPath)
1717
}
18+
19+
func (g *GPUInfo) HasSupportedAMDGPU() (bool, error) {
20+
return hasSupportedAMDGPU()
21+
}

pkg/gpuinfo/gpuinfo_linux.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
//go:build linux
2+
3+
package gpuinfo

pkg/gpuinfo/gpuinfo_not_linux.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//go:build !linux
2+
3+
package gpuinfo
4+
5+
func (g *GPUInfo) HasSupportedAMDGPU() (bool, error) {
6+
// AMD GPU detection is only supported on Linux
7+
return false, nil
8+
}

pkg/gpuinfo/memory_darwin_cgo.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,9 @@ func getVRAMSize(_ string) (uint64, error) {
1717
}
1818
return uint64(vramSize), nil
1919
}
20+
21+
// hasSupportedAMDGPU returns true if the system has supported AMD GPUs
22+
func hasSupportedAMDGPU() (bool, error) {
23+
// AMD GPU detection is only supported on Linux
24+
return false, nil
25+
}

pkg/gpuinfo/memory_darwin_nocgo.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,9 @@ import "errors"
88
func getVRAMSize(_ string) (uint64, error) {
99
return 0, errors.New("unimplemented without cgo")
1010
}
11+
12+
// hasSupportedAMDGPU returns true if the system has supported AMD GPUs
13+
func hasSupportedAMDGPU() (bool, error) {
14+
// AMD GPU detection is only supported on Linux
15+
return false, nil
16+
}

pkg/gpuinfo/memory_linux_nocgo.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@ import "errors"
88
func getVRAMSize(_ string) (uint64, error) {
99
return 0, errors.New("unimplemented without cgo")
1010
}
11+
12+
// hasSupportedAMDGPU returns true if the system has supported AMD GPUs
13+
func hasSupportedAMDGPU() (bool, error) {
14+
return false, errors.New("unimplemented without cgo")
15+
}

pkg/gpuinfo/memory_windows.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,9 @@ func getVRAMSize(modelRuntimeInstallPath string) (uint64, error) {
3838
}
3939
return 0, errors.New("unexpected nv-gpu-info output format")
4040
}
41+
42+
// hasSupportedAMDGPU returns true if the system has supported AMD GPUs
43+
func hasSupportedAMDGPU() (bool, error) {
44+
// AMD GPU detection is only supported on Linux
45+
return false, nil
46+
}
Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,51 @@
1+
//go:build linux
2+
13
package llamacpp
24

35
import (
46
"context"
57
"fmt"
68
"net/http"
7-
"path/filepath"
9+
"os"
810

11+
"github.com/docker/model-runner/pkg/gpuinfo"
912
"github.com/docker/model-runner/pkg/logging"
1013
)
1114

12-
func (l *llamaCpp) ensureLatestLlamaCpp(_ context.Context, log logging.Logger, _ *http.Client,
13-
_, vendoredServerStoragePath string,
15+
func init() {
16+
// Enable GPU variant detection by default on Linux
17+
ShouldUseGPUVariantLock.Lock()
18+
defer ShouldUseGPUVariantLock.Unlock()
19+
ShouldUseGPUVariant = true
20+
}
21+
22+
func (l *llamaCpp) ensureLatestLlamaCpp(ctx context.Context, log logging.Logger, httpClient *http.Client,
23+
llamaCppPath, vendoredServerStoragePath string,
1424
) error {
15-
l.status = fmt.Sprintf("running llama.cpp version: %s",
16-
getLlamaCppVersion(log, filepath.Join(vendoredServerStoragePath, "com.docker.llama-server")))
17-
return errLlamaCppUpdateDisabled
25+
var hasAMD bool
26+
var err error
27+
28+
ShouldUseGPUVariantLock.Lock()
29+
defer ShouldUseGPUVariantLock.Unlock()
30+
if ShouldUseGPUVariant {
31+
// Create GPU info to check for supported AMD GPUs
32+
gpuInfo := gpuinfo.New(vendoredServerStoragePath)
33+
hasAMD, err = gpuInfo.HasSupportedAMDGPU()
34+
if err != nil {
35+
log.Debugf("AMD GPU detection failed: %v", err)
36+
}
37+
}
38+
39+
desiredVersion := GetDesiredServerVersion()
40+
desiredVariant := "cpu"
41+
42+
// Use ROCm if supported AMD GPU is detected
43+
if hasAMD {
44+
log.Info("Supported AMD GPU detected, using ROCm variant")
45+
desiredVariant = "rocm"
46+
}
47+
48+
l.status = fmt.Sprintf("looking for updates for %s variant", desiredVariant)
49+
return l.downloadLatestLlamaCpp(ctx, log, httpClient, llamaCppPath, vendoredServerStoragePath, desiredVersion,
50+
desiredVariant)
1851
}

0 commit comments

Comments
 (0)