diff --git a/ui/media/js/main.js b/ui/media/js/main.js index fe14d50e..23f0ef74 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -803,7 +803,6 @@ function getCurrentUserRequest() { height: heightField.value, // allow_nsfw: allowNSFWField.checked, turbo: turboField.checked, - render_device: getCurrentRenderDeviceSelection(), use_full_precision: useFullPrecisionField.checked, use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, @@ -839,19 +838,6 @@ function getCurrentUserRequest() { return newTask } -function getCurrentRenderDeviceSelection() { - let selectedGPUs = $('#use_gpus').val() - - if (useCPUField.checked && !autoPickGPUsField.checked) { - return 'cpu' - } - if (autoPickGPUsField.checked || selectedGPUs.length == 0) { - return 'auto' - } - - return selectedGPUs.join(',') -} - function makeImage() { if (!isServerAvailable()) { alert('The server is not available.') @@ -1165,6 +1151,19 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength) promptStrengthField.addEventListener('input', updatePromptStrengthSlider) updatePromptStrength() +function getCurrentRenderDeviceSelection() { + let selectedGPUs = $('#use_gpus').val() + + if (useCPUField.checked && !autoPickGPUsField.checked) { + return 'cpu' + } + if (autoPickGPUsField.checked || selectedGPUs.length == 0) { + return 'auto' + } + + return selectedGPUs.join(',') +} + useCPUField.addEventListener('click', function() { let gpuSettingEntry = getParameterSettingsEntry('use_gpus') let autoPickGPUSettingEntry = getParameterSettingsEntry('auto_pick_gpus') @@ -1184,11 +1183,19 @@ useCPUField.addEventListener('click', function() { } gpuSettingEntry.style.display = (autoPickGPUsField.checked ? 'none' : '') } + + changeAppConfig({ + 'render_devices': getCurrentRenderDeviceSelection() + }) }) useGPUsField.addEventListener('click', function() { let selectedGPUs = $('#use_gpus').val() autoPickGPUsField.checked = (selectedGPUs.length === 0) + + changeAppConfig({ + 'render_devices': getCurrentRenderDeviceSelection() + }) }) autoPickGPUsField.addEventListener('click', function() { @@ -1198,6 +1205,10 @@ autoPickGPUsField.addEventListener('click', function() { let gpuSettingEntry = getParameterSettingsEntry('use_gpus') gpuSettingEntry.style.display = (this.checked ? 'none' : '') + + changeAppConfig({ + 'render_devices': getCurrentRenderDeviceSelection() + }) }) async function changeAppConfig(configDelta) { diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index 18631469..a71ab015 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -13,10 +13,12 @@ def get_device_delta(render_devices, active_devices): active_devices: ['cpu', 'cuda:N'...] ''' - if render_devices is not None: - if render_devices in ('cpu', 'auto'): + if render_devices in ('cpu', 'auto'): + render_devices = [render_devices] + elif render_devices is not None: + if isinstance(render_devices, str): render_devices = [render_devices] - elif isinstance(render_devices, list) and len(render_devices) > 0: + if isinstance(render_devices, list) and len(render_devices) > 0: render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices)) if len(render_devices) == 0: raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index a1b70418..343c60a0 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -40,7 +40,7 @@ class RenderTask(): # Task with output queue and completion lock. def __init__(self, req: Request): self.request: Request = req # Initial Request self.response: Any = None # Copy of the last reponse - self.render_device = None + self.render_device = None # Select the task affinity. (Not used to change active devices). self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.error: Exception = None self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed @@ -72,7 +72,7 @@ class ImageRequest(BaseModel): save_to_disk_path: str = None turbo: bool = True use_cpu: bool = False ##TODO Remove after UI and plugins transition. - render_device: str = 'auto' + render_device: str = None # Select the task affinity. (Not used to change active devices). use_full_precision: bool = False use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" diff --git a/ui/server.py b/ui/server.py index e6a0a9f4..7ad1d63e 100644 --- a/ui/server.py +++ b/ui/server.py @@ -161,6 +161,8 @@ async def setAppConfig(req : SetAppConfigRequest): config = getConfig() if req.update_branch: config['update_branch'] = req.update_branch + if req.render_devices: + update_render_threads_from_request(req.render_devices) try: setConfig(config) return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) @@ -287,27 +289,18 @@ def save_render_devices_to_config(render_devices): setConfig(config) -def update_render_threads_on_request(req : task_manager.ImageRequest): - if req.use_cpu: # TODO Remove after transition. - print('WARNING Replace {use_cpu: true} by {render_device: "cpu"}') - req.render_device = 'cpu' - del req.use_cpu +def update_render_threads_from_request(render_device): + if render_device not in ('cpu', 'auto') and not render_device.startswith('cuda:'): + raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_device}') - if req.render_device not in ('cpu', 'auto') and not req.render_device.startswith('cuda:'): - raise HTTPException(status_code=400, detail=f'Invalid render device requested: {req.render_device}') - - if req.render_device.startswith('cuda:'): - req.render_device = req.render_device.split(',') - - save_render_devices_to_config(req.render_device) - del req.render_device + if render_device.startswith('cuda:'): + render_device = render_device.split(',') + save_render_devices_to_config(render_device) update_render_threads() @app.post('/render') def render(req : task_manager.ImageRequest): - update_render_threads_on_request(req) - try: save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model) req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)