mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-18 03:11:10 +01:00
Use a logger
This commit is contained in:
parent
79cc84b611
commit
cde8c2d3bd
@ -3,9 +3,21 @@ import socket
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
import logging
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
from sd_internal import task_manager
|
from sd_internal import task_manager
|
||||||
|
|
||||||
|
LOG_FORMAT = '[%(threadName)s] %(message)s'
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format=LOG_FORMAT,
|
||||||
|
datefmt="[%X.%f]",
|
||||||
|
handlers=[RichHandler(markup=True)]
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
SD_DIR = os.getcwd()
|
SD_DIR = os.getcwd()
|
||||||
|
|
||||||
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
||||||
@ -49,8 +61,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
|||||||
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
|
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
log.warn(traceback.format_exc())
|
||||||
print(traceback.format_exc())
|
|
||||||
return default_val
|
return default_val
|
||||||
|
|
||||||
def setConfig(config):
|
def setConfig(config):
|
||||||
@ -59,7 +70,7 @@ def setConfig(config):
|
|||||||
with open(config_json_path, 'w', encoding='utf-8') as f:
|
with open(config_json_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
except:
|
except:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
try: # config.bat
|
try: # config.bat
|
||||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||||
@ -78,7 +89,7 @@ def setConfig(config):
|
|||||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||||
f.write('\r\n'.join(config_bat))
|
f.write('\r\n'.join(config_bat))
|
||||||
except:
|
except:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
try: # config.sh
|
try: # config.sh
|
||||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||||
@ -97,7 +108,7 @@ def setConfig(config):
|
|||||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||||
f.write('\n'.join(config_sh))
|
f.write('\n'.join(config_sh))
|
||||||
except:
|
except:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
|
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
@ -120,7 +131,7 @@ def update_render_threads():
|
|||||||
render_devices = config.get('render_devices', 'auto')
|
render_devices = config.get('render_devices', 'auto')
|
||||||
active_devices = task_manager.get_devices()['active'].keys()
|
active_devices = task_manager.get_devices()['active'].keys()
|
||||||
|
|
||||||
print('requesting for render_devices', render_devices)
|
log.debug(f'requesting for render_devices: {render_devices}')
|
||||||
task_manager.update_render_threads(render_devices, active_devices)
|
task_manager.update_render_threads(render_devices, active_devices)
|
||||||
|
|
||||||
def getUIPlugins():
|
def getUIPlugins():
|
||||||
|
@ -2,6 +2,9 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||||
|
|
||||||
@ -34,7 +37,7 @@ def get_device_delta(render_devices, active_devices):
|
|||||||
if 'auto' in render_devices:
|
if 'auto' in render_devices:
|
||||||
render_devices = auto_pick_devices(active_devices)
|
render_devices = auto_pick_devices(active_devices)
|
||||||
if 'cpu' in render_devices:
|
if 'cpu' in render_devices:
|
||||||
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
||||||
|
|
||||||
active_devices = set(active_devices)
|
active_devices = set(active_devices)
|
||||||
render_devices = set(render_devices)
|
render_devices = set(render_devices)
|
||||||
@ -53,7 +56,7 @@ def auto_pick_devices(currently_active_devices):
|
|||||||
if device_count == 1:
|
if device_count == 1:
|
||||||
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
|
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
|
||||||
|
|
||||||
print('Autoselecting GPU. Using most free memory.')
|
log.debug('Autoselecting GPU. Using most free memory.')
|
||||||
devices = []
|
devices = []
|
||||||
for device in range(device_count):
|
for device in range(device_count):
|
||||||
device = f'cuda:{device}'
|
device = f'cuda:{device}'
|
||||||
@ -64,7 +67,7 @@ def auto_pick_devices(currently_active_devices):
|
|||||||
mem_free /= float(10**9)
|
mem_free /= float(10**9)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
|
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
|
||||||
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
|
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
|
||||||
|
|
||||||
devices.sort(key=lambda x:x['mem_free'], reverse=True)
|
devices.sort(key=lambda x:x['mem_free'], reverse=True)
|
||||||
@ -94,7 +97,7 @@ def device_init(context, device):
|
|||||||
context.device = 'cpu'
|
context.device = 'cpu'
|
||||||
context.device_name = get_processor_name()
|
context.device_name = get_processor_name()
|
||||||
context.precision = 'full'
|
context.precision = 'full'
|
||||||
print('Render device CPU available as', context.device_name)
|
log.debug(f'Render device CPU available as {context.device_name}')
|
||||||
return
|
return
|
||||||
|
|
||||||
context.device_name = torch.cuda.get_device_name(device)
|
context.device_name = torch.cuda.get_device_name(device)
|
||||||
@ -102,11 +105,11 @@ def device_init(context, device):
|
|||||||
|
|
||||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||||
if needs_to_force_full_precision(context):
|
if needs_to_force_full_precision(context):
|
||||||
print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||||
# Apply force_full_precision now before models are loaded.
|
# Apply force_full_precision now before models are loaded.
|
||||||
context.precision = 'full'
|
context.precision = 'full'
|
||||||
|
|
||||||
print(f'Setting {device} as active')
|
log.info(f'Setting {device} as active')
|
||||||
torch.cuda.device(device)
|
torch.cuda.device(device)
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -135,7 +138,7 @@ def is_device_compatible(device):
|
|||||||
try:
|
try:
|
||||||
validate_device_id(device, log_prefix='is_device_compatible')
|
validate_device_id(device, log_prefix='is_device_compatible')
|
||||||
except:
|
except:
|
||||||
print(str(e))
|
log.error(str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device == 'cpu': return True
|
if device == 'cpu': return True
|
||||||
@ -144,10 +147,10 @@ def is_device_compatible(device):
|
|||||||
_, mem_total = torch.cuda.mem_get_info(device)
|
_, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
if mem_total < 3.0:
|
if mem_total < 3.0:
|
||||||
print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
|
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
|
||||||
return False
|
return False
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(str(e))
|
log.error(str(e))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -167,5 +170,5 @@ def get_processor_name():
|
|||||||
if "model name" in line:
|
if "model name" in line:
|
||||||
return re.sub(".*model name.*:", "", line, 1).strip()
|
return re.sub(".*model name.*:", "", line, 1).strip()
|
||||||
except:
|
except:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
import picklescan.scanner
|
||||||
|
import rich
|
||||||
|
|
||||||
from sd_internal import app, device_manager
|
from sd_internal import app, device_manager
|
||||||
from sd_internal import Request
|
from sd_internal import Request
|
||||||
import picklescan.scanner
|
|
||||||
import rich
|
log = logging.getLogger()
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||||
MODEL_EXTENSIONS = {
|
MODEL_EXTENSIONS = {
|
||||||
@ -42,7 +45,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
|||||||
if model_name:
|
if model_name:
|
||||||
is_sd2 = config.get('test_sd2', False)
|
is_sd2 = config.get('test_sd2', False)
|
||||||
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
|
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
|
||||||
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
|
log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
|
||||||
model_name = 'sd-v1-4'
|
model_name = 'sd-v1-4'
|
||||||
|
|
||||||
# Check models directory
|
# Check models directory
|
||||||
@ -67,7 +70,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
|||||||
for model_extension in model_extensions:
|
for model_extension in model_extensions:
|
||||||
if os.path.exists(default_model_path + model_extension):
|
if os.path.exists(default_model_path + model_extension):
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
||||||
return default_model_path + model_extension
|
return default_model_path + model_extension
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -88,13 +91,13 @@ def is_malicious_model(file_path):
|
|||||||
try:
|
try:
|
||||||
scan_result = picklescan.scanner.scan_file_path(file_path)
|
scan_result = picklescan.scanner.scan_file_path(file_path)
|
||||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||||
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('error while scanning', file_path, 'error:', e)
|
log.error(f'error while scanning: {file_path}, error: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def getModels():
|
def getModels():
|
||||||
@ -111,7 +114,10 @@ def getModels():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
models_scanned = 0
|
||||||
def listModels(model_type):
|
def listModels(model_type):
|
||||||
|
nonlocal models_scanned
|
||||||
|
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
models_dir = os.path.join(app.MODELS_DIR, model_type)
|
models_dir = os.path.join(app.MODELS_DIR, model_type)
|
||||||
if not os.path.exists(models_dir):
|
if not os.path.exists(models_dir):
|
||||||
@ -126,6 +132,7 @@ def getModels():
|
|||||||
mtime = os.path.getmtime(model_path)
|
mtime = os.path.getmtime(model_path)
|
||||||
mod_time = known_models[model_path] if model_path in known_models else -1
|
mod_time = known_models[model_path] if model_path in known_models else -1
|
||||||
if mod_time != mtime:
|
if mod_time != mtime:
|
||||||
|
models_scanned += 1
|
||||||
if is_malicious_model(model_path):
|
if is_malicious_model(model_path):
|
||||||
models['scan-error'] = file
|
models['scan-error'] = file
|
||||||
return
|
return
|
||||||
@ -142,6 +149,8 @@ def getModels():
|
|||||||
listModels(model_type='vae')
|
listModels(model_type='vae')
|
||||||
listModels(model_type='hypernetwork')
|
listModels(model_type='hypernetwork')
|
||||||
|
|
||||||
|
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. 0 infected[/]')
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
||||||
if os.path.exists(custom_weight_path):
|
if os.path.exists(custom_weight_path):
|
||||||
|
@ -6,12 +6,15 @@ import os
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
|
import logging
|
||||||
|
|
||||||
from sd_internal import device_manager, model_manager
|
from sd_internal import device_manager, model_manager
|
||||||
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
|
||||||
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
thread_data = threading.local()
|
thread_data = threading.local()
|
||||||
'''
|
'''
|
||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
@ -69,7 +72,7 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
|
|||||||
try:
|
try:
|
||||||
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
|
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
data_queue.put(json.dumps({
|
data_queue.put(json.dumps({
|
||||||
"status": 'failed',
|
"status": 'failed',
|
||||||
@ -91,7 +94,7 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image
|
|||||||
res = Response(req, images=construct_response(req, images))
|
res = Response(req, images=construct_response(req, images))
|
||||||
res = res.json()
|
res = res.json()
|
||||||
data_queue.put(json.dumps(res))
|
data_queue.put(json.dumps(res))
|
||||||
print('Task completed')
|
log.info('Task completed')
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ Notes:
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
import logging
|
||||||
|
|
||||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||||
|
|
||||||
@ -16,6 +17,8 @@ from typing import Any, Hashable
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sd_internal import Request, device_manager
|
from sd_internal import Request, device_manager
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
THREAD_NAME_PREFIX = 'Runtime-Render/'
|
THREAD_NAME_PREFIX = 'Runtime-Render/'
|
||||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||||
@ -140,11 +143,11 @@ class DataCache():
|
|||||||
for key in to_delete:
|
for key in to_delete:
|
||||||
(_, val) = self._base[key]
|
(_, val) = self._base[key]
|
||||||
if isinstance(val, RenderTask):
|
if isinstance(val, RenderTask):
|
||||||
print(f'RenderTask {key} expired. Data removed.')
|
log.debug(f'RenderTask {key} expired. Data removed.')
|
||||||
elif isinstance(val, SessionState):
|
elif isinstance(val, SessionState):
|
||||||
print(f'Session {key} expired. Data removed.')
|
log.debug(f'Session {key} expired. Data removed.')
|
||||||
else:
|
else:
|
||||||
print(f'Key {key} expired. Data removed.')
|
log.debug(f'Key {key} expired. Data removed.')
|
||||||
del self._base[key]
|
del self._base[key]
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@ -178,8 +181,7 @@ class DataCache():
|
|||||||
self._get_ttl_time(ttl), value
|
self._get_ttl_time(ttl), value
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
log.error(traceback.format_exc())
|
||||||
print(traceback.format_exc())
|
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
@ -190,7 +192,7 @@ class DataCache():
|
|||||||
try:
|
try:
|
||||||
ttl, value = self._base.get(key, (None, None))
|
ttl, value = self._base.get(key, (None, None))
|
||||||
if ttl is not None and self._is_expired(ttl):
|
if ttl is not None and self._is_expired(ttl):
|
||||||
print(f'Session {key} expired. Discarding data.')
|
log.debug(f'Session {key} expired. Discarding data.')
|
||||||
del self._base[key]
|
del self._base[key]
|
||||||
return None
|
return None
|
||||||
return value
|
return value
|
||||||
@ -234,7 +236,7 @@ class SessionState():
|
|||||||
def thread_get_next_task():
|
def thread_get_next_task():
|
||||||
from sd_internal import runtime2
|
from sd_internal import runtime2
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.')
|
log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.')
|
||||||
return None
|
return None
|
||||||
if len(tasks_queue) <= 0:
|
if len(tasks_queue) <= 0:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
@ -269,7 +271,7 @@ def thread_render(device):
|
|||||||
try:
|
try:
|
||||||
runtime2.init(device)
|
runtime2.init(device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {
|
||||||
'error': e
|
'error': e
|
||||||
}
|
}
|
||||||
@ -287,7 +289,7 @@ def thread_render(device):
|
|||||||
session_cache.clean()
|
session_cache.clean()
|
||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||||
print(f'Shutting down thread for device {runtime2.thread_data.device}')
|
log.info(f'Shutting down thread for device {runtime2.thread_data.device}')
|
||||||
runtime2.destroy()
|
runtime2.destroy()
|
||||||
return
|
return
|
||||||
if isinstance(current_state_error, SystemExit):
|
if isinstance(current_state_error, SystemExit):
|
||||||
@ -299,7 +301,7 @@ def thread_render(device):
|
|||||||
idle_event.wait(timeout=1)
|
idle_event.wait(timeout=1)
|
||||||
continue
|
continue
|
||||||
if task.error is not None:
|
if task.error is not None:
|
||||||
print(task.error)
|
log.error(task.error)
|
||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
continue
|
continue
|
||||||
@ -308,7 +310,7 @@ def thread_render(device):
|
|||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
continue
|
continue
|
||||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
||||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||||
try:
|
try:
|
||||||
def step_callback():
|
def step_callback():
|
||||||
@ -319,7 +321,7 @@ def thread_render(device):
|
|||||||
if isinstance(current_state_error, StopAsyncIteration):
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
runtime2.reload_models_if_necessary(task.request)
|
runtime2.reload_models_if_necessary(task.request)
|
||||||
@ -331,7 +333,7 @@ def thread_render(device):
|
|||||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task.error = e
|
task.error = e
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
continue
|
continue
|
||||||
finally:
|
finally:
|
||||||
# Task completed
|
# Task completed
|
||||||
@ -339,11 +341,11 @@ def thread_render(device):
|
|||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
if isinstance(task.error, StopAsyncIteration):
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
log.info(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||||
elif task.error is not None:
|
elif task.error is not None:
|
||||||
print(f'Session {task.request.session_id} task {id(task)} failed!')
|
log.info(f'Session {task.request.session_id} task {id(task)} failed!')
|
||||||
else:
|
else:
|
||||||
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||||
@ -429,7 +431,7 @@ def is_alive(device=None):
|
|||||||
|
|
||||||
def start_render_thread(device):
|
def start_render_thread(device):
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
|
||||||
print('Start new Rendering Thread on device', device)
|
log.info(f'Start new Rendering Thread on device: {device}')
|
||||||
try:
|
try:
|
||||||
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
||||||
rthread.daemon = True
|
rthread.daemon = True
|
||||||
@ -441,7 +443,7 @@ def start_render_thread(device):
|
|||||||
timeout = DEVICE_START_TIMEOUT
|
timeout = DEVICE_START_TIMEOUT
|
||||||
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
|
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
|
||||||
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
|
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
|
||||||
print(rthread, device, 'error:', weak_thread_data[rthread]['error'])
|
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||||
return False
|
return False
|
||||||
if timeout <= 0:
|
if timeout <= 0:
|
||||||
return False
|
return False
|
||||||
@ -453,11 +455,11 @@ def stop_render_thread(device):
|
|||||||
try:
|
try:
|
||||||
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
|
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
|
||||||
except:
|
except:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
|
||||||
print('Stopping Rendering Thread on device', device)
|
log.info(f'Stopping Rendering Thread on device: {device}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
thread_to_remove = None
|
thread_to_remove = None
|
||||||
@ -480,27 +482,27 @@ def stop_render_thread(device):
|
|||||||
|
|
||||||
def update_render_threads(render_devices, active_devices):
|
def update_render_threads(render_devices, active_devices):
|
||||||
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
|
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
|
||||||
print('devices_to_start', devices_to_start)
|
log.debug(f'devices_to_start: {devices_to_start}')
|
||||||
print('devices_to_stop', devices_to_stop)
|
log.debug(f'devices_to_stop: {devices_to_stop}')
|
||||||
|
|
||||||
for device in devices_to_stop:
|
for device in devices_to_stop:
|
||||||
if is_alive(device) <= 0:
|
if is_alive(device) <= 0:
|
||||||
print(device, 'is not alive')
|
log.debug(f'{device} is not alive')
|
||||||
continue
|
continue
|
||||||
if not stop_render_thread(device):
|
if not stop_render_thread(device):
|
||||||
print(device, 'could not stop render thread')
|
log.warn(f'{device} could not stop render thread')
|
||||||
|
|
||||||
for device in devices_to_start:
|
for device in devices_to_start:
|
||||||
if is_alive(device) >= 1:
|
if is_alive(device) >= 1:
|
||||||
print(device, 'already registered.')
|
log.debug(f'{device} already registered.')
|
||||||
continue
|
continue
|
||||||
if not start_render_thread(device):
|
if not start_render_thread(device):
|
||||||
print(device, 'failed to start.')
|
log.warn(f'{device} failed to start.')
|
||||||
|
|
||||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||||
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
|
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
|
||||||
|
|
||||||
print('active devices', get_devices()['active'])
|
log.debug(f"active devices: {get_devices()['active']}")
|
||||||
|
|
||||||
def shutdown_event(): # Signal render thread to close on shutdown
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
global current_state_error
|
global current_state_error
|
||||||
|
13
ui/server.py
13
ui/server.py
@ -14,7 +14,9 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from sd_internal import app, model_manager, task_manager
|
from sd_internal import app, model_manager, task_manager
|
||||||
|
|
||||||
print('started in ', app.SD_DIR)
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
log.info(f'started in {app.SD_DIR}')
|
||||||
|
|
||||||
server_api = FastAPI()
|
server_api = FastAPI()
|
||||||
|
|
||||||
@ -84,7 +86,7 @@ async def setAppConfig(req : SetAppConfigRequest):
|
|||||||
|
|
||||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
def update_render_devices_in_config(config, render_devices):
|
def update_render_devices_in_config(config, render_devices):
|
||||||
@ -153,8 +155,7 @@ def render(req : task_manager.ImageRequest):
|
|||||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
log.error(traceback.format_exc())
|
||||||
print(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@server_api.get('/image/stream/{task_id:int}')
|
@server_api.get('/image/stream/{task_id:int}')
|
||||||
@ -165,10 +166,10 @@ def stream(task_id:int):
|
|||||||
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||||
if task.buffer_queue.empty() and not task.lock.locked():
|
if task.buffer_queue.empty() and not task.lock.locked():
|
||||||
if task.response:
|
if task.response:
|
||||||
#print(f'Session {session_id} sending cached response')
|
#log.info(f'Session {session_id} sending cached response')
|
||||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||||
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||||
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||||
|
|
||||||
@server_api.get('/image/stop')
|
@server_api.get('/image/stop')
|
||||||
|
Loading…
Reference in New Issue
Block a user