Rename the python package name to easydiffusion (from sd_internal)

This commit is contained in:
cmdr2 2022-12-19 19:39:15 +05:30
parent e483071894
commit 47e3884994
10 changed files with 40 additions and 37 deletions

View File

View File

@ -6,7 +6,8 @@ import traceback
import logging import logging
from rich.logging import RichHandler from rich.logging import RichHandler
from sd_internal import task_manager from easydiffusion import task_manager
from easydiffusion.utils import log
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
logging.basicConfig( logging.basicConfig(
@ -16,8 +17,6 @@ logging.basicConfig(
handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)] handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)]
) )
log = logging.getLogger()
SD_DIR = os.getcwd() SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None) SD_UI_DIR = os.getenv('SD_UI_PATH', None)

View File

@ -2,9 +2,8 @@ import os
import torch import torch
import traceback import traceback
import re import re
import logging
log = logging.getLogger() from easydiffusion.utils import log
''' '''
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).

View File

@ -1,13 +1,13 @@
import os import os
import logging
import picklescan.scanner import picklescan.scanner
from sd_internal import app, TaskData, device_manager from easydiffusion import app, device_manager
from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit.models import model_loader from sdkit.models import model_loader
from sdkit.types import Context from sdkit.types import Context
log = logging.getLogger()
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = { MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'], 'stable-diffusion': ['.ckpt', '.safetensors'],

View File

@ -1,16 +1,15 @@
import queue import queue
import time import time
import json import json
import logging
from sd_internal import device_manager, save_utils from easydiffusion import device_manager
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
from sdkit import model_loader, image_generator, image_utils, filters as image_filters from sdkit import model_loader, image_generator, filters as image_filters
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images
from sdkit.types import Context, GenerateImageRequest, FilterImageRequest from sdkit.types import Context, GenerateImageRequest, FilterImageRequest
log = logging.getLogger()
context = Context() # thread-local context = Context() # thread-local
''' '''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
@ -28,7 +27,7 @@ def init(device):
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
context.stop_processing = False context.stop_processing = False
log.info(f'request: {save_utils.get_printable_request(req)}') log.info(f'request: {get_printable_request(req)}')
log.info(f'task data: {task_data.dict()}') log.info(f'task data: {task_data.dict()}')
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
@ -45,7 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
filtered_images = apply_filters(task_data, images, user_stopped) filtered_images = apply_filters(task_data, images, user_stopped)
if task_data.save_to_disk_path is not None: if task_data.save_to_disk_path is not None:
save_utils.save_to_disk(images, filtered_images, req, task_data) save_images_to_disk(images, filtered_images, req, task_data)
return filtered_images if task_data.show_only_filtered_image else images + filtered_images return filtered_images if task_data.show_only_filtered_image else images + filtered_images
@ -61,7 +60,7 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue:
images = [] images = []
user_stopped = True user_stopped = True
if context.partial_x_samples is not None: if context.partial_x_samples is not None:
images = image_utils.latent_samples_to_images(context, context.partial_x_samples) images = latent_samples_to_images(context, context.partial_x_samples)
context.partial_x_samples = None context.partial_x_samples = None
finally: finally:
model_loader.gc(context) model_loader.gc(context)
@ -89,7 +88,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped):
def construct_response(images: list, task_data: TaskData, base_seed: int): def construct_response(images: list, task_data: TaskData, base_seed: int):
return [ return [
ResponseImage( ResponseImage(
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
seed=base_seed + i seed=base_seed + i
) for i, img in enumerate(images) ) for i, img in enumerate(images)
] ]
@ -100,9 +99,9 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
def update_temp_img(x_samples, task_temp_images: list): def update_temp_img(x_samples, task_temp_images: list):
partial_images = [] partial_images = []
images = image_utils.latent_samples_to_images(context, x_samples) images = latent_samples_to_images(context, x_samples)
for i, img in enumerate(images): for i, img in enumerate(images):
buf = image_utils.img_to_buffer(img, output_format='JPEG') buf = img_to_buffer(img, output_format='JPEG')
context.temp_images[f"{task_data.request_id}/{i}"] = buf context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf task_temp_images[i] = buf

View File

@ -6,7 +6,6 @@ Notes:
""" """
import json import json
import traceback import traceback
import logging
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
@ -14,10 +13,11 @@ import torch
import queue, threading, time, weakref import queue, threading, time, weakref
from typing import Any, Hashable from typing import Any, Hashable
from sd_internal import TaskData, device_manager from easydiffusion import device_manager
from sdkit.types import GenerateImageRequest from easydiffusion.types import TaskData
from easydiffusion.utils import log
log = logging.getLogger() from sdkit.types import GenerateImageRequest
THREAD_NAME_PREFIX = '' THREAD_NAME_PREFIX = ''
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
@ -186,7 +186,7 @@ class SessionState():
return True return True
def thread_get_next_task(): def thread_get_next_task():
from sd_internal import renderer from easydiffusion import renderer
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
return None return None
@ -219,7 +219,7 @@ def thread_get_next_task():
def thread_render(device): def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from sd_internal import renderer, model_manager from easydiffusion import renderer, model_manager
try: try:
renderer.init(device) renderer.init(device)

View File

@ -0,0 +1,8 @@
import logging
log = logging.getLogger('easydiffusion')
from .save_utils import (
save_images_to_disk,
get_printable_request,
)

View File

@ -3,11 +3,11 @@ import time
import base64 import base64
import re import re
from easydiffusion.types import TaskData
from sdkit.utils import save_images, save_dicts from sdkit.utils import save_images, save_dicts
from sdkit.types import GenerateImageRequest from sdkit.types import GenerateImageRequest
from sd_internal import TaskData
filename_regex = re.compile('[^a-zA-Z0-9]') filename_regex = re.compile('[^a-zA-Z0-9]')
# keep in sync with `ui/media/js/dnd.js` # keep in sync with `ui/media/js/dnd.js`
@ -28,9 +28,9 @@ TASK_TEXT_MAPPING = {
'hypernetwork_strength': 'Hypernetwork Strength' 'hypernetwork_strength': 'Hypernetwork Strength'
} }
def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
metadata_entries = get_metadata_entries(req, task_data) metadata_entries = get_metadata_entries_for_request(req, task_data)
if task_data.show_only_filtered_image or filtered_images == images: if task_data.show_only_filtered_image or filtered_images == images:
save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
@ -40,7 +40,7 @@ def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest,
save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req) metadata = get_printable_request(req)
metadata.update({ metadata.update({
'use_stable_diffusion_model': task_data.use_stable_diffusion_model, 'use_stable_diffusion_model': task_data.use_stable_diffusion_model,

View File

@ -4,7 +4,6 @@ Notes:
""" """
import os import os
import traceback import traceback
import logging
import datetime import datetime
from typing import List, Union from typing import List, Union
@ -13,12 +12,11 @@ from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sd_internal import app, model_manager, task_manager from easydiffusion import app, model_manager, task_manager
from sd_internal import TaskData from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit.types import GenerateImageRequest from sdkit.types import GenerateImageRequest
log = logging.getLogger()
log.info(f'started in {app.SD_DIR}') log.info(f'started in {app.SD_DIR}')
log.info(f'started at {datetime.datetime.now():%x %X}') log.info(f'started at {datetime.datetime.now():%x %X}')