diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index fb380695..795eea73 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -1,12 +1,11 @@ import os -from easydiffusion import app, device_manager +from easydiffusion import app from easydiffusion.types import TaskData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model -from sdkit.utils import hash_file_quick +from sdkit.models import load_model, unload_model, scan_model KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"] MODEL_EXTENSIONS = { @@ -102,6 +101,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): "hypernetwork": task_data.use_hypernetwork_model, "gfpgan": task_data.use_face_correction, "realesrgan": task_data.use_upscale, + "nsfw_checker": True if task_data.block_nsfw else None, } models_to_reload = { model_type: path diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index 3b8adaa7..2bca77e5 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -102,15 +102,20 @@ def generate_images_internal( def filter_images(task_data: TaskData, images: list, user_stopped): - if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): + if user_stopped: return images filters_to_apply = [] + if task_data.block_nsfw: + filters_to_apply.append("nsfw_checker") if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): filters_to_apply.append("gfpgan") if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): filters_to_apply.append("realesrgan") + if len(filters_to_apply) == 0: + return images + return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount) diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index b45eafd2..4437ee7f 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -38,6 +38,7 @@ class TaskData(BaseModel): use_hypernetwork_model: str = None show_only_filtered_image: bool = False + block_nsfw: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 metadata_output_format: str = "txt" # or "json" diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 95fa7f8a..5fe7bc7c 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -30,6 +30,7 @@ const SETTINGS_IDS_LIST = [ "gfpgan_model", "use_upscale", "upscale_amount", + "block_nsfw", "show_only_filtered_image", "upscale_model", "preview-image", diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js index 25c4ba2e..ae24e5b8 100644 --- a/ui/media/js/engine.js +++ b/ui/media/js/engine.js @@ -741,6 +741,7 @@ "stream_progress_updates": true, "stream_image_progress": true, "show_only_filtered_image": true, + "block_nsfw": false, "output_format": "png", "output_quality": 75, } diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 2d99e341..53088c61 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -43,6 +43,7 @@ let hypernetworkModelField = new ModelDropdown(document.querySelector('#hypernet let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider') let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength') let outputFormatField = document.querySelector('#output_format') +let blockNSFWField = document.querySelector('#block_nsfw') let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image") let updateBranchLabel = document.querySelector("#updateBranchLabel") let streamImageProgressField = document.querySelector("#stream_image_progress") @@ -967,6 +968,7 @@ function getCurrentUserRequest() { stream_progress_updates: true, stream_image_progress: (numOutputsTotal > 50 ? false : streamImageProgressField.checked), show_only_filtered_image: showOnlyFilteredImageField.checked, + block_nsfw: blockNSFWField.checked, output_format: outputFormatField.value, output_quality: parseInt(outputQualityField.value), metadata_output_format: metadataOutputFormatField.value, diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index c25b1b65..84b7b60f 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -78,6 +78,14 @@ var PARAMETERS = [ } ], }, + { + id: "block_nsfw", + type: ParameterType.checkbox, + label: "Block NSFW images", + note: "blurs out NSFW images", + icon: "fa-land-mine-on", + default: false, + }, { id: "sound_toggle", type: ParameterType.checkbox,