mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
UI for TensorRT installation and conversion
This commit is contained in:
parent
a9f1000af8
commit
2e3059a7c8
@ -18,7 +18,7 @@ os_name = platform.system()
|
||||
modules_to_check = {
|
||||
"torch": ("1.11.0", "1.13.1", "2.0.0"),
|
||||
"torchvision": ("0.12.0", "0.14.1", "0.15.1"),
|
||||
"sdkit": "1.0.143",
|
||||
"sdkit": "1.0.146",
|
||||
"stable-diffusion-sdkit": "2.1.4",
|
||||
"rich": "12.6.0",
|
||||
"uvicorn": "0.19.0",
|
||||
|
@ -32,6 +32,8 @@ logging.basicConfig(
|
||||
|
||||
SD_DIR = os.getcwd()
|
||||
|
||||
ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, ".."))
|
||||
|
||||
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||
|
93
ui/easydiffusion/package_manager.py
Normal file
93
ui/easydiffusion/package_manager.py
Normal file
@ -0,0 +1,93 @@
|
||||
import sys
|
||||
import os
|
||||
import platform
|
||||
from importlib.metadata import version as pkg_version
|
||||
|
||||
from sdkit.utils import log
|
||||
|
||||
from easydiffusion import app
|
||||
|
||||
# future home of scripts/check_modules.py
|
||||
|
||||
manifest = {
|
||||
"tensorrt": {
|
||||
"install": ["nvidia-cudnn", "tensorrt-libs", "tensorrt"],
|
||||
"uninstall": ["tensorrt"],
|
||||
# TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error)
|
||||
}
|
||||
}
|
||||
installing = []
|
||||
|
||||
# remove this once TRT releases on pypi
|
||||
if platform.system() == "Windows":
|
||||
trt_dir = os.path.join(app.ROOT_DIR, "tensorrt")
|
||||
if os.path.exists(trt_dir):
|
||||
files = os.listdir(trt_dir)
|
||||
|
||||
packages = manifest["tensorrt"]["install"]
|
||||
packages = tuple(p.replace("-", "_") for p in packages)
|
||||
|
||||
wheels = []
|
||||
for p in packages:
|
||||
f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None)
|
||||
if f:
|
||||
wheels.append(os.path.join(trt_dir, f))
|
||||
|
||||
manifest["tensorrt"]["install"] = wheels
|
||||
|
||||
|
||||
def get_installed_packages() -> list:
|
||||
return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)}
|
||||
|
||||
|
||||
def is_installed(module_name) -> bool:
|
||||
return version(module_name) is not None
|
||||
|
||||
|
||||
def install(module_name):
|
||||
if is_installed(module_name):
|
||||
log.info(f"{module_name} has already been installed!")
|
||||
return
|
||||
if module_name in installing:
|
||||
log.info(f"{module_name} is already installing!")
|
||||
return
|
||||
|
||||
if module_name not in manifest:
|
||||
raise RuntimeError(f"Can't install unknown package: {module_name}!")
|
||||
|
||||
commands = manifest[module_name]["install"]
|
||||
commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands]
|
||||
|
||||
installing.append(module_name)
|
||||
|
||||
try:
|
||||
for cmd in commands:
|
||||
print(">", cmd)
|
||||
if os.system(cmd) != 0:
|
||||
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
|
||||
finally:
|
||||
installing.remove(module_name)
|
||||
|
||||
|
||||
def uninstall(module_name):
|
||||
if not is_installed(module_name):
|
||||
log.info(f"{module_name} hasn't been installed!")
|
||||
return
|
||||
|
||||
if module_name not in manifest:
|
||||
raise RuntimeError(f"Can't uninstall unknown package: {module_name}!")
|
||||
|
||||
commands = manifest[module_name]["uninstall"]
|
||||
commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands]
|
||||
|
||||
for cmd in commands:
|
||||
print(">", cmd)
|
||||
if os.system(cmd) != 0:
|
||||
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
|
||||
|
||||
|
||||
def version(module_name: str) -> str:
|
||||
try:
|
||||
return pkg_version(module_name)
|
||||
except:
|
||||
return None
|
@ -8,7 +8,7 @@ import os
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion import app, model_manager, task_manager, package_manager
|
||||
from easydiffusion.tasks import RenderTask, FilterTask
|
||||
from easydiffusion.types import (
|
||||
GenerateImageRequest,
|
||||
@ -135,6 +135,10 @@ def init():
|
||||
def stop_cloudflare_tunnel(req: dict):
|
||||
return stop_cloudflare_tunnel_internal(req)
|
||||
|
||||
@server_api.post("/package/{package_name:str}")
|
||||
def modify_package(package_name: str, req: dict):
|
||||
return modify_package_internal(package_name, req)
|
||||
|
||||
@server_api.get("/")
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
|
||||
@ -226,16 +230,24 @@ def ping_internal(session_id: str = None):
|
||||
if task_manager.current_state_error:
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail="Render thread is dead.")
|
||||
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
|
||||
# Alive
|
||||
response = {"status": str(task_manager.current_state)}
|
||||
|
||||
if session_id:
|
||||
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||
response["tasks"] = {id(t): t.status for t in session.tasks}
|
||||
|
||||
response["devices"] = task_manager.get_devices()
|
||||
response["packages_installed"] = package_manager.get_installed_packages()
|
||||
response["packages_installing"] = package_manager.installing
|
||||
|
||||
if cloudflare.address != None:
|
||||
response["cloudflare"] = cloudflare.address
|
||||
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
|
||||
@ -423,3 +435,19 @@ def stop_cloudflare_tunnel_internal(req: dict):
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def modify_package_internal(package_name: str, req: dict):
|
||||
try:
|
||||
cmd = req["command"]
|
||||
if cmd not in ("install", "uninstall"):
|
||||
raise RuntimeError(f"Unknown command: {cmd}")
|
||||
|
||||
cmd = getattr(package_manager, cmd)
|
||||
cmd(package_name)
|
||||
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -373,6 +373,12 @@ def get_devices():
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
# temp until TRT releases
|
||||
import os
|
||||
from easydiffusion import app
|
||||
|
||||
devices["enable_trt"] = os.path.exists(os.path.join(app.ROOT_DIR, "tensorrt"))
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
|
@ -60,7 +60,11 @@ class RenderTask(Task):
|
||||
model_manager.resolve_model_paths(self.models_data)
|
||||
|
||||
models_to_force_reload = []
|
||||
if runtime.set_vram_optimizations(context) or self.has_clip_skip_changed(context):
|
||||
if (
|
||||
runtime.set_vram_optimizations(context)
|
||||
or self.has_param_changed(context, "clip_skip")
|
||||
or self.has_param_changed(context, "convert_to_tensorrt")
|
||||
):
|
||||
models_to_force_reload.append("stable-diffusion")
|
||||
|
||||
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
|
||||
@ -78,13 +82,15 @@ class RenderTask(Task):
|
||||
step_callback,
|
||||
)
|
||||
|
||||
def has_clip_skip_changed(self, context):
|
||||
def has_param_changed(self, context, param_name):
|
||||
if not context.test_diffusers:
|
||||
return False
|
||||
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
|
||||
return True
|
||||
|
||||
model = context.models["stable-diffusion"]
|
||||
new_clip_skip = self.models_data.model_params.get("stable-diffusion", {}).get("clip_skip", False)
|
||||
return model["clip_skip"] != new_clip_skip
|
||||
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
|
||||
return model["params"].get(param_name) != new_val
|
||||
|
||||
|
||||
def make_images(
|
||||
|
@ -217,7 +217,10 @@ def convert_legacy_render_req_to_new(old_req: dict):
|
||||
|
||||
# move the model params
|
||||
if model_paths["stable-diffusion"]:
|
||||
model_params["stable-diffusion"] = {"clip_skip": bool(old_req.get("clip_skip", False))}
|
||||
model_params["stable-diffusion"] = {
|
||||
"clip_skip": bool(old_req.get("clip_skip", False)),
|
||||
"convert_to_tensorrt": bool(old_req.get("convert_to_tensorrt", False)),
|
||||
}
|
||||
|
||||
# move the filter params
|
||||
if model_paths["realesrgan"]:
|
||||
|
@ -146,6 +146,14 @@
|
||||
<button id="reload-models" class="secondaryButton reloadModels"><i class='fa-solid fa-rotate'></i></button>
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
|
||||
</td></tr>
|
||||
<tr class="pl-5 displayNone" id="enable_trt_config">
|
||||
<td><label for="convert_to_tensorrt">Convert to TensorRT:</label></td>
|
||||
<td class="diffusers-restart-needed">
|
||||
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>
|
||||
<label><small>Takes upto 20 mins the first time</small></label>
|
||||
</td>
|
||||
</tr>
|
||||
<tr class="pl-5 displayNone" id="clip_skip_config">
|
||||
<td><label for="clip_skip">Clip Skip:</label></td>
|
||||
<td class="diffusers-restart-needed">
|
||||
@ -363,7 +371,7 @@
|
||||
<div id="install-extras-container" class="displayNone">
|
||||
<br/>
|
||||
<div id="install-extras">
|
||||
<h3><i class="fa fa-bolt"></i> Accelerate Easy Diffusion</h3>
|
||||
<h3><i class="fa fa-cubes-stacked"></i> Optional Packages</h3>
|
||||
<div class="parameters-table" id="system-settings-install-extras-table"></div>
|
||||
</div>
|
||||
</div>
|
||||
@ -677,7 +685,7 @@ async function init() {
|
||||
events: {
|
||||
statusChange: setServerStatus,
|
||||
idle: onIdle,
|
||||
ping: tunnelUpdate
|
||||
ping: onPing
|
||||
}
|
||||
})
|
||||
splashScreen()
|
||||
|
@ -275,24 +275,24 @@ function setServerStatus(event) {
|
||||
// e : MouseEvent
|
||||
// prompt : Text to be shown as prompt. Should be a question to which "yes" is a good answer.
|
||||
// fn : function to be called if the user confirms the dialog or has the shift key pressed
|
||||
// allowSkip: Allow skipping the dialog using the shift key or the confirm_dangerous_actions setting (default: true)
|
||||
//
|
||||
// If the user had the shift key pressed while clicking, the function fn will be executed.
|
||||
// If the setting "confirm_dangerous_actions" in the system settings is disabled, the function
|
||||
// fn will be executed.
|
||||
// Otherwise, a confirmation dialog is shown. If the user confirms, the function fn will also
|
||||
// be executed.
|
||||
function shiftOrConfirm(e, prompt, fn) {
|
||||
function shiftOrConfirm(e, prompt, fn, allowSkip = true) {
|
||||
e.stopPropagation()
|
||||
if (e.shiftKey || !confirmDangerousActionsField.checked) {
|
||||
let tip = allowSkip
|
||||
? '<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>'
|
||||
: ""
|
||||
if (allowSkip && (e.shiftKey || !confirmDangerousActionsField.checked)) {
|
||||
fn(e)
|
||||
} else {
|
||||
confirm(
|
||||
'<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>',
|
||||
prompt,
|
||||
() => {
|
||||
fn(e)
|
||||
}
|
||||
)
|
||||
confirm(tip, prompt, () => {
|
||||
fn(e)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1417,6 +1417,11 @@ function getCurrentUserRequest() {
|
||||
newTask.reqBody.lora_alpha = modelStrengths
|
||||
}
|
||||
}
|
||||
if (testDiffusers.checked && document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
|
||||
// TRT is installed
|
||||
newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked
|
||||
}
|
||||
|
||||
return newTask
|
||||
}
|
||||
|
||||
@ -2217,6 +2222,11 @@ resumeBtn.addEventListener("click", function() {
|
||||
document.body.classList.remove("wait-pause")
|
||||
})
|
||||
|
||||
function onPing(event) {
|
||||
tunnelUpdate(event)
|
||||
packagesUpdate(event)
|
||||
}
|
||||
|
||||
function tunnelUpdate(event) {
|
||||
if ("cloudflare" in event) {
|
||||
document.getElementById("cloudflare-off").classList.add("displayNone")
|
||||
@ -2230,6 +2240,23 @@ function tunnelUpdate(event) {
|
||||
}
|
||||
}
|
||||
|
||||
function packagesUpdate(event) {
|
||||
let trtBtn = document.getElementById("toggle-tensorrt-install")
|
||||
let trtInstalled = "packages_installed" in event && "tensorrt" in event["packages_installed"]
|
||||
|
||||
if ("packages_installing" in event && event["packages_installing"].includes("tensorrt")) {
|
||||
trtBtn.innerHTML = "Installing.."
|
||||
trtBtn.disabled = true
|
||||
} else {
|
||||
trtBtn.innerHTML = trtInstalled ? "Uninstall" : "Install"
|
||||
trtBtn.disabled = false
|
||||
}
|
||||
|
||||
if (document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
|
||||
document.querySelector("#enable_trt_config").classList.remove("displayNone")
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", async function() {
|
||||
let command = "stop"
|
||||
if (document.getElementById("toggle-cloudflare-tunnel").innerHTML == "Start") {
|
||||
@ -2249,6 +2276,63 @@ document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", as
|
||||
console.log(`Cloudflare tunnel ${command} result:`, res)
|
||||
})
|
||||
|
||||
document.getElementById("toggle-tensorrt-install").addEventListener("click", function(e) {
|
||||
if (this.disabled === true) {
|
||||
return
|
||||
}
|
||||
|
||||
let command = this.innerHTML.toLowerCase()
|
||||
let self = this
|
||||
|
||||
shiftOrConfirm(
|
||||
e,
|
||||
"Are you sure you want to " + command + " TensorRT?",
|
||||
async function() {
|
||||
showToast(`TensorRT ${command} started. Please wait.`)
|
||||
|
||||
self.disabled = true
|
||||
|
||||
if (command === "install") {
|
||||
self.innerHTML = "Installing.."
|
||||
} else if (command === "uninstall") {
|
||||
self.innerHTML = "Uninstalling.."
|
||||
}
|
||||
|
||||
if (command === "installing..") {
|
||||
alert("Already installing TensorRT!")
|
||||
return
|
||||
}
|
||||
if (command !== "install" && command !== "uninstall") {
|
||||
return
|
||||
}
|
||||
|
||||
let res = await fetch("/package/tensorrt", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
command: command,
|
||||
}),
|
||||
})
|
||||
res = await res.json()
|
||||
|
||||
self.disabled = false
|
||||
|
||||
if (res.status === "OK") {
|
||||
alert("TensorRT " + command + "ed successfully!")
|
||||
self.innerHTML = command === "install" ? "Uninstall" : "Install"
|
||||
} else if (res.status_code === 500) {
|
||||
alert("TensorselfRT failed to " + command + ": " + res.detail)
|
||||
self.innerHTML = command === "install" ? "Install" : "Uninstall"
|
||||
}
|
||||
|
||||
console.log(`Package ${command} result:`, res)
|
||||
},
|
||||
false
|
||||
)
|
||||
})
|
||||
|
||||
/* Embeddings */
|
||||
|
||||
function updateEmbeddingsList(filter = "") {
|
||||
|
@ -247,10 +247,10 @@ var PARAMETERS = [
|
||||
type: ParameterType.custom,
|
||||
label: "NVIDIA TensorRT",
|
||||
note: `Faster image generation by converting your Stable Diffusion models to the NVIDIA TensorRT format. You can choose the
|
||||
models to convert. Requires an NVIDIA graphics card.<br/><br/>
|
||||
models to convert. Download size: approximately 2 GB.<br/><br/>
|
||||
<b>Early access version:</b> support for LoRA is still under development.`,
|
||||
icon: "fa-angles-up",
|
||||
render: () => '<button id="install-tensorrt" class="primaryButton">Install</button>',
|
||||
render: () => '<button id="toggle-tensorrt-install" class="primaryButton">Install</button>',
|
||||
table: installExtrasTable,
|
||||
},
|
||||
]
|
||||
@ -596,12 +596,10 @@ function setDeviceInfo(devices) {
|
||||
systemInfoEl.querySelector("#system-info-rendering-devices").innerHTML = activeGPUs.join("</br>")
|
||||
|
||||
// tensorRT
|
||||
if (devices.active) {
|
||||
console.log(devices.active)
|
||||
if (devices.active && testDiffusers.checked && devices.enable_trt === true) {
|
||||
let nvidiaGPUs = Object.keys(devices.active).filter((d) => {
|
||||
let gpuName = devices.active[d].name
|
||||
gpuName = gpuName.toLowerCase()
|
||||
console.log(gpuName)
|
||||
return (
|
||||
gpuName.includes("nvidia") ||
|
||||
gpuName.includes("geforce") ||
|
||||
|
Loading…
Reference in New Issue
Block a user