"""runtime.py: torch device owned by a thread.
Notes:
    Avoid device switching, transfering all models will get too complex.
    To use a diffrent device signal the current render device to exit
    And then start a new clean thread for the new device.
"""
import json
import os, re
import traceback
import queue
import torch
import numpy as np
from gc import collect as gc_collect
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from transformers import logging

from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS

from threading import Lock
from safetensors.torch import load_file

import uuid

logging.set_verbosity_error()

# consts
config_yaml = "optimizedSD/v1-inference.yaml"
filename_regex = re.compile('[^a-zA-Z0-9]')
gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time.

# api stuff
from sd_internal import device_manager
from . import Request, Response, Image as ResponseImage
import base64
from io import BytesIO
#from colorama import Fore

from threading import local as LocalThreadVars
thread_data = LocalThreadVars()

def thread_init(device):
    # Thread bound properties
    thread_data.stop_processing = False
    thread_data.temp_images = {}

    thread_data.ckpt_file = None
    thread_data.vae_file = None
    thread_data.hypernetwork_file = None
    thread_data.gfpgan_file = None
    thread_data.real_esrgan_file = None

    thread_data.model = None
    thread_data.modelCS = None
    thread_data.modelFS = None
    thread_data.hypernetwork = None
    thread_data.hypernetwork_strength = 1
    thread_data.model_gfpgan = None
    thread_data.model_real_esrgan = None

    thread_data.model_is_half = False
    thread_data.model_fs_is_half = False
    thread_data.device = None
    thread_data.device_name = None
    thread_data.unet_bs = 1
    thread_data.precision = 'autocast'
    thread_data.sampler_plms = None
    thread_data.sampler_ddim = None

    thread_data.turbo = False
    thread_data.force_full_precision = False
    thread_data.reduced_memory = True

    thread_data.test_sd2 = isSD2()

    device_manager.device_init(thread_data, device)

# temp hack, will remove soon
def isSD2():
    try:
        SD_UI_DIR = os.getenv('SD_UI_PATH', None)
        CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
        config_json_path = os.path.join(CONFIG_DIR, 'config.json')
        if not os.path.exists(config_json_path):
            return False
        with open(config_json_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
            return config.get('test_sd2', False)
    except Exception as e:
        return False

def load_model_ckpt():
    if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
    if os.path.exists(thread_data.ckpt_file + '.ckpt'):
        thread_data.ckpt_file += '.ckpt'
    elif os.path.exists(thread_data.ckpt_file + '.safetensors'):
        thread_data.ckpt_file += '.safetensors'
    elif not os.path.exists(thread_data.ckpt_file):
        raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors')

    if not thread_data.precision:
        thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'

    if not thread_data.unet_bs:
        thread_data.unet_bs = 1

    if thread_data.device == 'cpu':
        thread_data.precision = 'full'

    print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision)

    if thread_data.test_sd2:
        load_model_ckpt_sd2()
    else:
        load_model_ckpt_sd1()

def load_model_ckpt_sd1():
    sd, model_ver = load_model_from_config(thread_data.ckpt_file)
    li, lo = [], []
    for key, value in sd.items():
        sp = key.split(".")
        if (sp[0]) == "model":
            if "input_blocks" in sp:
                li.append(key)
            elif "middle_block" in sp:
                li.append(key)
            elif "time_embed" in sp:
                li.append(key)
            else:
                lo.append(key)
    for key in li:
        sd["model1." + key[6:]] = sd.pop(key)
    for key in lo:
        sd["model2." + key[6:]] = sd.pop(key)

    config = OmegaConf.load(f"{config_yaml}")

    model = instantiate_from_config(config.modelUNet)
    _, _ = model.load_state_dict(sd, strict=False)
    model.eval()
    model.cdevice = torch.device(thread_data.device)
    model.unet_bs = thread_data.unet_bs
    model.turbo = thread_data.turbo
    # if thread_data.device != 'cpu':
    #     model.to(thread_data.device)
    #if thread_data.reduced_memory:
        #model.model1.to("cpu")
        #model.model2.to("cpu")
    thread_data.model = model

    modelCS = instantiate_from_config(config.modelCondStage)
    _, _ = modelCS.load_state_dict(sd, strict=False)
    modelCS.eval()
    modelCS.cond_stage_model.device = torch.device(thread_data.device)
    # if thread_data.device != 'cpu':
    #     if thread_data.reduced_memory:
    #         modelCS.to('cpu')
    #     else:
    #         modelCS.to(thread_data.device) # Preload on device if not already there.
    thread_data.modelCS = modelCS

    modelFS = instantiate_from_config(config.modelFirstStage)
    _, _ = modelFS.load_state_dict(sd, strict=False)

    if thread_data.vae_file is not None:
        try:
            loaded = False
            for model_extension in ['.ckpt', '.vae.pt']:
                if os.path.exists(thread_data.vae_file + model_extension):
                    print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}")
                    vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu")
                    vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
                    modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
                    loaded = True
                    break

            if not loaded:
                print(f'Cannot find VAE: {thread_data.vae_file}')
                thread_data.vae_file = None
        except:
            print(traceback.format_exc())
            print(f'Could not load VAE: {thread_data.vae_file}')
            thread_data.vae_file = None

    modelFS.eval()
    # if thread_data.device != 'cpu':
    #     if thread_data.reduced_memory:
    #         modelFS.to('cpu')
    #     else:
    #         modelFS.to(thread_data.device) # Preload on device if not already there.
    thread_data.modelFS = modelFS
    del sd

    if thread_data.device != "cpu" and thread_data.precision == "autocast":
        thread_data.model.half()
        thread_data.modelCS.half()
        thread_data.modelFS.half()
        thread_data.model_is_half = True
        thread_data.model_fs_is_half = True
    else:
        thread_data.model_is_half = False
        thread_data.model_fs_is_half = False

    print(f'''loaded model
 model file: {thread_data.ckpt_file}
 model.device: {model.device}
 modelCS.device: {modelCS.cond_stage_model.device}
 modelFS.device: {thread_data.modelFS.device}
 using precision: {thread_data.precision}''')

def load_model_ckpt_sd2():
    sd, model_ver = load_model_from_config(thread_data.ckpt_file)

    config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml"
    config = OmegaConf.load(config_file)
    verbose = False

    thread_data.model = instantiate_from_config(config.model)
    m, u = thread_data.model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    thread_data.model.to(thread_data.device)
    thread_data.model.eval()
    del sd

    thread_data.model.cond_stage_model.device = torch.device(thread_data.device)

    if thread_data.device != "cpu" and thread_data.precision == "autocast":
        thread_data.model.half()
        thread_data.model_is_half = True
        thread_data.model_fs_is_half = True
    else:
        thread_data.model_is_half = False
        thread_data.model_fs_is_half = False

    print(f'''loaded model
 model file: {thread_data.ckpt_file}
 using precision: {thread_data.precision}''')

def unload_filters():
    if thread_data.model_gfpgan is not None:
        if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')

        del thread_data.model_gfpgan
    thread_data.model_gfpgan = None

    if thread_data.model_real_esrgan is not None:
        if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu')

        del thread_data.model_real_esrgan
    thread_data.model_real_esrgan = None

    gc()

def unload_models():
    if thread_data.model is not None:
        print('Unloading models...')
        if thread_data.device != 'cpu':
            if not thread_data.test_sd2:
                thread_data.modelFS.to('cpu')
                thread_data.modelCS.to('cpu')
                thread_data.model.model1.to("cpu")
                thread_data.model.model2.to("cpu")

        del thread_data.model
        del thread_data.modelCS
        del thread_data.modelFS

    thread_data.model = None
    thread_data.modelCS = None
    thread_data.modelFS = None

    gc()

# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
#     if thread_data.device == target_device: return
#     start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
#     if start_mem <= 0: return
#     model_name = model.__class__.__name__
#     print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb')
#     start_time = time.time()
#     model.to(target_device)
#     time_step = start_time
#     WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
#     last_mem = start_mem
#     is_transfering = True
#     while is_transfering:
#         time.sleep(0.5) # 500ms
#         mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
#         is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
#         last_mem = mem
#         if not is_transfering:
#             break;
#         if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
#             print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb')
#             time_step = time.time()
#     print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}')

