From cde8c2d3bd3f307d2c19a3fcdda689b9ab2e26b9 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 21:30:18 +0530 Subject: [PATCH] Use a logger --- ui/sd_internal/app.py | 23 +++++++++---- ui/sd_internal/device_manager.py | 23 +++++++------ ui/sd_internal/model_manager.py | 23 +++++++++---- ui/sd_internal/runtime2.py | 7 ++-- ui/sd_internal/task_manager.py | 56 +++++++++++++++++--------------- ui/server.py | 13 ++++---- 6 files changed, 87 insertions(+), 58 deletions(-) diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index d1cec46f..c34f2ea6 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -3,9 +3,21 @@ import socket import sys import json import traceback +import logging +from rich.logging import RichHandler from sd_internal import task_manager +LOG_FORMAT = '[%(threadName)s] %(message)s' +logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + datefmt="[%X.%f]", + handlers=[RichHandler(markup=True)] +) + +log = logging.getLogger() + SD_DIR = os.getcwd() SD_UI_DIR = os.getenv('SD_UI_PATH', None) @@ -49,8 +61,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') return config except Exception as e: - print(str(e)) - print(traceback.format_exc()) + log.warn(traceback.format_exc()) return default_val def setConfig(config): @@ -59,7 +70,7 @@ def setConfig(config): with open(config_json_path, 'w', encoding='utf-8') as f: json.dump(config, f) except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) try: # config.bat config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') @@ -78,7 +89,7 @@ def setConfig(config): with open(config_bat_path, 'w', encoding='utf-8') as f: f.write('\r\n'.join(config_bat)) except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) try: # config.sh config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') @@ -97,7 +108,7 @@ def setConfig(config): with open(config_sh_path, 'w', encoding='utf-8') as f: f.write('\n'.join(config_sh)) except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): config = getConfig() @@ -120,7 +131,7 @@ def update_render_threads(): render_devices = config.get('render_devices', 'auto') active_devices = task_manager.get_devices()['active'].keys() - print('requesting for render_devices', render_devices) + log.debug(f'requesting for render_devices: {render_devices}') task_manager.update_render_threads(render_devices, active_devices) def getUIPlugins(): diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index c490a0c6..a3f91cfb 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -2,6 +2,9 @@ import os import torch import traceback import re +import logging + +log = logging.getLogger() COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked @@ -34,7 +37,7 @@ def get_device_delta(render_devices, active_devices): if 'auto' in render_devices: render_devices = auto_pick_devices(active_devices) if 'cpu' in render_devices: - print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') + log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') active_devices = set(active_devices) render_devices = set(render_devices) @@ -53,7 +56,7 @@ def auto_pick_devices(currently_active_devices): if device_count == 1: return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] - print('Autoselecting GPU. Using most free memory.') + log.debug('Autoselecting GPU. Using most free memory.') devices = [] for device in range(device_count): device = f'cuda:{device}' @@ -64,7 +67,7 @@ def auto_pick_devices(currently_active_devices): mem_free /= float(10**9) mem_total /= float(10**9) device_name = torch.cuda.get_device_name(device) - print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') + log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) devices.sort(key=lambda x:x['mem_free'], reverse=True) @@ -94,7 +97,7 @@ def device_init(context, device): context.device = 'cpu' context.device_name = get_processor_name() context.precision = 'full' - print('Render device CPU available as', context.device_name) + log.debug(f'Render device CPU available as {context.device_name}') return context.device_name = torch.cuda.get_device_name(device) @@ -102,11 +105,11 @@ def device_init(context, device): # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images if needs_to_force_full_precision(context): - print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') + log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. context.precision = 'full' - print(f'Setting {device} as active') + log.info(f'Setting {device} as active') torch.cuda.device(device) return @@ -135,7 +138,7 @@ def is_device_compatible(device): try: validate_device_id(device, log_prefix='is_device_compatible') except: - print(str(e)) + log.error(str(e)) return False if device == 'cpu': return True @@ -144,10 +147,10 @@ def is_device_compatible(device): _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 3.0: - print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') + log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') return False except RuntimeError as e: - print(str(e)) + log.error(str(e)) return False return True @@ -167,5 +170,5 @@ def get_processor_name(): if "model name" in line: return re.sub(".*model name.*:", "", line, 1).strip() except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) return "cpu" diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 6cf9428a..827434d7 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -1,9 +1,12 @@ import os +import logging +import picklescan.scanner +import rich from sd_internal import app, device_manager from sd_internal import Request -import picklescan.scanner -import rich + +log = logging.getLogger() KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] MODEL_EXTENSIONS = { @@ -42,7 +45,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): if model_name: is_sd2 = config.get('test_sd2', False) if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 - print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') + log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') model_name = 'sd-v1-4' # Check models directory @@ -67,7 +70,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): for model_extension in model_extensions: if os.path.exists(default_model_path + model_extension): if model_name is not None: - print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') return default_model_path + model_extension return None @@ -88,13 +91,13 @@ def is_malicious_model(file_path): try: scan_result = picklescan.scanner.scan_file_path(file_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: - rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) return True else: - rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) return False except Exception as e: - print('error while scanning', file_path, 'error:', e) + log.error(f'error while scanning: {file_path}, error: {e}') return False def getModels(): @@ -111,7 +114,10 @@ def getModels(): }, } + models_scanned = 0 def listModels(model_type): + nonlocal models_scanned + model_extensions = MODEL_EXTENSIONS.get(model_type, []) models_dir = os.path.join(app.MODELS_DIR, model_type) if not os.path.exists(models_dir): @@ -126,6 +132,7 @@ def getModels(): mtime = os.path.getmtime(model_path) mod_time = known_models[model_path] if model_path in known_models else -1 if mod_time != mtime: + models_scanned += 1 if is_malicious_model(model_path): models['scan-error'] = file return @@ -142,6 +149,8 @@ def getModels(): listModels(model_type='vae') listModels(model_type='hypernetwork') + if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. 0 infected[/]') + # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') if os.path.exists(custom_weight_path): diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index f74acc12..c7845079 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -6,12 +6,15 @@ import os import base64 import re import traceback +import logging from sd_internal import device_manager, model_manager from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters +log = logging.getLogger() + thread_data = threading.local() ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc @@ -69,7 +72,7 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s try: return _make_images_internal(req, data_queue, task_temp_images, step_callback) except Exception as e: - print(traceback.format_exc()) + log.error(traceback.format_exc()) data_queue.put(json.dumps({ "status": 'failed', @@ -91,7 +94,7 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image res = Response(req, images=construct_response(req, images)) res = res.json() data_queue.put(json.dumps(res)) - print('Task completed') + log.info('Task completed') return res diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 6b0a1d8c..852631c6 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -6,6 +6,7 @@ Notes: """ import json import traceback +import logging TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout @@ -16,6 +17,8 @@ from typing import Any, Hashable from pydantic import BaseModel from sd_internal import Request, device_manager +log = logging.getLogger() + THREAD_NAME_PREFIX = 'Runtime-Render/' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. @@ -140,11 +143,11 @@ class DataCache(): for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): - print(f'RenderTask {key} expired. Data removed.') + log.debug(f'RenderTask {key} expired. Data removed.') elif isinstance(val, SessionState): - print(f'Session {key} expired. Data removed.') + log.debug(f'Session {key} expired. Data removed.') else: - print(f'Key {key} expired. Data removed.') + log.debug(f'Key {key} expired. Data removed.') del self._base[key] finally: self._lock.release() @@ -178,8 +181,7 @@ class DataCache(): self._get_ttl_time(ttl), value ) except Exception as e: - print(str(e)) - print(traceback.format_exc()) + log.error(traceback.format_exc()) return False else: return True @@ -190,7 +192,7 @@ class DataCache(): try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): - print(f'Session {key} expired. Discarding data.') + log.debug(f'Session {key} expired. Discarding data.') del self._base[key] return None return value @@ -234,7 +236,7 @@ class SessionState(): def thread_get_next_task(): from sd_internal import runtime2 if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.') + log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() @@ -269,7 +271,7 @@ def thread_render(device): try: runtime2.init(device) except Exception as e: - print(traceback.format_exc()) + log.error(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { 'error': e } @@ -287,7 +289,7 @@ def thread_render(device): session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: - print(f'Shutting down thread for device {runtime2.thread_data.device}') + log.info(f'Shutting down thread for device {runtime2.thread_data.device}') runtime2.destroy() return if isinstance(current_state_error, SystemExit): @@ -299,7 +301,7 @@ def thread_render(device): idle_event.wait(timeout=1) continue if task.error is not None: - print(task.error) + log.error(task.error) task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue @@ -308,7 +310,7 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') + log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: def step_callback(): @@ -319,7 +321,7 @@ def thread_render(device): if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel runtime2.reload_models_if_necessary(task.request) @@ -331,7 +333,7 @@ def thread_render(device): session_cache.keep(task.request.session_id, TASK_TTL) except Exception as e: task.error = e - print(traceback.format_exc()) + log.error(traceback.format_exc()) continue finally: # Task completed @@ -339,11 +341,11 @@ def thread_render(device): task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - print(f'Session {task.request.session_id} task {id(task)} cancelled!') + log.info(f'Session {task.request.session_id} task {id(task)} cancelled!') elif task.error is not None: - print(f'Session {task.request.session_id} task {id(task)} failed!') + log.info(f'Session {task.request.session_id} task {id(task)} failed!') else: - print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') + log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): @@ -429,7 +431,7 @@ def is_alive(device=None): def start_render_thread(device): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) - print('Start new Rendering Thread on device', device) + log.info(f'Start new Rendering Thread on device: {device}') try: rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread.daemon = True @@ -441,7 +443,7 @@ def start_render_thread(device): timeout = DEVICE_START_TIMEOUT while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: - print(rthread, device, 'error:', weak_thread_data[rthread]['error']) + log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") return False if timeout <= 0: return False @@ -453,11 +455,11 @@ def stop_render_thread(device): try: device_manager.validate_device_id(device, log_prefix='stop_render_thread') except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) return False if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) - print('Stopping Rendering Thread on device', device) + log.info(f'Stopping Rendering Thread on device: {device}') try: thread_to_remove = None @@ -480,27 +482,27 @@ def stop_render_thread(device): def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) - print('devices_to_start', devices_to_start) - print('devices_to_stop', devices_to_stop) + log.debug(f'devices_to_start: {devices_to_start}') + log.debug(f'devices_to_stop: {devices_to_stop}') for device in devices_to_stop: if is_alive(device) <= 0: - print(device, 'is not alive') + log.debug(f'{device} is not alive') continue if not stop_render_thread(device): - print(device, 'could not stop render thread') + log.warn(f'{device} could not stop render thread') for device in devices_to_start: if is_alive(device) >= 1: - print(device, 'already registered.') + log.debug(f'{device} already registered.') continue if not start_render_thread(device): - print(device, 'failed to start.') + log.warn(f'{device} failed to start.') if is_alive() <= 0: # No running devices, probably invalid user config. raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') - print('active devices', get_devices()['active']) + log.debug(f"active devices: {get_devices()['active']}") def shutdown_event(): # Signal render thread to close on shutdown global current_state_error diff --git a/ui/server.py b/ui/server.py index 458bc1aa..aaa16ce0 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,7 +14,9 @@ from pydantic import BaseModel from sd_internal import app, model_manager, task_manager -print('started in ', app.SD_DIR) +log = logging.getLogger() + +log.info(f'started in {app.SD_DIR}') server_api = FastAPI() @@ -84,7 +86,7 @@ async def setAppConfig(req : SetAppConfigRequest): return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) except Exception as e: - print(traceback.format_exc()) + log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) def update_render_devices_in_config(config, render_devices): @@ -153,8 +155,7 @@ def render(req : task_manager.ImageRequest): except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable except Exception as e: - print(e) - print(traceback.format_exc()) + log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @server_api.get('/image/stream/{task_id:int}') @@ -165,10 +166,10 @@ def stream(task_id:int): #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): if task.response: - #print(f'Session {session_id} sending cached response') + #log.info(f'Session {session_id} sending cached response') return JSONResponse(task.response, headers=NOCACHE_HEADERS) raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early - #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + #log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') return StreamingResponse(task.read_buffer_generator(), media_type='application/json') @server_api.get('/image/stop')