mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-27 02:33:10 +01:00
1083 lines
43 KiB
Python
1083 lines
43 KiB
Python
"""runtime.py: torch device owned by a thread.
|
|
Notes:
|
|
Avoid device switching, transfering all models will get too complex.
|
|
To use a diffrent device signal the current render device to exit
|
|
And then start a new clean thread for the new device.
|
|
"""
|
|
import json
|
|
import os, re
|
|
import traceback
|
|
import queue
|
|
import torch
|
|
import numpy as np
|
|
from gc import collect as gc_collect
|
|
from omegaconf import OmegaConf
|
|
from PIL import Image, ImageOps
|
|
from tqdm import tqdm, trange
|
|
from itertools import islice
|
|
from einops import rearrange
|
|
import time
|
|
from pytorch_lightning import seed_everything
|
|
from torch import autocast
|
|
from contextlib import nullcontext
|
|
from einops import rearrange, repeat
|
|
from ldm.util import instantiate_from_config
|
|
from transformers import logging
|
|
|
|
from gfpgan import GFPGANer
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
from realesrgan import RealESRGANer
|
|
|
|
from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
|
|
|
|
from threading import Lock
|
|
from safetensors.torch import load_file
|
|
|
|
import uuid
|
|
|
|
logging.set_verbosity_error()
|
|
|
|
# consts
|
|
config_yaml = "optimizedSD/v1-inference.yaml"
|
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
|
gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time.
|
|
|
|
# api stuff
|
|
from sd_internal import device_manager
|
|
from . import Request, Response, Image as ResponseImage
|
|
import base64
|
|
from io import BytesIO
|
|
#from colorama import Fore
|
|
|
|
from threading import local as LocalThreadVars
|
|
thread_data = LocalThreadVars()
|
|
|
|
def thread_init(device):
|
|
# Thread bound properties
|
|
thread_data.stop_processing = False
|
|
thread_data.temp_images = {}
|
|
|
|
thread_data.ckpt_file = None
|
|
thread_data.vae_file = None
|
|
thread_data.hypernetwork_file = None
|
|
thread_data.gfpgan_file = None
|
|
thread_data.real_esrgan_file = None
|
|
|
|
thread_data.model = None
|
|
thread_data.modelCS = None
|
|
thread_data.modelFS = None
|
|
thread_data.hypernetwork = None
|
|
thread_data.hypernetwork_strength = 1
|
|
thread_data.model_gfpgan = None
|
|
thread_data.model_real_esrgan = None
|
|
|
|
thread_data.model_is_half = False
|
|
thread_data.model_fs_is_half = False
|
|
thread_data.device = None
|
|
thread_data.device_name = None
|
|
thread_data.unet_bs = 1
|
|
thread_data.precision = 'autocast'
|
|
thread_data.sampler_plms = None
|
|
thread_data.sampler_ddim = None
|
|
|
|
thread_data.turbo = False
|
|
thread_data.force_full_precision = False
|
|
thread_data.reduced_memory = True
|
|
|
|
thread_data.test_sd2 = isSD2()
|
|
|
|
device_manager.device_init(thread_data, device)
|
|
|
|
# temp hack, will remove soon
|
|
def isSD2():
|
|
try:
|
|
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
|
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
|
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
|
if not os.path.exists(config_json_path):
|
|
return False
|
|
with open(config_json_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
return config.get('test_sd2', False)
|
|
except Exception as e:
|
|
return False
|
|
|
|
def load_model_ckpt():
|
|
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
|
if os.path.exists(thread_data.ckpt_file + '.ckpt'):
|
|
thread_data.ckpt_file += '.ckpt'
|
|
elif os.path.exists(thread_data.ckpt_file + '.safetensors'):
|
|
thread_data.ckpt_file += '.safetensors'
|
|
elif not os.path.exists(thread_data.ckpt_file):
|
|
raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors')
|
|
|
|
if not thread_data.precision:
|
|
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
|
|
|
|
if not thread_data.unet_bs:
|
|
thread_data.unet_bs = 1
|
|
|
|
if thread_data.device == 'cpu':
|
|
thread_data.precision = 'full'
|
|
|
|
print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision)
|
|
|
|
if thread_data.test_sd2:
|
|
load_model_ckpt_sd2()
|
|
else:
|
|
load_model_ckpt_sd1()
|
|
|
|
def load_model_ckpt_sd1():
|
|
sd, model_ver = load_model_from_config(thread_data.ckpt_file)
|
|
li, lo = [], []
|
|
for key, value in sd.items():
|
|
sp = key.split(".")
|
|
if (sp[0]) == "model":
|
|
if "input_blocks" in sp:
|
|
li.append(key)
|
|
elif "middle_block" in sp:
|
|
li.append(key)
|
|
elif "time_embed" in sp:
|
|
li.append(key)
|
|
else:
|
|
lo.append(key)
|
|
for key in li:
|
|
sd["model1." + key[6:]] = sd.pop(key)
|
|
for key in lo:
|
|
sd["model2." + key[6:]] = sd.pop(key)
|
|
|
|
config = OmegaConf.load(f"{config_yaml}")
|
|
|
|
model = instantiate_from_config(config.modelUNet)
|
|
_, _ = model.load_state_dict(sd, strict=False)
|
|
model.eval()
|
|
model.cdevice = torch.device(thread_data.device)
|
|
model.unet_bs = thread_data.unet_bs
|
|
model.turbo = thread_data.turbo
|
|
# if thread_data.device != 'cpu':
|
|
# model.to(thread_data.device)
|
|
#if thread_data.reduced_memory:
|
|
#model.model1.to("cpu")
|
|
#model.model2.to("cpu")
|
|
thread_data.model = model
|
|
|
|
modelCS = instantiate_from_config(config.modelCondStage)
|
|
_, _ = modelCS.load_state_dict(sd, strict=False)
|
|
modelCS.eval()
|
|
modelCS.cond_stage_model.device = torch.device(thread_data.device)
|
|
# if thread_data.device != 'cpu':
|
|
# if thread_data.reduced_memory:
|
|
# modelCS.to('cpu')
|
|
# else:
|
|
# modelCS.to(thread_data.device) # Preload on device if not already there.
|
|
thread_data.modelCS = modelCS
|
|
|
|
modelFS = instantiate_from_config(config.modelFirstStage)
|
|
_, _ = modelFS.load_state_dict(sd, strict=False)
|
|
|
|
if thread_data.vae_file is not None:
|
|
try:
|
|
loaded = False
|
|
for model_extension in ['.ckpt', '.vae.pt']:
|
|
if os.path.exists(thread_data.vae_file + model_extension):
|
|
print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}")
|
|
vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu")
|
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
|
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
|
|
loaded = True
|
|
break
|
|
|
|
if not loaded:
|
|
print(f'Cannot find VAE: {thread_data.vae_file}')
|
|
thread_data.vae_file = None
|
|
except:
|
|
print(traceback.format_exc())
|
|
print(f'Could not load VAE: {thread_data.vae_file}')
|
|
thread_data.vae_file = None
|
|
|
|
modelFS.eval()
|
|
# if thread_data.device != 'cpu':
|
|
# if thread_data.reduced_memory:
|
|
# modelFS.to('cpu')
|
|
# else:
|
|
# modelFS.to(thread_data.device) # Preload on device if not already there.
|
|
thread_data.modelFS = modelFS
|
|
del sd
|
|
|
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
|
thread_data.model.half()
|
|
thread_data.modelCS.half()
|
|
thread_data.modelFS.half()
|
|
thread_data.model_is_half = True
|
|
thread_data.model_fs_is_half = True
|
|
else:
|
|
thread_data.model_is_half = False
|
|
thread_data.model_fs_is_half = False
|
|
|
|
print(f'''loaded model
|
|
model file: {thread_data.ckpt_file}
|
|
model.device: {model.device}
|
|
modelCS.device: {modelCS.cond_stage_model.device}
|
|
modelFS.device: {thread_data.modelFS.device}
|
|
using precision: {thread_data.precision}''')
|
|
|
|
def load_model_ckpt_sd2():
|
|
sd, model_ver = load_model_from_config(thread_data.ckpt_file)
|
|
|
|
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml"
|
|
config = OmegaConf.load(config_file)
|
|
verbose = False
|
|
|
|
thread_data.model = instantiate_from_config(config.model)
|
|
m, u = thread_data.model.load_state_dict(sd, strict=False)
|
|
if len(m) > 0 and verbose:
|
|
print("missing keys:")
|
|
print(m)
|
|
if len(u) > 0 and verbose:
|
|
print("unexpected keys:")
|
|
print(u)
|
|
|
|
thread_data.model.to(thread_data.device)
|
|
thread_data.model.eval()
|
|
del sd
|
|
|
|
thread_data.model.cond_stage_model.device = torch.device(thread_data.device)
|
|
|
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
|
thread_data.model.half()
|
|
thread_data.model_is_half = True
|
|
thread_data.model_fs_is_half = True
|
|
else:
|
|
thread_data.model_is_half = False
|
|
thread_data.model_fs_is_half = False
|
|
|
|
print(f'''loaded model
|
|
model file: {thread_data.ckpt_file}
|
|
using precision: {thread_data.precision}''')
|
|
|
|
def unload_filters():
|
|
if thread_data.model_gfpgan is not None:
|
|
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
|
|
|
|
del thread_data.model_gfpgan
|
|
thread_data.model_gfpgan = None
|
|
|
|
if thread_data.model_real_esrgan is not None:
|
|
if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu')
|
|
|
|
del thread_data.model_real_esrgan
|
|
thread_data.model_real_esrgan = None
|
|
|
|
gc()
|
|
|
|
def unload_models():
|
|
if thread_data.model is not None:
|
|
print('Unloading models...')
|
|
if thread_data.device != 'cpu':
|
|
if not thread_data.test_sd2:
|
|
thread_data.modelFS.to('cpu')
|
|
thread_data.modelCS.to('cpu')
|
|
thread_data.model.model1.to("cpu")
|
|
thread_data.model.model2.to("cpu")
|
|
|
|
del thread_data.model
|
|
del thread_data.modelCS
|
|
del thread_data.modelFS
|
|
|
|
thread_data.model = None
|
|
thread_data.modelCS = None
|
|
thread_data.modelFS = None
|
|
|
|
gc()
|
|
|
|
# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
|
|
# if thread_data.device == target_device: return
|
|
# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
|
# if start_mem <= 0: return
|
|
# model_name = model.__class__.__name__
|
|
# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb')
|
|
# start_time = time.time()
|
|
# model.to(target_device)
|
|
# time_step = start_time
|
|
# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
|
|
# last_mem = start_mem
|
|
# is_transfering = True
|
|
# while is_transfering:
|
|
# time.sleep(0.5) # 500ms
|
|
# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
|
# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
|
|
# last_mem = mem
|
|
# if not is_transfering:
|
|
# break;
|
|
# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
|
|
# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb')
|
|
# time_step = time.time()
|
|
# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}')
|
|
|
|
def move_to_cpu(model):
|
|
if thread_data.device != "cpu":
|
|
d = torch.device(thread_data.device)
|
|
mem = torch.cuda.memory_allocated(d) / 1e6
|
|
model.to("cpu")
|
|
while torch.cuda.memory_allocated(d) / 1e6 >= mem:
|
|
time.sleep(1)
|
|
|
|
def load_model_gfpgan():
|
|
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
|
|
model_path = thread_data.gfpgan_file + ".pth"
|
|
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
|
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
|
|
|
def load_model_real_esrgan():
|
|
if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
|
|
model_path = thread_data.real_esrgan_file + ".pth"
|
|
|
|
RealESRGAN_models = {
|
|
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
|
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
|
}
|
|
|
|
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
|
|
|
|
if thread_data.device == 'cpu':
|
|
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
|
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
|
|
thread_data.model_real_esrgan.model.to('cpu')
|
|
else:
|
|
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
|
|
|
|
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
|
|
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
|
|
|
|
|
def get_session_out_path(disk_path, session_id):
|
|
if disk_path is None: return None
|
|
if session_id is None: return None
|
|
|
|
session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id))
|
|
os.makedirs(session_out_path, exist_ok=True)
|
|
return session_out_path
|
|
|
|
def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
|
|
if disk_path is None: return None
|
|
if session_id is None: return None
|
|
if ext is None: raise Exception('Missing ext')
|
|
|
|
session_out_path = get_session_out_path(disk_path, session_id)
|
|
|
|
prompt_flattened = filename_regex.sub('_', prompt)[:50]
|
|
|
|
if suffix is not None:
|
|
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
|
|
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
|
|
|
|
def apply_filters(filter_name, image_data, model_path=None):
|
|
print(f'Applying filter {filter_name}...')
|
|
gc() # Free space before loading new data.
|
|
|
|
if isinstance(image_data, torch.Tensor):
|
|
image_data.to(thread_data.device)
|
|
|
|
if filter_name == 'gfpgan':
|
|
# This lock is only ever used here. No need to use timeout for the request. Should never deadlock.
|
|
with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting.
|
|
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
|
|
from facexlib.detection import retinaface
|
|
retinaface.device = torch.device(thread_data.device)
|
|
print('forced retinaface.device to', thread_data.device)
|
|
|
|
if model_path is not None and model_path != thread_data.gfpgan_file:
|
|
thread_data.gfpgan_file = model_path
|
|
load_model_gfpgan()
|
|
elif not thread_data.model_gfpgan:
|
|
load_model_gfpgan()
|
|
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
|
|
|
|
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
|
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
|
image_data = output[:,:,::-1]
|
|
|
|
if filter_name == 'real_esrgan':
|
|
if model_path is not None and model_path != thread_data.real_esrgan_file:
|
|
thread_data.real_esrgan_file = model_path
|
|
load_model_real_esrgan()
|
|
elif not thread_data.model_real_esrgan:
|
|
load_model_real_esrgan()
|
|
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
|
|
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
|
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
|
|
image_data = output[:,:,::-1]
|
|
|
|
return image_data
|
|
|
|
def is_model_reload_necessary(req: Request):
|
|
# custom model support:
|
|
# the req.use_stable_diffusion_model needs to be a valid path
|
|
# to the ckpt file (without the extension).
|
|
if os.path.exists(req.use_stable_diffusion_model + '.ckpt'):
|
|
req.use_stable_diffusion_model += '.ckpt'
|
|
elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'):
|
|
req.use_stable_diffusion_model += '.safetensors'
|
|
elif not os.path.exists(req.use_stable_diffusion_model):
|
|
raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors')
|
|
|
|
needs_model_reload = False
|
|
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
|
thread_data.ckpt_file = req.use_stable_diffusion_model
|
|
thread_data.vae_file = req.use_vae_model
|
|
needs_model_reload = True
|
|
|
|
if thread_data.device != 'cpu':
|
|
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
|
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
|
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
|
needs_model_reload = True
|
|
|
|
return needs_model_reload
|
|
|
|
def reload_model():
|
|
unload_models()
|
|
unload_filters()
|
|
load_model_ckpt()
|
|
|
|
def is_hypernetwork_reload_necessary(req: Request):
|
|
needs_model_reload = False
|
|
if thread_data.hypernetwork_file != req.use_hypernetwork_model:
|
|
thread_data.hypernetwork_file = req.use_hypernetwork_model
|
|
needs_model_reload = True
|
|
|
|
return needs_model_reload
|
|
|
|
def load_hypernetwork():
|
|
if thread_data.test_sd2:
|
|
# Not yet supported in SD2
|
|
return
|
|
|
|
from . import hypernetwork
|
|
if thread_data.hypernetwork_file is not None:
|
|
try:
|
|
loaded = False
|
|
for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
|
|
if os.path.exists(thread_data.hypernetwork_file + model_extension):
|
|
print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
|
|
thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
|
|
loaded = True
|
|
break
|
|
|
|
if not loaded:
|
|
print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
|
|
thread_data.hypernetwork_file = None
|
|
except:
|
|
print(traceback.format_exc())
|
|
print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
|
|
thread_data.hypernetwork_file = None
|
|
|
|
def unload_hypernetwork():
|
|
if thread_data.hypernetwork is not None:
|
|
print('Unloading hypernetwork...')
|
|
if thread_data.device != 'cpu':
|
|
for i in thread_data.hypernetwork:
|
|
thread_data.hypernetwork[i][0].to('cpu')
|
|
thread_data.hypernetwork[i][1].to('cpu')
|
|
del thread_data.hypernetwork
|
|
thread_data.hypernetwork = None
|
|
|
|
gc()
|
|
|
|
def reload_hypernetwork():
|
|
unload_hypernetwork()
|
|
load_hypernetwork()
|
|
|
|
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
|
try:
|
|
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
|
except Exception as e:
|
|
print(traceback.format_exc())
|
|
|
|
if thread_data.device != 'cpu' and not thread_data.test_sd2:
|
|
thread_data.modelFS.to('cpu')
|
|
thread_data.modelCS.to('cpu')
|
|
thread_data.model.model1.to("cpu")
|
|
thread_data.model.model2.to("cpu")
|
|
|
|
gc() # Release from memory.
|
|
data_queue.put(json.dumps({
|
|
"status": 'failed',
|
|
"detail": str(e)
|
|
}))
|
|
raise e
|
|
|
|
def update_temp_img(req, x_samples, task_temp_images: list):
|
|
partial_images = []
|
|
for i in range(req.num_outputs):
|
|
if thread_data.test_sd2:
|
|
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
|
|
else:
|
|
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
|
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
|
x_sample = x_sample.astype(np.uint8)
|
|
img = Image.fromarray(x_sample)
|
|
buf = img_to_buffer(img, output_format='JPEG')
|
|
|
|
del img, x_sample, x_sample_ddim
|
|
# don't delete x_samples, it is used in the code that called this callback
|
|
|
|
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
|
task_temp_images[i] = buf
|
|
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
|
return partial_images
|
|
|
|
# Build and return the apropriate generator for do_mk_img
|
|
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
|
|
if not req.stream_progress_updates:
|
|
def empty_callback(x_samples, i): return x_samples
|
|
return empty_callback
|
|
|
|
thread_data.partial_x_samples = None
|
|
last_callback_time = -1
|
|
def img_callback(x_samples, i):
|
|
nonlocal last_callback_time
|
|
|
|
thread_data.partial_x_samples = x_samples
|
|
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
|
last_callback_time = time.time()
|
|
|
|
progress = {"step": i, "step_time": step_time}
|
|
if extra_props is not None:
|
|
progress.update(extra_props)
|
|
|
|
if req.stream_image_progress and i % 5 == 0:
|
|
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
|
|
|
|
data_queue.put(json.dumps(progress))
|
|
|
|
step_callback()
|
|
|
|
if thread_data.stop_processing:
|
|
raise UserInitiatedStop("User requested that we stop processing")
|
|
return img_callback
|
|
|
|
def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
|
thread_data.stop_processing = False
|
|
|
|
res = Response()
|
|
res.request = req
|
|
res.images = []
|
|
thread_data.hypernetwork_strength = req.hypernetwork_strength
|
|
|
|
thread_data.temp_images.clear()
|
|
|
|
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
|
|
thread_data.turbo = req.turbo
|
|
thread_data.model.turbo = req.turbo
|
|
|
|
# Start by cleaning memory, loading and unloading things can leave memory allocated.
|
|
gc()
|
|
|
|
opt_prompt = req.prompt
|
|
opt_seed = req.seed
|
|
opt_n_iter = 1
|
|
opt_C = 4
|
|
opt_f = 8
|
|
opt_ddim_eta = 0.0
|
|
|
|
print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name)
|
|
print('\n\n Using precision:', thread_data.precision)
|
|
|
|
seed_everything(opt_seed)
|
|
|
|
batch_size = req.num_outputs
|
|
prompt = opt_prompt
|
|
assert prompt is not None
|
|
data = [batch_size * [prompt]]
|
|
|
|
if thread_data.precision == "autocast" and thread_data.device != "cpu":
|
|
precision_scope = autocast
|
|
else:
|
|
precision_scope = nullcontext
|
|
|
|
mask = None
|
|
|
|
if req.init_image is None:
|
|
handler = _txt2img
|
|
|
|
init_latent = None
|
|
t_enc = None
|
|
else:
|
|
handler = _img2img
|
|
|
|
init_image = load_img(req.init_image, req.width, req.height)
|
|
init_image = init_image.to(thread_data.device)
|
|
|
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
|
init_image = init_image.half()
|
|
|
|
if not thread_data.test_sd2:
|
|
thread_data.modelFS.to(thread_data.device)
|
|
|
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
|
if thread_data.test_sd2:
|
|
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
|
|
else:
|
|
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
|
|
|
if req.mask is not None:
|
|
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
|
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
|
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
|
|
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
|
mask = mask.half()
|
|
|
|
# Send to CPU and wait until complete.
|
|
# wait_model_move_to(thread_data.modelFS, 'cpu')
|
|
if not thread_data.test_sd2:
|
|
move_to_cpu(thread_data.modelFS)
|
|
|
|
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
|
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
|
print(f"target t_enc is {t_enc} steps")
|
|
|
|
if req.save_to_disk_path is not None:
|
|
session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id)
|
|
else:
|
|
session_out_path = None
|
|
|
|
with torch.no_grad():
|
|
for n in trange(opt_n_iter, desc="Sampling"):
|
|
for prompts in tqdm(data, desc="data"):
|
|
|
|
with precision_scope("cuda"):
|
|
if thread_data.reduced_memory and not thread_data.test_sd2:
|
|
thread_data.modelCS.to(thread_data.device)
|
|
uc = None
|
|
if req.guidance_scale != 1.0:
|
|
if thread_data.test_sd2:
|
|
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
|
|
else:
|
|
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
|
if isinstance(prompts, tuple):
|
|
prompts = list(prompts)
|
|
|
|
subprompts, weights = split_weighted_subprompts(prompts[0])
|
|
if len(subprompts) > 1:
|
|
c = torch.zeros_like(uc)
|
|
totalWeight = sum(weights)
|
|
# normalize each "sub prompt" and add it
|
|
for i in range(len(subprompts)):
|
|
weight = weights[i]
|
|
# if not skip_normalize:
|
|
weight = weight / totalWeight
|
|
if thread_data.test_sd2:
|
|
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
|
else:
|
|
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
|
else:
|
|
if thread_data.test_sd2:
|
|
c = thread_data.model.get_learned_conditioning(prompts)
|
|
else:
|
|
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
|
|
|
if thread_data.reduced_memory and not thread_data.test_sd2:
|
|
thread_data.modelFS.to(thread_data.device)
|
|
|
|
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
|
img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
|
|
|
|
# run the handler
|
|
try:
|
|
print('Running handler...')
|
|
if handler == _txt2img:
|
|
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
|
else:
|
|
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
|
|
except UserInitiatedStop:
|
|
if not hasattr(thread_data, 'partial_x_samples'):
|
|
continue
|
|
if thread_data.partial_x_samples is None:
|
|
del thread_data.partial_x_samples
|
|
continue
|
|
x_samples = thread_data.partial_x_samples
|
|
del thread_data.partial_x_samples
|
|
|
|
print("decoding images")
|
|
img_data = [None] * batch_size
|
|
for i in range(batch_size):
|
|
if thread_data.test_sd2:
|
|
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
|
|
else:
|
|
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
|
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
|
x_sample = x_sample.astype(np.uint8)
|
|
img_data[i] = x_sample
|
|
del x_samples, x_samples_ddim, x_sample
|
|
|
|
print("saving images")
|
|
for i in range(batch_size):
|
|
img = Image.fromarray(img_data[i])
|
|
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
|
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
|
|
|
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
|
|
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
|
|
|
|
return_orig_img = not has_filters or not req.show_only_filtered_image
|
|
|
|
if thread_data.stop_processing:
|
|
return_orig_img = True
|
|
|
|
if req.save_to_disk_path is not None:
|
|
if return_orig_img:
|
|
img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format)
|
|
save_image(img, img_out_path, req.output_format, req.output_quality)
|
|
meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt')
|
|
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
|
|
|
if return_orig_img:
|
|
img_buffer = img_to_buffer(img, req.output_format, req.output_quality)
|
|
img_str = buffer_to_base64_str(img_buffer, req.output_format)
|
|
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
|
|
res.images.append(res_image_orig)
|
|
task_temp_images[i] = img_buffer
|
|
|
|
if req.save_to_disk_path is not None:
|
|
res_image_orig.path_abs = img_out_path
|
|
del img
|
|
|
|
if has_filters and not thread_data.stop_processing:
|
|
filters_applied = []
|
|
if req.use_face_correction:
|
|
img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
|
|
filters_applied.append(req.use_face_correction)
|
|
if req.use_upscale:
|
|
img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
|
|
filters_applied.append(req.use_upscale)
|
|
if (len(filters_applied) > 0):
|
|
filtered_image = Image.fromarray(img_data[i])
|
|
filtered_buffer = img_to_buffer(filtered_image, req.output_format, req.output_quality)
|
|
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
|
|
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
|
res.images.append(response_image)
|
|
task_temp_images[i] = filtered_buffer
|
|
if req.save_to_disk_path is not None:
|
|
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
|
|
save_image(filtered_image, filtered_img_out_path, req.output_format, req.output_quality)
|
|
response_image.path_abs = filtered_img_out_path
|
|
del filtered_image
|
|
# Filter Applied, move to next seed
|
|
opt_seed += 1
|
|
|
|
# if thread_data.reduced_memory:
|
|
# unload_filters()
|
|
if not thread_data.test_sd2:
|
|
move_to_cpu(thread_data.modelFS)
|
|
del img_data
|
|
gc()
|
|
if thread_data.device != 'cpu':
|
|
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
|
|
|
|
print('Task completed')
|
|
res = res.json()
|
|
data_queue.put(json.dumps(res))
|
|
|
|
return res
|
|
|
|
def save_image(img, img_out_path, output_format="", output_quality=75):
|
|
try:
|
|
if output_format.upper() == "JPEG":
|
|
img.save(img_out_path, quality=output_quality)
|
|
else:
|
|
img.save(img_out_path)
|
|
except:
|
|
print('could not save the file', traceback.format_exc())
|
|
|
|
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
|
metadata = f'''{prompt}
|
|
Width: {req.width}
|
|
Height: {req.height}
|
|
Seed: {opt_seed}
|
|
Steps: {req.num_inference_steps}
|
|
Guidance Scale: {req.guidance_scale}
|
|
Prompt Strength: {req.prompt_strength}
|
|
Use Face Correction: {req.use_face_correction}
|
|
Use Upscaling: {req.use_upscale}
|
|
Sampler: {req.sampler}
|
|
Negative Prompt: {req.negative_prompt}
|
|
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
|
VAE model: {req.use_vae_model}
|
|
Hypernetwork Model: {req.use_hypernetwork_model}
|
|
Hypernetwork Strength: {req.hypernetwork_strength}
|
|
'''
|
|
try:
|
|
with open(meta_out_path, 'w', encoding='utf-8') as f:
|
|
f.write(metadata)
|
|
except:
|
|
print('could not save the file', traceback.format_exc())
|
|
|
|
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
|
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
|
|
|
# Send to CPU and wait until complete.
|
|
# wait_model_move_to(thread_data.modelCS, 'cpu')
|
|
|
|
if not thread_data.test_sd2:
|
|
move_to_cpu(thread_data.modelCS)
|
|
|
|
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'):
|
|
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
|
|
|
|
|
|
# samples, _ = sampler.sample(S=opt.steps,
|
|
# conditioning=c,
|
|
# batch_size=opt.n_samples,
|
|
# shape=shape,
|
|
# verbose=False,
|
|
# unconditional_guidance_scale=opt.scale,
|
|
# unconditional_conditioning=uc,
|
|
# eta=opt.ddim_eta,
|
|
# x_T=start_code)
|
|
|
|
if thread_data.test_sd2:
|
|
if sampler_name == 'plms':
|
|
from ldm.models.diffusion.plms import PLMSSampler
|
|
sampler = PLMSSampler(thread_data.model)
|
|
elif sampler_name == 'ddim':
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
sampler = DDIMSampler(thread_data.model)
|
|
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
|
elif sampler_name == 'dpm2':
|
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
|
sampler = DPMSolverSampler(thread_data.model)
|
|
|
|
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
|
|
|
samples_ddim, intermediates = sampler.sample(
|
|
S=opt_ddim_steps,
|
|
conditioning=c,
|
|
batch_size=opt_n_samples,
|
|
seed=opt_seed,
|
|
shape=shape,
|
|
verbose=False,
|
|
unconditional_guidance_scale=opt_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=opt_ddim_eta,
|
|
x_T=start_code,
|
|
img_callback=img_callback,
|
|
mask=mask,
|
|
sampler = sampler_name,
|
|
)
|
|
else:
|
|
if sampler_name == 'ddim':
|
|
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
|
|
|
samples_ddim = thread_data.model.sample(
|
|
S=opt_ddim_steps,
|
|
conditioning=c,
|
|
seed=opt_seed,
|
|
shape=shape,
|
|
verbose=False,
|
|
unconditional_guidance_scale=opt_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=opt_ddim_eta,
|
|
x_T=start_code,
|
|
img_callback=img_callback,
|
|
mask=mask,
|
|
sampler = sampler_name,
|
|
)
|
|
return samples_ddim
|
|
|
|
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
|
|
# encode (scaled latent)
|
|
x_T = None if mask is None else init_latent
|
|
|
|
if thread_data.test_sd2:
|
|
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
|
sampler = DDIMSampler(thread_data.model)
|
|
|
|
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
|
|
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))
|
|
|
|
samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)
|
|
|
|
else:
|
|
z_enc = thread_data.model.stochastic_encode(
|
|
init_latent,
|
|
torch.tensor([t_enc] * batch_size).to(thread_data.device),
|
|
opt_seed,
|
|
opt_ddim_eta,
|
|
opt_ddim_steps,
|
|
)
|
|
|
|
# decode it
|
|
samples_ddim = thread_data.model.sample(
|
|
t_enc,
|
|
c,
|
|
z_enc,
|
|
unconditional_guidance_scale=opt_scale,
|
|
unconditional_conditioning=uc,
|
|
img_callback=img_callback,
|
|
mask=mask,
|
|
x_T=x_T,
|
|
sampler = 'ddim'
|
|
)
|
|
return samples_ddim
|
|
|
|
def gc():
|
|
gc_collect()
|
|
if thread_data.device == 'cpu':
|
|
return
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
# internal
|
|
|
|
def chunk(it, size):
|
|
it = iter(it)
|
|
return iter(lambda: tuple(islice(it, size)), ())
|
|
|
|
def load_model_from_config(ckpt, verbose=False):
|
|
print(f"Loading model from {ckpt}")
|
|
model_ver = 'sd1'
|
|
|
|
if ckpt.endswith(".safetensors"):
|
|
print("Loading from safetensors")
|
|
pl_sd = load_file(ckpt, device="cpu")
|
|
else:
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
|
|
if "global_step" in pl_sd:
|
|
print(f"Global Step: {pl_sd['global_step']}")
|
|
|
|
if "state_dict" in pl_sd:
|
|
# check for a key that only seems to be present in SD2 models
|
|
if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys():
|
|
model_ver = 'sd2'
|
|
|
|
return pl_sd["state_dict"], model_ver
|
|
else:
|
|
return pl_sd, model_ver
|
|
|
|
class UserInitiatedStop(Exception):
|
|
pass
|
|
|
|
def load_img(img_str, w0, h0):
|
|
image = base64_str_to_img(img_str).convert("RGB")
|
|
w, h = image.size
|
|
print(f"loaded input image of size ({w}, {h}) from base64")
|
|
if h0 is not None and w0 is not None:
|
|
h, w = h0, w0
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
|
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image)
|
|
return 2.*image - 1.
|
|
|
|
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
|
image = base64_str_to_img(mask_str).convert("RGB")
|
|
w, h = image.size
|
|
print(f"loaded input mask of size ({w}, {h})")
|
|
|
|
if invert:
|
|
print("inverted")
|
|
image = ImageOps.invert(image)
|
|
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
|
|
# image[where_0], image[where_1] = 255, 0
|
|
|
|
if h0 is not None and w0 is not None:
|
|
h, w = h0, w0
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
|
|
|
print(f"New mask size ({w}, {h})")
|
|
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
|
|
image = np.array(image)
|
|
|
|
image = image.astype(np.float32) / 255.0
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image)
|
|
return image
|
|
|
|
# https://stackoverflow.com/a/61114178
|
|
def img_to_base64_str(img, output_format="PNG", output_quality=75):
|
|
buffered = img_to_buffer(img, output_format, quality=output_quality)
|
|
return buffer_to_base64_str(buffered, output_format)
|
|
|
|
def img_to_buffer(img, output_format="PNG", output_quality=75):
|
|
buffered = BytesIO()
|
|
if ( output_format.upper() == "JPEG" ):
|
|
img.save(buffered, format=output_format, quality=output_quality)
|
|
else:
|
|
img.save(buffered, format=output_format)
|
|
buffered.seek(0)
|
|
return buffered
|
|
|
|
def buffer_to_base64_str(buffered, output_format="PNG"):
|
|
buffered.seek(0)
|
|
img_byte = buffered.getvalue()
|
|
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
|
|
img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode()
|
|
return img_str
|
|
|
|
def base64_str_to_buffer(img_str):
|
|
mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg"
|
|
img_str = img_str[len(f"data:{mime_type};base64,"):]
|
|
data = base64.b64decode(img_str)
|
|
buffered = BytesIO(data)
|
|
return buffered
|
|
|
|
def base64_str_to_img(img_str):
|
|
buffered = base64_str_to_buffer(img_str)
|
|
img = Image.open(buffered)
|
|
return img
|
|
|
|
def split_weighted_subprompts(text):
|
|
"""
|
|
grabs all text up to the first occurrence of ':'
|
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
|
if ':' has no value defined, defaults to 1.0
|
|
repeats until no text remaining
|
|
"""
|
|
remaining = len(text)
|
|
prompts = []
|
|
weights = []
|
|
while remaining > 0:
|
|
if ":" in text:
|
|
idx = text.index(":") # first occurrence from start
|
|
# grab up to index as sub-prompt
|
|
prompt = text[:idx]
|
|
remaining -= idx
|
|
# remove from main text
|
|
text = text[idx+1:]
|
|
# find value for weight
|
|
if " " in text:
|
|
idx = text.index(" ") # first occurence
|
|
else: # no space, read to end
|
|
idx = len(text)
|
|
if idx != 0:
|
|
try:
|
|
weight = float(text[:idx])
|
|
except: # couldn't treat as float
|
|
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
|
|
weight = 1.0
|
|
else: # no value found
|
|
weight = 1.0
|
|
# remove from main text
|
|
remaining -= idx
|
|
text = text[idx+1:]
|
|
# append the sub-prompt and its weight
|
|
prompts.append(prompt)
|
|
weights.append(weight)
|
|
else: # no : found
|
|
if len(text) > 0: # there is still text though
|
|
# take remainder as weight 1
|
|
prompts.append(text)
|
|
weights.append(1.0)
|
|
remaining = 0
|
|
return prompts, weights
|