def move_to_cpu(model):
    if thread_data.device != "cpu":
        d = torch.device(thread_data.device)
        mem = torch.cuda.memory_allocated(d) / 1e6
        model.to("cpu")
        while torch.cuda.memory_allocated(d) / 1e6 >= mem:
            time.sleep(1)

def load_model_gfpgan():
    if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
    model_path = thread_data.gfpgan_file + ".pth"
    thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
    print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)

def load_model_real_esrgan():
    if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
    model_path = thread_data.real_esrgan_file + ".pth"

    RealESRGAN_models = {
        'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
        'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
    }

    model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]

    if thread_data.device == 'cpu':
        thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
        #thread_data.model_real_esrgan.device = torch.device(thread_data.device)
        thread_data.model_real_esrgan.model.to('cpu')
    else:
        thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)

    thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
    print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)


def get_session_out_path(disk_path, session_id):
    if disk_path is None: return None
    if session_id is None: return None

    session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id))
    os.makedirs(session_out_path, exist_ok=True)
    return session_out_path

def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
    if disk_path is None: return None
    if session_id is None: return None
    if ext is None: raise Exception('Missing ext')

    session_out_path = get_session_out_path(disk_path, session_id)

    prompt_flattened = filename_regex.sub('_', prompt)[:50]

    if suffix is not None:
        return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
    return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")

def apply_filters(filter_name, image_data, model_path=None):
    print(f'Applying filter {filter_name}...')
    gc() # Free space before loading new data.

    if isinstance(image_data, torch.Tensor):
        image_data.to(thread_data.device)

    if filter_name == 'gfpgan':
        # This lock is only ever used here. No need to use timeout for the request. Should never deadlock.
        with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting.
            # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
            from facexlib.detection import retinaface
            retinaface.device = torch.device(thread_data.device)
            print('forced retinaface.device to', thread_data.device)

            if model_path is not None and model_path != thread_data.gfpgan_file:
                thread_data.gfpgan_file = model_path
                load_model_gfpgan()
            elif not thread_data.model_gfpgan:
                load_model_gfpgan()
            if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')

            print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
            _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
            image_data = output[:,:,::-1]

    if filter_name == 'real_esrgan':
        if model_path is not None and model_path != thread_data.real_esrgan_file:
            thread_data.real_esrgan_file = model_path
            load_model_real_esrgan()
        elif not thread_data.model_real_esrgan:
            load_model_real_esrgan()
        if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
        print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
        output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
        image_data = output[:,:,::-1]

    return image_data

