Hypernetwork support (#619)

* Update README.md

* Update README.md

* Make on_sd_start.sh executable

* Merge pull request #542 from patriceac/patch-1

Fix restoration of model and VAE

* Merge pull request #541 from patriceac/patch-2

Fix restoration of parallel output setting

* Hypernetwork support

Adds support for hypernetworks. Hypernetworks are stored in /models/hypernetworks

* forgot to remove unused code

Co-authored-by: cmdr2 <secondary.cmdr2@gmail.com>
This commit is contained in:
Guillaume Mercier
2022-12-07 00:54:16 -05:00
committed by GitHub
parent 1283c6483d
commit cbe91251ac
11 changed files with 346 additions and 7 deletions

View File

@ -77,6 +77,8 @@ class ImageRequest(BaseModel):
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
@ -177,28 +179,35 @@ current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
current_vae_path = None
current_hypernetwork_path = None
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(ckpt_file_path=None, vae_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
@ -240,7 +249,7 @@ def thread_get_next_task():
manager_lock.release()
def thread_render(device):
global current_state, current_state_error, current_model_path, current_vae_path
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
from . import runtime
try:
runtime.thread_init(device)
@ -285,6 +294,10 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
if runtime.is_hypernetwork_reload_necessary(task.request):
runtime.reload_hypernetwork()
current_hypernetwork_path = task.request.use_hypernetwork_model
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
@ -504,6 +517,8 @@ def render(req : ImageRequest):
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.use_hypernetwork_model = req.use_hypernetwork_model
r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality