Use a logger

This commit is contained in:
cmdr2 2022-12-09 21:30:18 +05:30
parent 79cc84b611
commit cde8c2d3bd
6 changed files with 87 additions and 58 deletions

View File

@ -3,9 +3,21 @@ import socket
import sys import sys
import json import json
import traceback import traceback
import logging
from rich.logging import RichHandler
from sd_internal import task_manager from sd_internal import task_manager
LOG_FORMAT = '[%(threadName)s] %(message)s'
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt="[%X.%f]",
handlers=[RichHandler(markup=True)]
)
log = logging.getLogger()
SD_DIR = os.getcwd() SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None) SD_UI_DIR = os.getenv('SD_UI_PATH', None)
@ -49,8 +61,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
return config return config
except Exception as e: except Exception as e:
print(str(e)) log.warn(traceback.format_exc())
print(traceback.format_exc())
return default_val return default_val
def setConfig(config): def setConfig(config):
@ -59,7 +70,7 @@ def setConfig(config):
with open(config_json_path, 'w', encoding='utf-8') as f: with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f) json.dump(config, f)
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
try: # config.bat try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
@ -78,7 +89,7 @@ def setConfig(config):
with open(config_bat_path, 'w', encoding='utf-8') as f: with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat)) f.write('\r\n'.join(config_bat))
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
try: # config.sh try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
@ -97,7 +108,7 @@ def setConfig(config):
with open(config_sh_path, 'w', encoding='utf-8') as f: with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh)) f.write('\n'.join(config_sh))
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig() config = getConfig()
@ -120,7 +131,7 @@ def update_render_threads():
render_devices = config.get('render_devices', 'auto') render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys() active_devices = task_manager.get_devices()['active'].keys()
print('requesting for render_devices', render_devices) log.debug(f'requesting for render_devices: {render_devices}')
task_manager.update_render_threads(render_devices, active_devices) task_manager.update_render_threads(render_devices, active_devices)
def getUIPlugins(): def getUIPlugins():

View File

@ -2,6 +2,9 @@ import os
import torch import torch
import traceback import traceback
import re import re
import logging
log = logging.getLogger()
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
@ -34,7 +37,7 @@ def get_device_delta(render_devices, active_devices):
if 'auto' in render_devices: if 'auto' in render_devices:
render_devices = auto_pick_devices(active_devices) render_devices = auto_pick_devices(active_devices)
if 'cpu' in render_devices: if 'cpu' in render_devices:
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
active_devices = set(active_devices) active_devices = set(active_devices)
render_devices = set(render_devices) render_devices = set(render_devices)
@ -53,7 +56,7 @@ def auto_pick_devices(currently_active_devices):
if device_count == 1: if device_count == 1:
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
print('Autoselecting GPU. Using most free memory.') log.debug('Autoselecting GPU. Using most free memory.')
devices = [] devices = []
for device in range(device_count): for device in range(device_count):
device = f'cuda:{device}' device = f'cuda:{device}'
@ -64,7 +67,7 @@ def auto_pick_devices(currently_active_devices):
mem_free /= float(10**9) mem_free /= float(10**9)
mem_total /= float(10**9) mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
devices.sort(key=lambda x:x['mem_free'], reverse=True) devices.sort(key=lambda x:x['mem_free'], reverse=True)
@ -94,7 +97,7 @@ def device_init(context, device):
context.device = 'cpu' context.device = 'cpu'
context.device_name = get_processor_name() context.device_name = get_processor_name()
context.precision = 'full' context.precision = 'full'
print('Render device CPU available as', context.device_name) log.debug(f'Render device CPU available as {context.device_name}')
return return
context.device_name = torch.cuda.get_device_name(device) context.device_name = torch.cuda.get_device_name(device)
@ -102,11 +105,11 @@ def device_init(context, device):
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
if needs_to_force_full_precision(context): if needs_to_force_full_precision(context):
print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
context.precision = 'full' context.precision = 'full'
print(f'Setting {device} as active') log.info(f'Setting {device} as active')
torch.cuda.device(device) torch.cuda.device(device)
return return
@ -135,7 +138,7 @@ def is_device_compatible(device):
try: try:
validate_device_id(device, log_prefix='is_device_compatible') validate_device_id(device, log_prefix='is_device_compatible')
except: except:
print(str(e)) log.error(str(e))
return False return False
if device == 'cpu': return True if device == 'cpu': return True
@ -144,10 +147,10 @@ def is_device_compatible(device):
_, mem_total = torch.cuda.mem_get_info(device) _, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) mem_total /= float(10**9)
if mem_total < 3.0: if mem_total < 3.0:
print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
return False return False
except RuntimeError as e: except RuntimeError as e:
print(str(e)) log.error(str(e))
return False return False
return True return True
@ -167,5 +170,5 @@ def get_processor_name():
if "model name" in line: if "model name" in line:
return re.sub(".*model name.*:", "", line, 1).strip() return re.sub(".*model name.*:", "", line, 1).strip()
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
return "cpu" return "cpu"

View File

@ -1,9 +1,12 @@
import os import os
import logging
import picklescan.scanner
import rich
from sd_internal import app, device_manager from sd_internal import app, device_manager
from sd_internal import Request from sd_internal import Request
import picklescan.scanner
import rich log = logging.getLogger()
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = { MODEL_EXTENSIONS = {
@ -42,7 +45,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
if model_name: if model_name:
is_sd2 = config.get('test_sd2', False) is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4' model_name = 'sd-v1-4'
# Check models directory # Check models directory
@ -67,7 +70,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
for model_extension in model_extensions: for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension): if os.path.exists(default_model_path + model_extension):
if model_name is not None: if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension return default_model_path + model_extension
return None return None
@ -88,13 +91,13 @@ def is_malicious_model(file_path):
try: try:
scan_result = picklescan.scanner.scan_file_path(file_path) scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0: if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True return True
else: else:
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False return False
except Exception as e: except Exception as e:
print('error while scanning', file_path, 'error:', e) log.error(f'error while scanning: {file_path}, error: {e}')
return False return False
def getModels(): def getModels():
@ -111,7 +114,10 @@ def getModels():
}, },
} }
models_scanned = 0
def listModels(model_type): def listModels(model_type):
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, []) model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type) models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir): if not os.path.exists(models_dir):
@ -126,6 +132,7 @@ def getModels():
mtime = os.path.getmtime(model_path) mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1 mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime: if mod_time != mtime:
models_scanned += 1
if is_malicious_model(model_path): if is_malicious_model(model_path):
models['scan-error'] = file models['scan-error'] = file
return return
@ -142,6 +149,8 @@ def getModels():
listModels(model_type='vae') listModels(model_type='vae')
listModels(model_type='hypernetwork') listModels(model_type='hypernetwork')
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. 0 infected[/]')
# legacy # legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path): if os.path.exists(custom_weight_path):

View File