def is_model_reload_necessary(req: Request):
    # custom model support:
    #  the req.use_stable_diffusion_model needs to be a valid path
    #  to the ckpt file (without the extension).
    if os.path.exists(req.use_stable_diffusion_model + '.ckpt'):
        req.use_stable_diffusion_model += '.ckpt'
    elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'):
        req.use_stable_diffusion_model += '.safetensors'
    elif not os.path.exists(req.use_stable_diffusion_model): 
        raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors')

    needs_model_reload = False
    if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
        thread_data.ckpt_file = req.use_stable_diffusion_model
        thread_data.vae_file = req.use_vae_model
        needs_model_reload = True

    if thread_data.device != 'cpu':
        if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
            (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
            thread_data.precision = 'full' if req.use_full_precision else 'autocast'
            needs_model_reload = True

    return needs_model_reload

def reload_model():
    unload_models()
    unload_filters()
    load_model_ckpt()

def is_hypernetwork_reload_necessary(req: Request):
    needs_model_reload = False
    if thread_data.hypernetwork_file != req.use_hypernetwork_model:
        thread_data.hypernetwork_file = req.use_hypernetwork_model
        needs_model_reload = True

    return needs_model_reload

def load_hypernetwork():
    if thread_data.test_sd2:
        # Not yet supported in SD2
        return

    from . import hypernetwork
    if thread_data.hypernetwork_file is not None:
        try:
            loaded = False
            for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
                if os.path.exists(thread_data.hypernetwork_file + model_extension):
                    print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
                    thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
                    loaded = True
                    break

            if not loaded:
                print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
                thread_data.hypernetwork_file = None
        except:
            print(traceback.format_exc())
            print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
            thread_data.hypernetwork_file = None

def unload_hypernetwork():
    if thread_data.hypernetwork is not None:
        print('Unloading hypernetwork...')
        if thread_data.device != 'cpu':
            for i in thread_data.hypernetwork:
                thread_data.hypernetwork[i][0].to('cpu')
                thread_data.hypernetwork[i][1].to('cpu')
        del thread_data.hypernetwork
    thread_data.hypernetwork = None

    gc()

def reload_hypernetwork():
    unload_hypernetwork()
    load_hypernetwork()

def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
    try:
        return do_mk_img(req, data_queue, task_temp_images, step_callback)
    except Exception as e:
        print(traceback.format_exc())

        if thread_data.device != 'cpu' and not thread_data.test_sd2:
            thread_data.modelFS.to('cpu')
            thread_data.modelCS.to('cpu')
            thread_data.model.model1.to("cpu")
            thread_data.model.model2.to("cpu")

        gc() # Release from memory.
        data_queue.put(json.dumps({
            "status": 'failed',
            "detail": str(e)
        }))
        raise e

def update_temp_img(req, x_samples, task_temp_images: list):
    partial_images = []
    for i in range(req.num_outputs):
        if thread_data.test_sd2:
            x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
        else:
            x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
        x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
        x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
        x_sample = x_sample.astype(np.uint8)
        img = Image.fromarray(x_sample)
        buf = img_to_buffer(img, output_format='JPEG')

        del img, x_sample, x_sample_ddim
        # don't delete x_samples, it is used in the code that called this callback

        thread_data.temp_images[f'{req.request_id}/{i}'] = buf
        task_temp_images[i] = buf
        partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
    return partial_images

# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
    if not req.stream_progress_updates:
        def empty_callback(x_samples, i):
            step_callback()
        return empty_callback

    thread_data.partial_x_samples = None
    last_callback_time = -1
    def img_callback(x_samples, i):
        nonlocal last_callback_time

        thread_data.partial_x_samples = x_samples
        step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
        last_callback_time = time.time()

        progress = {"step": i, "step_time": step_time}
        if extra_props is not None:
            progress.update(extra_props)

        if req.stream_image_progress and i % 5 == 0:
            progress['output'] = update_temp_img(req, x_samples, task_temp_images)

        data_queue.put(json.dumps(progress))

        step_callback()

        if thread_data.stop_processing:
            raise UserInitiatedStop("User requested that we stop processing")
    return img_callback

def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
    thread_data.stop_processing = False

    res = Response()
    res.request = req
    res.images = []
    thread_data.hypernetwork_strength = req.hypernetwork_strength

    thread_data.temp_images.clear()

    if thread_data.turbo != req.turbo and not thread_data.test_sd2:
        thread_data.turbo = req.turbo
        thread_data.model.turbo = req.turbo

    # Start by cleaning memory, loading and unloading things can leave memory allocated.
    gc()

    opt_prompt = req.prompt
    opt_seed = req.seed
    opt_n_iter = 1
    opt_C = 4
    opt_f = 8
    opt_ddim_eta = 0.0

    print(req, '\n    device', torch.device(thread_data.device), "as", thread_data.device_name)
    print('\n\n    Using precision:', thread_data.precision)

    seed_everything(opt_seed)

    batch_size = req.num_outputs
    prompt = opt_prompt
    assert prompt is not None
    data = [batch_size * [prompt]]

    if thread_data.precision == "autocast" and thread_data.device != "cpu":
        precision_scope = autocast
    else:
        precision_scope = nullcontext

    mask = None

    if req.init_image is None:
        handler = _txt2img

        init_latent = None
        t_enc = None
    else:
        handler = _img2img

        init_image = load_img(req.init_image, req.width, req.height)
        init_image = init_image.to(thread_data.device)

        if thread_data.device != "cpu" and thread_data.precision == "autocast":
            init_image = init_image.half()

        if not thread_data.test_sd2:
            thread_data.modelFS.to(thread_data.device)

        init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
        if thread_data.test_sd2:
            init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image))  # move to latent space
        else:
            init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image))  # move to latent space

        if req.mask is not None:
            mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
            mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
            mask = repeat(mask, '1 ... -> b ...', b=batch_size)

            if thread_data.device != "cpu" and thread_data.precision == "autocast":
                mask = mask.half()

        # Send to CPU and wait until complete.
        # wait_model_move_to(thread_data.modelFS, 'cpu')
        if not thread_data.test_sd2:
            move_to_cpu(thread_data.modelFS)

        assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
        t_enc = int(req.prompt_strength * req.num_inference_steps)
        print(f"target t_enc is {t_enc} steps")

    with torch.no_grad():
        for n in trange(opt_n_iter, desc="Sampling"):
            for prompts in tqdm(data, desc="data"):

                with precision_scope("cuda"):
                    if thread_data.reduced_memory and not thread_data.test_sd2:
                        thread_data.modelCS.to(thread_data.device)
                    uc = None
                    if req.guidance_scale != 1.0:
                        if thread_data.test_sd2:
                            uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
                        else:
                            uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)

                    subprompts, weights = split_weighted_subprompts(prompts[0])
                    if len(subprompts) > 1:
                        c = torch.zeros_like(uc)
                        totalWeight = sum(weights)
                        # normalize each "sub prompt" and add it
                        for i in range(len(subprompts)):
                            weight = weights[i]
                            # if not skip_normalize:
                            weight = weight / totalWeight
                            if thread_data.test_sd2:
                                c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
                            else:
                                c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
                    else:
                        if thread_data.test_sd2:
                            c = thread_data.model.get_learned_conditioning(prompts)
                        else:
                            c = thread_data.modelCS.get_learned_conditioning(prompts)

                    if thread_data.reduced_memory and not thread_data.test_sd2:
                        thread_data.modelFS.to(thread_data.device)

                    n_steps = req.num_inference_steps if req.init_image is None else t_enc
                    img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})

                    # run the handler
                    try:
                        print('Running handler...')
                        if handler == _txt2img:
                            x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
                        else:
                            x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
                    except UserInitiatedStop:
                        if not hasattr(thread_data, 'partial_x_samples'):
                            continue
                        if thread_data.partial_x_samples is None:
                            del thread_data.partial_x_samples
                            continue
                        x_samples = thread_data.partial_x_samples
                        del thread_data.partial_x_samples

                    print("decoding images")
                    img_data = [None] * batch_size
                    for i in range(batch_size):
                        if thread_data.test_sd2:
                            x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
                        else:
                            x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
                        x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
                        x_sample = x_sample.astype(np.uint8)
                        img_data[i] = x_sample
                    del x_samples, x_samples_ddim, x_sample

                    print("saving images")
                    for i in range(batch_size):
                        img = Image.fromarray(img_data[i])
                        img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
                        img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.

                        has_filters =   (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
                                        (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))

                        return_orig_img = not has_filters or not req.show_only_filtered_image

                        if thread_data.stop_processing:
                            return_orig_img = True

                        if req.save_to_disk_path is not None:
                            if return_orig_img:
                                img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format)
                                save_image(img, img_out_path, req.output_format, req.output_quality)
                            meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt')
                            save_metadata(meta_out_path, req, prompts[0], opt_seed)

                        if return_orig_img:
                            img_buffer = img_to_buffer(img, req.output_format, req.output_quality)
                            img_str = buffer_to_base64_str(img_buffer, req.output_format)
                            res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
                            res.images.append(res_image_orig)
                            task_temp_images[i] = img_buffer

                            if req.save_to_disk_path is not None:
                                res_image_orig.path_abs = img_out_path
                        del img

                        if has_filters and not thread_data.stop_processing:
                            filters_applied = []
                            if req.use_face_correction:
                                img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
                                filters_applied.append(req.use_face_correction)
                            if req.use_upscale:
                                img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
                                filters_applied.append(req.use_upscale)
                            if (len(filters_applied) > 0):
                                filtered_image = Image.fromarray(img_data[i])
                                filtered_buffer = img_to_buffer(filtered_image, req.output_format, req.output_quality)
                                filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
                                response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
                                res.images.append(response_image)
                                task_temp_images[i] = filtered_buffer
                                if req.save_to_disk_path is not None:
                                    filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
                                    save_image(filtered_image, filtered_img_out_path, req.output_format, req.output_quality)
                                    response_image.path_abs = filtered_img_out_path
                                del filtered_image
                        # Filter Applied, move to next seed
                        opt_seed += 1

                    # if thread_data.reduced_memory:
                    #     unload_filters()
                    if not thread_data.test_sd2:
                        move_to_cpu(thread_data.modelFS)
                    del img_data
                    gc()
                    if thread_data.device != 'cpu':
                        print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')

    print('Task completed')
    res = res.json()
    data_queue.put(json.dumps(res))

    return res

