mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-27 08:39:28 +01:00
Rename the python package name to easydiffusion (from sd_internal)
This commit is contained in:
parent
e483071894
commit
47e3884994
0
ui/easydiffusion/__init__.py
Normal file
0
ui/easydiffusion/__init__.py
Normal file
@ -6,7 +6,8 @@ import traceback
|
|||||||
import logging
|
import logging
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
from sd_internal import task_manager
|
from easydiffusion import task_manager
|
||||||
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
|
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -16,8 +17,6 @@ logging.basicConfig(
|
|||||||
handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)]
|
handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)]
|
||||||
)
|
)
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
SD_DIR = os.getcwd()
|
SD_DIR = os.getcwd()
|
||||||
|
|
||||||
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
@ -2,9 +2,8 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
import re
|
import re
|
||||||
import logging
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
@ -1,13 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
|
||||||
import picklescan.scanner
|
import picklescan.scanner
|
||||||
|
|
||||||
from sd_internal import app, TaskData, device_manager
|
from easydiffusion import app, device_manager
|
||||||
|
from easydiffusion.types import TaskData
|
||||||
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
from sdkit.models import model_loader
|
from sdkit.models import model_loader
|
||||||
from sdkit.types import Context
|
from sdkit.types import Context
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||||
MODEL_EXTENSIONS = {
|
MODEL_EXTENSIONS = {
|
||||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
'stable-diffusion': ['.ckpt', '.safetensors'],
|
@ -1,16 +1,15 @@
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
|
|
||||||
from sd_internal import device_manager, save_utils
|
from easydiffusion import device_manager
|
||||||
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
|
||||||
|
|
||||||
from sdkit import model_loader, image_generator, image_utils, filters as image_filters
|
from sdkit import model_loader, image_generator, filters as image_filters
|
||||||
|
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images
|
||||||
from sdkit.types import Context, GenerateImageRequest, FilterImageRequest
|
from sdkit.types import Context, GenerateImageRequest, FilterImageRequest
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
context = Context() # thread-local
|
context = Context() # thread-local
|
||||||
'''
|
'''
|
||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
@ -28,7 +27,7 @@ def init(device):
|
|||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
context.stop_processing = False
|
context.stop_processing = False
|
||||||
log.info(f'request: {save_utils.get_printable_request(req)}')
|
log.info(f'request: {get_printable_request(req)}')
|
||||||
log.info(f'task data: {task_data.dict()}')
|
log.info(f'task data: {task_data.dict()}')
|
||||||
|
|
||||||
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
@ -45,7 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
|
|||||||
filtered_images = apply_filters(task_data, images, user_stopped)
|
filtered_images = apply_filters(task_data, images, user_stopped)
|
||||||
|
|
||||||
if task_data.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
save_utils.save_to_disk(images, filtered_images, req, task_data)
|
save_images_to_disk(images, filtered_images, req, task_data)
|
||||||
|
|
||||||
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
||||||
|
|
||||||
@ -61,7 +60,7 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue:
|
|||||||
images = []
|
images = []
|
||||||
user_stopped = True
|
user_stopped = True
|
||||||
if context.partial_x_samples is not None:
|
if context.partial_x_samples is not None:
|
||||||
images = image_utils.latent_samples_to_images(context, context.partial_x_samples)
|
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||||
context.partial_x_samples = None
|
context.partial_x_samples = None
|
||||||
finally:
|
finally:
|
||||||
model_loader.gc(context)
|
model_loader.gc(context)
|
||||||
@ -89,7 +88,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped):
|
|||||||
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
||||||
return [
|
return [
|
||||||
ResponseImage(
|
ResponseImage(
|
||||||
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||||
seed=base_seed + i
|
seed=base_seed + i
|
||||||
) for i, img in enumerate(images)
|
) for i, img in enumerate(images)
|
||||||
]
|
]
|
||||||
@ -100,9 +99,9 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
|||||||
|
|
||||||
def update_temp_img(x_samples, task_temp_images: list):
|
def update_temp_img(x_samples, task_temp_images: list):
|
||||||
partial_images = []
|
partial_images = []
|
||||||
images = image_utils.latent_samples_to_images(context, x_samples)
|
images = latent_samples_to_images(context, x_samples)
|
||||||
for i, img in enumerate(images):
|
for i, img in enumerate(images):
|
||||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
buf = img_to_buffer(img, output_format='JPEG')
|
||||||
|
|
||||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||||
task_temp_images[i] = buf
|
task_temp_images[i] = buf
|
@ -6,7 +6,6 @@ Notes:
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
|
||||||
|
|
||||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||||
|
|
||||||
@ -14,10 +13,11 @@ import torch
|
|||||||
import queue, threading, time, weakref
|
import queue, threading, time, weakref
|
||||||
from typing import Any, Hashable
|
from typing import Any, Hashable
|
||||||
|
|
||||||
from sd_internal import TaskData, device_manager
|
from easydiffusion import device_manager
|
||||||
from sdkit.types import GenerateImageRequest
|
from easydiffusion.types import TaskData
|
||||||
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
log = logging.getLogger()
|
from sdkit.types import GenerateImageRequest
|
||||||
|
|
||||||
THREAD_NAME_PREFIX = ''
|
THREAD_NAME_PREFIX = ''
|
||||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||||
@ -186,7 +186,7 @@ class SessionState():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def thread_get_next_task():
|
def thread_get_next_task():
|
||||||
from sd_internal import renderer
|
from easydiffusion import renderer
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
||||||
return None
|
return None
|
||||||
@ -219,7 +219,7 @@ def thread_get_next_task():
|
|||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error
|
global current_state, current_state_error
|
||||||
|
|
||||||
from sd_internal import renderer, model_manager
|
from easydiffusion import renderer, model_manager
|
||||||
try:
|
try:
|
||||||
renderer.init(device)
|
renderer.init(device)
|
||||||
|
|
8
ui/easydiffusion/utils/__init__.py
Normal file
8
ui/easydiffusion/utils/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
log = logging.getLogger('easydiffusion')
|
||||||
|
|
||||||
|
from .save_utils import (
|
||||||
|
save_images_to_disk,
|
||||||
|
get_printable_request,
|
||||||
|
)
|
@ -3,11 +3,11 @@ import time
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from easydiffusion.types import TaskData
|
||||||
|
|
||||||
from sdkit.utils import save_images, save_dicts
|
from sdkit.utils import save_images, save_dicts
|
||||||
from sdkit.types import GenerateImageRequest
|
from sdkit.types import GenerateImageRequest
|
||||||
|
|
||||||
from sd_internal import TaskData
|
|
||||||
|
|
||||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||||
|
|
||||||
# keep in sync with `ui/media/js/dnd.js`
|
# keep in sync with `ui/media/js/dnd.js`
|
||||||
@ -28,9 +28,9 @@ TASK_TEXT_MAPPING = {
|
|||||||
'hypernetwork_strength': 'Hypernetwork Strength'
|
'hypernetwork_strength': 'Hypernetwork Strength'
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||||
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||||
metadata_entries = get_metadata_entries(req, task_data)
|
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
||||||
|
|
||||||
if task_data.show_only_filtered_image or filtered_images == images:
|
if task_data.show_only_filtered_image or filtered_images == images:
|
||||||
save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
@ -40,7 +40,7 @@ def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest,
|
|||||||
save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
|
save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
|
||||||
|
|
||||||
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
|
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
||||||
metadata = get_printable_request(req)
|
metadata = get_printable_request(req)
|
||||||
metadata.update({
|
metadata.update({
|
||||||
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
@ -4,7 +4,6 @@ Notes:
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
@ -13,12 +12,11 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sd_internal import app, model_manager, task_manager
|
from easydiffusion import app, model_manager, task_manager
|
||||||
from sd_internal import TaskData
|
from easydiffusion.types import TaskData
|
||||||
|
from easydiffusion.utils import log
|
||||||
from sdkit.types import GenerateImageRequest
|
from sdkit.types import GenerateImageRequest
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
log.info(f'started in {app.SD_DIR}')
|
log.info(f'started in {app.SD_DIR}')
|
||||||
log.info(f'started at {datetime.datetime.now():%x %X}')
|
log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user