diff --git a/ui/frontend/build_src/src/api/index.ts b/ui/frontend/build_src/src/api/index.ts index 113d1359..d12a88ba 100644 --- a/ui/frontend/build_src/src/api/index.ts +++ b/ui/frontend/build_src/src/api/index.ts @@ -22,7 +22,10 @@ export const healthPing = async () => { * the local list of modifications */ export const loadModifications = async () => { - const response = await fetch(`${API_URL}/modifiers.json`); + const url = `${API_URL}/modifiers.json`; + + console.log('loadModifications', url); + const response = await fetch(url); const data = await response.json(); return data; } @@ -40,6 +43,9 @@ export const getSaveDirectory = async () => { export const MakeImageKey = 'MakeImage'; export const doMakeImage = async (reqBody: ImageRequest) => { + const {seed, num_outputs} = reqBody; + console.log('doMakeImage', seed, num_outputs); + const res = await fetch(`${API_URL}/image`, { method: 'POST', headers: { diff --git a/ui/frontend/build_src/src/components/creationPanel/advancedSettings/index.tsx b/ui/frontend/build_src/src/components/creationPanel/advancedSettings/index.tsx index 7519a30f..45eb3535 100644 --- a/ui/frontend/build_src/src/components/creationPanel/advancedSettings/index.tsx +++ b/ui/frontend/build_src/src/components/creationPanel/advancedSettings/index.tsx @@ -23,8 +23,8 @@ const IMAGE_DIMENSIONS = [ function SettingsList() { - const requestCount = useImageCreate((state) => state.requestCount); - const setRequestCount = useImageCreate((state) => state.setRequestCount); + const parallelCount = useImageCreate((state) => state.parallelCount); + const setParallelCount = useImageCreate((state) => state.setParallelCount); const setRequestOption = useImageCreate((state) => state.setRequestOptions); @@ -134,9 +134,9 @@ function SettingsList() { Number of images to make:{" "} - setRequestCount(parseInt(e.target.value, 10)) + setRequestOption("num_outputs", parseInt(e.target.value, 10)) } size={4} /> @@ -145,9 +145,9 @@ function SettingsList() { Generate in parallel: - setRequestOption("num_outputs", parseInt(e.target.value, 10)) + setParallelCount(parseInt(e.target.value, 10)) } size={4} /> diff --git a/ui/frontend/build_src/src/components/creationPanel/makeButton/index.tsx b/ui/frontend/build_src/src/components/creationPanel/makeButton/index.tsx index 3afeb625..6a74433b 100644 --- a/ui/frontend/build_src/src/components/creationPanel/makeButton/index.tsx +++ b/ui/frontend/build_src/src/components/creationPanel/makeButton/index.tsx @@ -1,24 +1,71 @@ import React, {useEffect, useState}from "react"; import { useImageCreate } from "../../../store/imageCreateStore"; -// import { useImageDisplay } from "../../../store/imageDisplayStore"; import { useImageQueue } from "../../../store/imageQueueStore"; -// import { doMakeImage } from "../../../api"; import {v4 as uuidv4} from 'uuid'; +import { useRandomSeed } from "../../../utils"; + export default function MakeButton() { + const parallelCount = useImageCreate((state) => state.parallelCount); const builtRequest = useImageCreate((state) => state.builtRequest); const addNewImage = useImageQueue((state) => state.addNewImage); - const makeImage = () => { - // todo turn this into a loop and adjust the parallel count - // - const req = builtRequest(); - addNewImage(uuidv4(), req) + const makeImages = () => { + + // the request that we have built + const req = builtRequest(); + // the actual number of request we will make + let requests = []; + // the number of images we will make + let { num_outputs } = req; + + // if making fewer images than the parallel count + // then it is only 1 request + if( parallelCount > num_outputs ) { + requests.push(num_outputs); + } + + else { + // while we have at least 1 image to make + while (num_outputs >= 1) { + // subtract the parallel count from the number of images to make + num_outputs -= parallelCount; + + // if we are still 0 or greater we can make the full parallel count + if(num_outputs <= 0) { + requests.push(parallelCount) + } + // otherwise we can only make the remaining images + else { + requests.push(Math.abs(num_outputs)) + } + } + } + + // make the requests + requests.forEach((num, index) => { + + // get the seed we want to use + let seed = req.seed; + if(index !== 0) { + // we want to use a random seed for subsequent requests + seed = useRandomSeed(); + } + // add the request to the queue + addNewImage(uuidv4(), { + ...req, + // updated the number of images to make + num_outputs: num, + // update the seed + seed: seed + }) + }); + }; return ( - + ); } \ No newline at end of file diff --git a/ui/frontend/build_src/src/components/displayPanel/currentImage/index.tsx b/ui/frontend/build_src/src/components/displayPanel/currentImage/index.tsx index 0afcb601..a655a350 100644 --- a/ui/frontend/build_src/src/components/displayPanel/currentImage/index.tsx +++ b/ui/frontend/build_src/src/components/displayPanel/currentImage/index.tsx @@ -14,7 +14,6 @@ export default function CurrentImage() { const {id, options} = useImageQueue((state) => state.firstInQueue()); console.log('CurrentImage id', id) - const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue); const { status, data } = useQuery( diff --git a/ui/frontend/build_src/src/store/imageCreateStore.ts b/ui/frontend/build_src/src/store/imageCreateStore.ts index 55bfec7c..493746e9 100644 --- a/ui/frontend/build_src/src/store/imageCreateStore.ts +++ b/ui/frontend/build_src/src/store/imageCreateStore.ts @@ -35,11 +35,11 @@ export type ImageRequest = { }; interface ImageCreateState { - requestCount: number; + parallelCount: number; requestOptions: ImageRequest; tags: string[]; - setRequestCount: (count: number) => void; + setParallelCount: (count: number) => void; setRequestOptions: (key: keyof ImageRequest, value: any) => void; getValueForRequestKey: (key: keyof ImageRequest) => any; @@ -67,7 +67,7 @@ interface ImageCreateState { // @ts-ignore export const useImageCreate = create(devtools((set, get) => ({ - requestCount: 1, + parallelCount: 1, requestOptions:{ prompt: 'a photograph of an astronaut riding a horse', @@ -90,8 +90,8 @@ export const useImageCreate = create(devtools((set, get) => ({ tags: [] as string[], - setRequestCount: (count: number) => set(produce((state) => { - state.requestCount = count; + setParallelCount: (count: number) => set(produce((state) => { + state.parallelCount = count; })), setRequestOptions: (key: keyof ImageRequest, value: any) => { @@ -127,14 +127,12 @@ export const useImageCreate = create(devtools((set, get) => ({ // this is a computed value, just adding the tags to the request builtRequest: () => { - console.log('builtRequest'); const state = get(); const requestOptions = state.requestOptions; const tags = state.tags; // join all the tags with a comma and add it to the prompt const prompt = `${requestOptions.prompt} ${tags.join(',')}`; - console.log('builtRequest return1'); const request = { ...requestOptions, @@ -146,22 +144,15 @@ export const useImageCreate = create(devtools((set, get) => ({ // TODO check this request.save_to_disk_path = null; } - console.log('builtRequest return2'); // if we arent using face correction clear the face correction if(!state.uiOptions.isCheckUseFaceCorrection){ request.use_face_correction = null; } - console.log('builtRequest return3'); // if we arent using upscaling clear the upscaling if(!state.uiOptions.isCheckedUseUpscaling){ request.use_upscale = null; } - // const request = { - // ...requestOptions, - // prompt - // } - console.log('builtRequest return last'); return request; }, diff --git a/ui/server.py b/ui/server.py index e3b2fba0..1e903875 100644 --- a/ui/server.py +++ b/ui/server.py @@ -17,9 +17,10 @@ OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder from fastapi import FastAPI, HTTPException from starlette.responses import FileResponse, StreamingResponse from pydantic import BaseModel -import logging # this is needed for development. from fastapi.middleware.cors import CORSMiddleware +import logging + from sd_internal import Request, Response app = FastAPI() @@ -71,7 +72,6 @@ class SetAppConfigRequest(BaseModel): @app.get('/') def read_root(): headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - #return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers) return FileResponse(os.path.join(SD_UI_DIR,'frontend/dist/index.html'), headers=headers) # then get the js files @@ -84,6 +84,7 @@ def read_scripts(): def read_styles(): return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.css')) + @app.get('/ping') async def ping(): global model_loaded, model_is_loading @@ -216,7 +217,7 @@ def read_modifiers(): @app.get('/modifiers.json') def read_modifiers(): - return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json')) + return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/modifiers.json')) @app.get('/output_dir') def read_home_dir():