fix multi-gpu issue on sycl (llama/8554)

---------

Signed-off-by: Chen Xi <xi2chen@intel.com>
Co-authored-by: Meng, Hengyu <hengyu.meng@intel.com>
This commit is contained in:
Chen Xi 2024-07-25 11:45:18 +00:00 committed by Georgi Gerganov
parent c06970dd72
commit 31d0a9a14f
2 changed files with 89 additions and 14 deletions

View File

@ -267,7 +267,7 @@ struct ggml_backend_sycl_context {
queue_ptr stream(int device, int stream) { queue_ptr stream(int device, int stream) {
if (qptrs[device][stream] == nullptr) { if (qptrs[device][stream] == nullptr) {
qptrs[device][stream] = &(dpct::get_current_device().default_queue()); qptrs[device][stream] = &(dpct::get_device(device).default_queue());
} }
return qptrs[device][stream]; return qptrs[device][stream];
} }

View File

@ -588,7 +588,7 @@ namespace dpct
out = prop; out = prop;
} }
/// dpct device extension /// dpct device extension
class device_ext : public sycl::device { class device_ext : public sycl::device {
typedef std::mutex mutex_type; typedef std::mutex mutex_type;
@ -697,7 +697,7 @@ namespace dpct
std::unique_lock<mutex_type> lock(m_mutex); std::unique_lock<mutex_type> lock(m_mutex);
lock.unlock(); lock.unlock();
for (auto &q : _queues) { for (auto &q : _queues) {
q.wait_and_throw(); q.wait_and_throw();
} }
// Guard the destruct of current_queues to make sure the ref count is // Guard the destruct of current_queues to make sure the ref count is
// safe. // safe.
@ -734,7 +734,12 @@ namespace dpct
void destroy_queue(sycl::queue queue) { void destroy_queue(sycl::queue queue) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
_queues.clear(); _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
[=](const sycl::queue &q) -> bool
{
return q == queue;
}),
_queues.end());
} }
void set_saved_queue(sycl::queue q) { void set_saved_queue(sycl::queue q) {
std::lock_guard<mutex_type> lock(m_mutex); std::lock_guard<mutex_type> lock(m_mutex);
@ -764,13 +769,13 @@ namespace dpct
if (enable_exception_handler) { if (enable_exception_handler) {
eh = exception_handler; eh = exception_handler;
} }
auto q = sycl::queue(*this, eh, _queues.push_back(sycl::queue(
sycl::property_list( *this, eh,
sycl::property_list(
#ifdef DPCT_PROFILING_ENABLED #ifdef DPCT_PROFILING_ENABLED
sycl::property::queue::enable_profiling(), sycl::property::queue::enable_profiling(),
#endif #endif
properties...)); properties...)));
_queues.push_back(q);
return _queues.back(); return _queues.back();
} }
@ -783,8 +788,8 @@ namespace dpct
if (enable_exception_handler) { if (enable_exception_handler) {
eh = exception_handler; eh = exception_handler;
} }
_queues.push_back( _queues.push_back(sycl::queue(
sycl::queue(device, eh, device, eh,
sycl::property_list( sycl::property_list(
#ifdef DPCT_PROFILING_ENABLED #ifdef DPCT_PROFILING_ENABLED
sycl::property::queue::enable_profiling(), sycl::property::queue::enable_profiling(),
@ -855,15 +860,75 @@ namespace dpct
unsigned int get_device_id(const sycl::device &dev) unsigned int get_device_id(const sycl::device &dev)
{ {
unsigned int id = 0; unsigned int id = 0;
for (auto dev_item : _devs) for (auto &dev_item : _devs)
{ {
if (*dev_item == dev) if (*dev_item == dev)
{ {
break; return id;
} }
id++; id++;
} }
return id; return -1;
}
inline std::string get_preferred_gpu_platform_name() {
std::string result;
std::string filter = "level-zero";
char* env = getenv("ONEAPI_DEVICE_SELECTOR");
if (env) {
if (std::strstr(env, "level_zero")) {
filter = "level-zero";
}
else if (std::strstr(env, "opencl")) {
filter = "opencl";
}
else if (std::strstr(env, "cuda")) {
filter = "cuda";
}
else if (std::strstr(env, "hip")) {
filter = "hip";
}
else {
throw std::runtime_error("invalid device filter: " + std::string(env));
}
}
auto plaform_list = sycl::platform::get_platforms();
for (const auto& platform : plaform_list) {
auto devices = platform.get_devices();
auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
return d.is_gpu();
});
if (gpu_dev == devices.end()) {
// cout << "platform [" << platform_name
// << "] does not contain GPU devices, skipping\n";
continue;
}
auto platform_name = platform.get_info<sycl::info::platform::name>();
std::string platform_name_low_case;
platform_name_low_case.resize(platform_name.size());
std::transform(
platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
if (platform_name_low_case.find(filter) == std::string::npos) {
// cout << "platform [" << platform_name
// << "] does not match with requested "
// << filter << ", skipping\n";
continue;
}
result = platform_name;
}
if (result.empty())
throw std::runtime_error("can not find preferred GPU platform");
return result;
} }
template <class DeviceSelector> template <class DeviceSelector>
@ -930,10 +995,15 @@ namespace dpct
// Keep track of the number of devices per backend // Keep track of the number of devices per backend
std::map<sycl::backend, size_t> DeviceNums; std::map<sycl::backend, size_t> DeviceNums;
std::map<std::string, std::vector<sycl::device>> backend_devices; std::map<std::string, std::vector<sycl::device>> backend_devices;
auto preferred_platform_name = get_preferred_gpu_platform_name();
while (!Platforms.empty()) { while (!Platforms.empty()) {
auto Platform = Platforms.back(); auto Platform = Platforms.back();
Platforms.pop_back(); Platforms.pop_back();
auto platform_name = Platform.get_info<sycl::info::platform::name>();
if (platform_name.compare(preferred_platform_name) != 0) {
continue;
}
auto devices = Platform.get_devices(); auto devices = Platform.get_devices();
std::string backend_type = get_device_backend_and_type(devices[0]); std::string backend_type = get_device_backend_and_type(devices[0]);
for (const auto &device : devices) { for (const auto &device : devices) {
@ -1989,6 +2059,11 @@ namespace dpct
return dev_mgr::instance().current_device(); return dev_mgr::instance().current_device();
} }
static inline device_ext &get_device(unsigned int id)
{
return dev_mgr::instance().get_device(id);
}
static inline sycl::queue &get_in_order_queue() static inline sycl::queue &get_in_order_queue()
{ {
return dev_mgr::instance().current_device().in_order_queue(); return dev_mgr::instance().current_device().in_order_queue();