import { Group, Text, TextInput } from '@mantine/core'; import { InfoTooltip } from '@/components/InfoTooltip'; import { Select } from '@/components/Select'; import { useLaunchConfigStore } from '@/stores/launchConfig'; import type { AccelerationOption } from '@/types'; const GPU_ACCELERATIONS = new Set(['cuda', 'rocm', 'vulkan']); const TENSOR_SPLIT_ACCELERATIONS = new Set(['cuda', 'rocm', 'vulkan']); interface GpuDeviceSelectorProps { availableAccelerations: AccelerationOption[]; } export const GpuDeviceSelector = ({ availableAccelerations }: GpuDeviceSelectorProps) => { const { acceleration, gpuDeviceSelection, tensorSplit, setGpuDeviceSelection, setTensorSplit } = useLaunchConfigStore(); const selectedAcceleration = availableAccelerations.find((a) => a.value === acceleration); const isGpuAcceleration = GPU_ACCELERATIONS.has(acceleration); const getDiscreteDeviceCount = () => { if (!selectedAcceleration?.devices) { return 0; } if (acceleration === 'vulkan' || acceleration === 'rocm') { return selectedAcceleration.devices.filter( (device) => typeof device === 'string' || !device.isIntegrated, ).length; } return selectedAcceleration.devices.length; }; const hasMultipleDevices = getDiscreteDeviceCount() > 1; const showTensorSplit = TENSOR_SPLIT_ACCELERATIONS.has(acceleration) && hasMultipleDevices && gpuDeviceSelection === 'all'; if (!isGpuAcceleration || !hasMultipleDevices) { return null; } const deviceOptions = (() => { if (!selectedAcceleration?.devices) { return []; } if (acceleration === 'vulkan' || acceleration === 'rocm') { const discreteDeviceOptions = selectedAcceleration.devices .map((device, index) => { if (typeof device === 'object' && device.isIntegrated) { return null; } const deviceName = typeof device === 'string' ? device : device.name; return { label: `GPU ${index}: ${deviceName}`, value: index.toString(), }; }) .filter((option): option is NonNullable => option !== null); return [{ label: 'All GPUs', value: 'all' }, ...discreteDeviceOptions]; } return [ { label: 'All GPUs', value: 'all' }, ...selectedAcceleration.devices.map((device, index) => { const deviceName = typeof device === 'string' ? device : typeof device === 'object' && 'name' in device ? device.name : String(device); return { label: `GPU ${index}: ${deviceName}`, value: index.toString(), }; }), ]; })(); return (
GPU Device