Make stream_image_progress accept an integer for the rate the progress frames should be generated. (#889)

* Make stream_image_progress accept an integer

for the rate the progress frames should be generated.

* Use a different field for the progress interval.
This commit is contained in:
ayunami2000 2023-02-20 22:08:21 -05:00 committed by GitHub
parent 2f0e8a8a4a
commit e25e1bfe10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 3 deletions

View File

@ -55,8 +55,9 @@ def print_task_info(req: GenerateImageRequest, task_data: TaskData):
def make_images_internal(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
):
images, user_stopped = generate_images_internal(
req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress
req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval
)
filtered_images = filter_images(task_data, images, user_stopped)
@ -77,10 +78,11 @@ def generate_images_internal(
task_temp_images: list,
step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
):
context.temp_images.clear()
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress, stream_image_progress_interval)
try:
if req.init_image is not None:
@ -136,6 +138,7 @@ def make_step_callback(
task_temp_images: list,
step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
@ -161,7 +164,7 @@ def make_step_callback(
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and i % 5 == 0:
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
progress["output"] = update_temp_img(x_samples, task_temp_images)
data_queue.put(json.dumps(progress))

View File

@ -43,6 +43,7 @@ class TaskData(BaseModel):
output_quality: int = 75
metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False
stream_image_progress_interval: int = 5
class MergeRequest(BaseModel):