def save_image(img, img_out_path, output_format="", output_quality=75):
    try:
        if output_format.upper() == "JPEG":
            img.save(img_out_path, quality=output_quality)
        else:
            img.save(img_out_path)
    except:
        print('could not save the file', traceback.format_exc())

def save_metadata(meta_out_path, req, prompt, opt_seed):
    metadata = f'''{prompt}
Width: {req.width}
Height: {req.height}
Seed: {opt_seed}
Steps: {req.num_inference_steps}
Guidance Scale: {req.guidance_scale}
Prompt Strength: {req.prompt_strength}
Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale}
Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
VAE model: {req.use_vae_model}
Hypernetwork Model: {req.use_hypernetwork_model}
Hypernetwork Strength: {req.hypernetwork_strength}
'''
    try:
        with open(meta_out_path, 'w', encoding='utf-8') as f:
            f.write(metadata)
    except:
        print('could not save the file', traceback.format_exc())

def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
    shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]

    # Send to CPU and wait until complete.
    # wait_model_move_to(thread_data.modelCS, 'cpu')

    if not thread_data.test_sd2:
        move_to_cpu(thread_data.modelCS)

    if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'):
        raise Exception('Only plms, ddim and dpm2 samplers are supported right now, in SD 2.0')


    # samples, _ = sampler.sample(S=opt.steps,
    #                                                  conditioning=c,
    #                                                  batch_size=opt.n_samples,
    #                                                  shape=shape,
    #                                                  verbose=False,
    #                                                  unconditional_guidance_scale=opt.scale,
    #                                                  unconditional_conditioning=uc,
    #                                                  eta=opt.ddim_eta,
    #                                                  x_T=start_code)

    if thread_data.test_sd2:
        if sampler_name == 'plms':
            from ldm.models.diffusion.plms import PLMSSampler
            sampler = PLMSSampler(thread_data.model)
        elif sampler_name == 'ddim':
            from ldm.models.diffusion.ddim import DDIMSampler
            sampler = DDIMSampler(thread_data.model)
            sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
        elif sampler_name == 'dpm2':
            from ldm.models.diffusion.dpm_solver import DPMSolverSampler
            sampler = DPMSolverSampler(thread_data.model)

        shape = [opt_C, opt_H // opt_f, opt_W // opt_f]

        samples_ddim, intermediates = sampler.sample(
            S=opt_ddim_steps,
            conditioning=c,
            batch_size=opt_n_samples,
            seed=opt_seed,
            shape=shape,
            verbose=False,
            unconditional_guidance_scale=opt_scale,
            unconditional_conditioning=uc,
            eta=opt_ddim_eta,
            x_T=start_code,
            img_callback=img_callback,
            mask=mask,
            sampler = sampler_name,
        )
    else:
        if sampler_name == 'ddim':
            thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)

        samples_ddim = thread_data.model.sample(
            S=opt_ddim_steps,
            conditioning=c,
            seed=opt_seed,
            shape=shape,
            verbose=False,
            unconditional_guidance_scale=opt_scale,
            unconditional_conditioning=uc,
            eta=opt_ddim_eta,
            x_T=start_code,
            img_callback=img_callback,
            mask=mask,
            sampler = sampler_name,
        )
    return samples_ddim

