Use a logger

This commit is contained in:
cmdr2 2022-12-09 21:30:18 +05:30
parent 79cc84b611
commit cde8c2d3bd
6 changed files with 87 additions and 58 deletions

View File

@ -3,9 +3,21 @@ import socket
import sys
import json
import traceback
import logging
from rich.logging import RichHandler
from sd_internal import task_manager
LOG_FORMAT = '[%(threadName)s] %(message)s'
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt="[%X.%f]",
handlers=[RichHandler(markup=True)]
)
log = logging.getLogger()
SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
@ -49,8 +61,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
return config
except Exception as e:
print(str(e))
print(traceback.format_exc())
log.warn(traceback.format_exc())
return default_val
def setConfig(config):
@ -59,7 +70,7 @@ def setConfig(config):
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
@ -78,7 +89,7 @@ def setConfig(config):
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
@ -97,7 +108,7 @@ def setConfig(config):
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig()
@ -120,7 +131,7 @@ def update_render_threads():
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
print('requesting for render_devices', render_devices)
log.debug(f'requesting for render_devices: {render_devices}')
task_manager.update_render_threads(render_devices, active_devices)
def getUIPlugins():

View File

@ -2,6 +2,9 @@ import os
import torch
import traceback
import re
import logging
log = logging.getLogger()
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
@ -34,7 +37,7 @@ def get_device_delta(render_devices, active_devices):
if 'auto' in render_devices:
render_devices = auto_pick_devices(active_devices)
if 'cpu' in render_devices:
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
active_devices = set(active_devices)
render_devices = set(render_devices)
@ -53,7 +56,7 @@ def auto_pick_devices(currently_active_devices):
if device_count == 1:
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
print('Autoselecting GPU. Using most free memory.')
log.debug('Autoselecting GPU. Using most free memory.')
devices = []
for device in range(device_count):
device = f'cuda:{device}'
@ -64,7 +67,7 @@ def auto_pick_devices(currently_active_devices):
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
devices.sort(key=lambda x:x['mem_free'], reverse=True)
@ -94,7 +97,7 @@ def device_init(context, device):
context.device = 'cpu'
context.device_name = get_processor_name()
context.precision = 'full'
print('Render device CPU available as', context.device_name)
log.debug(f'Render device CPU available as {context.device_name}')
return
context.device_name = torch.cuda.get_device_name(device)
@ -102,11 +105,11 @@ def device_init(context, device):
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
if needs_to_force_full_precision(context):
print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
# Apply force_full_precision now before models are loaded.
context.precision = 'full'
print(f'Setting {device} as active')
log.info(f'Setting {device} as active')
torch.cuda.device(device)
return
@ -135,7 +138,7 @@ def is_device_compatible(device):
try:
validate_device_id(device, log_prefix='is_device_compatible')
except:
print(str(e))
log.error(str(e))
return False
if device == 'cpu': return True
@ -144,10 +147,10 @@ def is_device_compatible(device):
_, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
return False
except RuntimeError as e:
print(str(e))
log.error(str(e))
return False
return True
@ -167,5 +170,5 @@ def get_processor_name():
if "model name" in line:
return re.sub(".*model name.*:", "", line, 1).strip()
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
return "cpu"

View File

@ -1,9 +1,12 @@
import os
import logging
import picklescan.scanner
import rich
from sd_internal import app, device_manager
from sd_internal import Request
import picklescan.scanner
import rich
log = logging.getLogger()
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = {
@ -42,7 +45,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
@ -67,7 +70,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension
return None
@ -88,13 +91,13 @@ def is_malicious_model(file_path):
try:
scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
log.error(f'error while scanning: {file_path}, error: {e}')
return False
def getModels():
@ -111,7 +114,10 @@ def getModels():
},
}
models_scanned = 0
def listModels(model_type):
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir):
@ -126,6 +132,7 @@ def getModels():
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
models_scanned += 1
if is_malicious_model(model_path):
models['scan-error'] = file
return
@ -142,6 +149,8 @@ def getModels():
listModels(model_type='vae')
listModels(model_type='hypernetwork')
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. 0 infected[/]')
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):

View File