@ -6,12 +6,15 @@ import os
import base64 import base64
import re import re
import traceback import traceback
import logging
from sd_internal import device_manager, model_manager from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils, filters as image_filters from modules import model_loader, image_generator, image_utils, filters as image_filters
log = logging.getLogger()
thread_data = threading.local() thread_data = threading.local()
''' '''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
@ -69,7 +72,7 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
try: try:
return _make_images_internal(req, data_queue, task_temp_images, step_callback) return _make_images_internal(req, data_queue, task_temp_images, step_callback)
except Exception as e: except Exception as e:
print(traceback.format_exc()) log.error(traceback.format_exc())
data_queue.put(json.dumps({ data_queue.put(json.dumps({
"status": 'failed', "status": 'failed',
@ -91,7 +94,7 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image
res = Response(req, images=construct_response(req, images)) res = Response(req, images=construct_response(req, images))
res = res.json() res = res.json()
data_queue.put(json.dumps(res)) data_queue.put(json.dumps(res))
print('Task completed') log.info('Task completed')
return res return res

View File

@ -6,6 +6,7 @@ Notes:
""" """
import json import json
import traceback import traceback
import logging
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
@ -16,6 +17,8 @@ from typing import Any, Hashable
from pydantic import BaseModel from pydantic import BaseModel
from sd_internal import Request, device_manager from sd_internal import Request, device_manager
log = logging.getLogger()
THREAD_NAME_PREFIX = 'Runtime-Render/' THREAD_NAME_PREFIX = 'Runtime-Render/'
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
@ -140,11 +143,11 @@ class DataCache():
for key in to_delete: for key in to_delete:
(_, val) = self._base[key] (_, val) = self._base[key]
if isinstance(val, RenderTask): if isinstance(val, RenderTask):
print(f'RenderTask {key} expired. Data removed.') log.debug(f'RenderTask {key} expired. Data removed.')
elif isinstance(val, SessionState): elif isinstance(val, SessionState):
print(f'Session {key} expired. Data removed.') log.debug(f'Session {key} expired. Data removed.')
else: else:
print(f'Key {key} expired. Data removed.') log.debug(f'Key {key} expired. Data removed.')
del self._base[key] del self._base[key]
finally: finally:
self._lock.release() self._lock.release()
@ -178,8 +181,7 @@ class DataCache():
self._get_ttl_time(ttl), value self._get_ttl_time(ttl), value
) )
except Exception as e: except Exception as e:
print(str(e)) log.error(traceback.format_exc())
print(traceback.format_exc())
return False return False
else: else:
return True return True
@ -190,7 +192,7 @@ class DataCache():
try: try:
ttl, value = self._base.get(key, (None, None)) ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl): if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.') log.debug(f'Session {key} expired. Discarding data.')
del self._base[key] del self._base[key]
return None return None
return value return value
@ -234,7 +236,7 @@ class SessionState():
def thread_get_next_task(): def thread_get_next_task():
from sd_internal import runtime2 from sd_internal import runtime2
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.') log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.')
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -269,7 +271,7 @@ def thread_render(device):
try: try:
runtime2.init(device) runtime2.init(device)
except Exception as e: except Exception as e:
print(traceback.format_exc()) log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'error': e 'error': e
} }
@ -287,7 +289,7 @@ def thread_render(device):
session_cache.clean() session_cache.clean()
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']: if not weak_thread_data[threading.current_thread()]['alive']:
print(f'Shutting down thread for device {runtime2.thread_data.device}') log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
runtime2.destroy() runtime2.destroy()
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
@ -299,7 +301,7 @@ def thread_render(device):
idle_event.wait(timeout=1) idle_event.wait(timeout=1)
continue continue
if task.error is not None: if task.error is not None:
print(task.error) log.error(task.error)
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
@ -308,7 +310,7 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
def step_callback(): def step_callback():
@ -319,7 +321,7 @@ def thread_render(device):
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
runtime2.reload_models_if_necessary(task.request) runtime2.reload_models_if_necessary(task.request)
@ -331,7 +333,7 @@ def thread_render(device):
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = e task.error = e
print(traceback.format_exc()) log.error(traceback.format_exc())
continue continue
finally: finally:
# Task completed # Task completed
@ -339,11 +341,11 @@ def thread_render(device):
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!') log.info(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None: elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!') log.info(f'Session {task.request.session_id} task {id(task)} failed!')
else: else:
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False): def get_cached_task(task_id:str, update_ttl:bool=False):
@ -429,7 +431,7 @@ def is_alive(device=None):
def start_render_thread(device): def start_render_thread(device):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device) log.info(f'Start new Rendering Thread on device: {device}')
try: try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True rthread.daemon = True
@ -441,7 +443,7 @@ def start_render_thread(device):
timeout = DEVICE_START_TIMEOUT timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
print(rthread, device, 'error:', weak_thread_data[rthread]['error']) log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False return False
if timeout <= 0: if timeout <= 0:
return False return False
@ -453,11 +455,11 @@ def stop_render_thread(device):
try: try:
device_manager.validate_device_id(device, log_prefix='stop_render_thread') device_manager.validate_device_id(device, log_prefix='stop_render_thread')
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
return False return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
print('Stopping Rendering Thread on device', device) log.info(f'Stopping Rendering Thread on device: {device}')
try: try:
thread_to_remove = None thread_to_remove = None
@ -480,27 +482,27 @@ def stop_render_thread(device):
def update_render_threads(render_devices, active_devices): def update_render_threads(render_devices, active_devices):
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
print('devices_to_start', devices_to_start) log.debug(f'devices_to_start: {devices_to_start}')
print('devices_to_stop', devices_to_stop) log.debug(f'devices_to_stop: {devices_to_stop}')
for device in devices_to_stop: for device in devices_to_stop:
if is_alive(device) <= 0: if is_alive(device) <= 0:
print(device, 'is not alive') log.debug(f'{device} is not alive')
continue continue
if not stop_render_thread(device): if not stop_render_thread(device):
print(device, 'could not stop render thread') log.warn(f'{device} could not stop render thread')
for device in devices_to_start: for device in devices_to_start:
if is_alive(device) >= 1: if is_alive(device) >= 1:
print(device, 'already registered.') log.debug(f'{device} already registered.')
continue continue
if not start_render_thread(device): if not start_render_thread(device):
print(device, 'failed to start.') log.warn(f'{device} failed to start.')
if is_alive() <= 0: # No running devices, probably invalid user config. if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
print('active devices', get_devices()['active']) log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error global current_state_error

View File

@ -14,7 +14,9 @@ from pydantic import BaseModel
from sd_internal import app, model_manager, task_manager from sd_internal import app, model_manager, task_manager
print('started in ', app.SD_DIR) log = logging.getLogger()
log.info(f'started in {app.SD_DIR}')
server_api = FastAPI() server_api = FastAPI()
@ -84,7 +86,7 @@ async def setAppConfig(req : SetAppConfigRequest):
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e: except Exception as e:
print(traceback.format_exc()) log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def update_render_devices_in_config(config, render_devices): def update_render_devices_in_config(config, render_devices):
@ -153,8 +155,7 @@ def render(req : task_manager.ImageRequest):
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e: except Exception as e:
print(e) log.error(traceback.format_exc())
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@server_api.get('/image/stream/{task_id:int}') @server_api.get('/image/stream/{task_id:int}')
@ -165,10 +166,10 @@ def stream(task_id:int):
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked(): if task.buffer_queue.empty() and not task.lock.locked():
if task.response: if task.response:
#print(f'Session {session_id} sending cached response') #log.info(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS) return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') #log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json') return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@server_api.get('/image/stop') @server_api.get('/image/stop')