def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
    # encode (scaled latent)
    x_T = None if mask is None else init_latent

    if thread_data.test_sd2:
        from ldm.models.diffusion.ddim import DDIMSampler

        sampler = DDIMSampler(thread_data.model)

        sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)

        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))

        samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)

    else:
        z_enc = thread_data.model.stochastic_encode(
            init_latent,
            torch.tensor([t_enc] * batch_size).to(thread_data.device),
            opt_seed,
            opt_ddim_eta,
            opt_ddim_steps,
        )

        # decode it
        samples_ddim = thread_data.model.sample(
            t_enc,
            c,
            z_enc,
            unconditional_guidance_scale=opt_scale,
            unconditional_conditioning=uc,
            img_callback=img_callback,
            mask=mask,
            x_T=x_T,
            sampler = 'ddim'
        )
    return samples_ddim

def gc():
    gc_collect()
    if thread_data.device == 'cpu':
        return
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# internal

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def load_model_from_config(ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    model_ver = 'sd1'

    if ckpt.endswith(".safetensors"):
        print("Loading from safetensors")
        pl_sd = load_file(ckpt, device="cpu")
    else:
        pl_sd = torch.load(ckpt, map_location="cpu")

    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")

    if "state_dict" in pl_sd:
        # check for a key that only seems to be present in SD2 models
        if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys():
            model_ver = 'sd2'

        return pl_sd["state_dict"], model_ver
    else:
        return pl_sd, model_ver

class UserInitiatedStop(Exception):
    pass

def load_img(img_str, w0, h0):
    image = base64_str_to_img(img_str).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from base64")
    if h0 is not None and w0 is not None:
        h, w = h0, w0

    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 64
    image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def load_mask(mask_str, h0, w0, newH, newW, invert=False):
    image = base64_str_to_img(mask_str).convert("RGB")
    w, h = image.size
    print(f"loaded input mask of size ({w}, {h})")

    if invert:
        print("inverted")
        image = ImageOps.invert(image)
        # where_0, where_1 = np.where(image == 0), np.where(image == 255)
        # image[where_0], image[where_1] = 255, 0

    if h0 is not None and w0 is not None:
        h, w = h0, w0

    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 64

    print(f"New mask size ({w}, {h})")
    image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
    image = np.array(image)

    image = image.astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image

# https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG", output_quality=75):
    buffered = img_to_buffer(img, output_format, quality=output_quality)
    return buffer_to_base64_str(buffered, output_format)

def img_to_buffer(img, output_format="PNG", output_quality=75):
    buffered = BytesIO()
    if ( output_format.upper() == "JPEG" ):
        img.save(buffered, format=output_format, quality=output_quality)
    else:
        img.save(buffered, format=output_format)
    buffered.seek(0)
    return buffered

def buffer_to_base64_str(buffered, output_format="PNG"):
    buffered.seek(0)
    img_byte = buffered.getvalue()
    mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
    img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode()
    return img_str

def base64_str_to_buffer(img_str):
    mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg"
    img_str = img_str[len(f"data:{mime_type};base64,"):]
    data = base64.b64decode(img_str)
    buffered = BytesIO(data)
    return buffered

def base64_str_to_img(img_str):
    buffered = base64_str_to_buffer(img_str)
    img = Image.open(buffered)
    return img

def split_weighted_subprompts(text):
    """
    grabs all text up to the first occurrence of ':' 
    uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
    if ':' has no value defined, defaults to 1.0
    repeats until no text remaining
    """
    remaining = len(text)
    prompts = []
    weights = []
    while remaining > 0:
        if ":" in text:
            idx = text.index(":") # first occurrence from start
            # grab up to index as sub-prompt
            prompt = text[:idx]
            remaining -= idx
            # remove from main text
            text = text[idx+1:]
            # find value for weight 
            if " " in text:
                idx = text.index(" ") # first occurence
            else: # no space, read to end
                idx = len(text)
            if idx != 0:
                try:
                    weight = float(text[:idx])
                except: # couldn't treat as float
                    print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
                    weight = 1.0
            else: # no value found
                weight = 1.0
            # remove from main text
            remaining -= idx
            text = text[idx+1:]
            # append the sub-prompt and its weight
            prompts.append(prompt)
            weights.append(weight)
        else: # no : found
            if len(text) > 0: # there is still text though
                # take remainder as weight 1
                prompts.append(text)
                weights.append(1.0)
            remaining = 0
    return prompts, weights