@ -6,12 +6,15 @@ import os
import base64
import re
import traceback
import logging
from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils, filters as image_filters
log = logging.getLogger()
thread_data = threading.local()
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
@ -69,7 +72,7 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
try:
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
except Exception as e:
print(traceback.format_exc())
log.error(traceback.format_exc())
data_queue.put(json.dumps({
"status": 'failed',
@ -91,7 +94,7 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image
res = Response(req, images=construct_response(req, images))
res = res.json()
data_queue.put(json.dumps(res))
print('Task completed')
log.info('Task completed')
return res

View File

@ -6,6 +6,7 @@ Notes:
"""
import json
import traceback
import logging
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
@ -16,6 +17,8 @@ from typing import Any, Hashable
from pydantic import BaseModel
from sd_internal import Request, device_manager
log = logging.getLogger()
THREAD_NAME_PREFIX = 'Runtime-Render/'
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
@ -140,11 +143,11 @@ class DataCache():
for key in to_delete:
(_, val) = self._base[key]
if isinstance(val, RenderTask):
print(f'RenderTask {key} expired. Data removed.')
log.debug(f'RenderTask {key} expired. Data removed.')
elif isinstance(val, SessionState):
print(f'Session {key} expired. Data removed.')
log.debug(f'Session {key} expired. Data removed.')
else:
print(f'Key {key} expired. Data removed.')
log.debug(f'Key {key} expired. Data removed.')
del self._base[key]
finally:
self._lock.release()
@ -178,8 +181,7 @@ class DataCache():
self._get_ttl_time(ttl), value
)
except Exception as e:
print(str(e))
print(traceback.format_exc())
log.error(traceback.format_exc())
return False
else:
return True
@ -190,7 +192,7 @@ class DataCache():
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.')
log.debug(f'Session {key} expired. Discarding data.')
del self._base[key]
return None
return value
@ -234,7 +236,7 @@ class SessionState():
def thread_get_next_task():
from sd_internal import runtime2
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.')
log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.')
return None
if len(tasks_queue) <= 0:
manager_lock.release()
@ -269,7 +271,7 @@ def thread_render(device):
try:
runtime2.init(device)
except Exception as e:
print(traceback.format_exc())
log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
'error': e
}
@ -287,7 +289,7 @@ def thread_render(device):
session_cache.clean()
task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']:
print(f'Shutting down thread for device {runtime2.thread_data.device}')
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
runtime2.destroy()
return
if isinstance(current_state_error, SystemExit):
@ -299,7 +301,7 @@ def thread_render(device):
idle_event.wait(timeout=1)
continue
if task.error is not None:
print(task.error)
log.error(task.error)
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
@ -308,7 +310,7 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
def step_callback():
@ -319,7 +321,7 @@ def thread_render(device):
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel
runtime2.reload_models_if_necessary(task.request)
@ -331,7 +333,7 @@ def thread_render(device):
session_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e:
task.error = e
print(traceback.format_exc())
log.error(traceback.format_exc())
continue
finally:
# Task completed
@ -339,11 +341,11 @@ def thread_render(device):
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
log.info(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
log.info(f'Session {task.request.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False):
@ -429,7 +431,7 @@ def is_alive(device=None):
def start_render_thread(device):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device)
log.info(f'Start new Rendering Thread on device: {device}')
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
@ -441,7 +443,7 @@ def start_render_thread(device):
timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
print(rthread, device, 'error:', weak_thread_data[rthread]['error'])
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False
if timeout <= 0:
return False
@ -453,11 +455,11 @@ def stop_render_thread(device):
try:
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
print('Stopping Rendering Thread on device', device)
log.info(f'Stopping Rendering Thread on device: {device}')
try:
thread_to_remove = None
@ -480,27 +482,27 @@ def stop_render_thread(device):
def update_render_threads(render_devices, active_devices):
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
print('devices_to_start', devices_to_start)
print('devices_to_stop', devices_to_stop)
log.debug(f'devices_to_start: {devices_to_start}')
log.debug(f'devices_to_stop: {devices_to_stop}')
for device in devices_to_stop:
if is_alive(device) <= 0:
print(device, 'is not alive')
log.debug(f'{device} is not alive')
continue
if not stop_render_thread(device):
print(device, 'could not stop render thread')
log.warn(f'{device} could not stop render thread')
for device in devices_to_start:
if is_alive(device) >= 1:
print(device, 'already registered.')
log.debug(f'{device} already registered.')
continue
if not start_render_thread(device):
print(device, 'failed to start.')
log.warn(f'{device} failed to start.')
if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
print('active devices', get_devices()['active'])
log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error

View File

@ -14,7 +14,9 @@ from pydantic import BaseModel
from sd_internal import app, model_manager, task_manager
print('started in ', app.SD_DIR)
log = logging.getLogger()
log.info(f'started in {app.SD_DIR}')
server_api = FastAPI()
@ -84,7 +86,7 @@ async def setAppConfig(req : SetAppConfigRequest):
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def update_render_devices_in_config(config, render_devices):
@ -153,8 +155,7 @@ def render(req : task_manager.ImageRequest):
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
print(e)
print(traceback.format_exc())
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@server_api.get('/image/stream/{task_id:int}')
@ -165,10 +166,10 @@ def stream(task_id:int):
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
#log.info(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@server_api.get('/image/stop')