Merge pull request #6 from caranicas/beta-react-parallel-logic

updated parallel logic
This commit is contained in:
caranicas 2022-09-14 13:25:53 -04:00 committed by GitHub
commit 45e05b0891
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 33 deletions

View File

@ -22,7 +22,10 @@ export const healthPing = async () => {
* the local list of modifications * the local list of modifications
*/ */
export const loadModifications = async () => { export const loadModifications = async () => {
const response = await fetch(`${API_URL}/modifiers.json`); const url = `${API_URL}/modifiers.json`;
console.log('loadModifications', url);
const response = await fetch(url);
const data = await response.json(); const data = await response.json();
return data; return data;
} }
@ -40,6 +43,9 @@ export const getSaveDirectory = async () => {
export const MakeImageKey = 'MakeImage'; export const MakeImageKey = 'MakeImage';
export const doMakeImage = async (reqBody: ImageRequest) => { export const doMakeImage = async (reqBody: ImageRequest) => {
const {seed, num_outputs} = reqBody;
console.log('doMakeImage', seed, num_outputs);
const res = await fetch(`${API_URL}/image`, { const res = await fetch(`${API_URL}/image`, {
method: 'POST', method: 'POST',
headers: { headers: {

View File

@ -23,8 +23,8 @@ const IMAGE_DIMENSIONS = [
function SettingsList() { function SettingsList() {
const requestCount = useImageCreate((state) => state.requestCount); const parallelCount = useImageCreate((state) => state.parallelCount);
const setRequestCount = useImageCreate((state) => state.setRequestCount); const setParallelCount = useImageCreate((state) => state.setParallelCount);
const setRequestOption = useImageCreate((state) => state.setRequestOptions); const setRequestOption = useImageCreate((state) => state.setRequestOptions);
@ -134,9 +134,9 @@ function SettingsList() {
Number of images to make:{" "} Number of images to make:{" "}
<input <input
type="number" type="number"
value={requestCount} value={num_outputs}
onChange={(e) => onChange={(e) =>
setRequestCount(parseInt(e.target.value, 10)) setRequestOption("num_outputs", parseInt(e.target.value, 10))
} }
size={4} size={4}
/> />
@ -145,9 +145,9 @@ function SettingsList() {
Generate in parallel: Generate in parallel:
<input <input
type="number" type="number"
value={num_outputs} value={parallelCount}
onChange={(e) => onChange={(e) =>
setRequestOption("num_outputs", parseInt(e.target.value, 10)) setParallelCount(parseInt(e.target.value, 10))
} }
size={4} size={4}
/> />

View File

@ -1,24 +1,71 @@
import React, {useEffect, useState}from "react"; import React, {useEffect, useState}from "react";
import { useImageCreate } from "../../../store/imageCreateStore"; import { useImageCreate } from "../../../store/imageCreateStore";
// import { useImageDisplay } from "../../../store/imageDisplayStore";
import { useImageQueue } from "../../../store/imageQueueStore"; import { useImageQueue } from "../../../store/imageQueueStore";
// import { doMakeImage } from "../../../api";
import {v4 as uuidv4} from 'uuid'; import {v4 as uuidv4} from 'uuid';
import { useRandomSeed } from "../../../utils";
export default function MakeButton() { export default function MakeButton() {
const parallelCount = useImageCreate((state) => state.parallelCount);
const builtRequest = useImageCreate((state) => state.builtRequest); const builtRequest = useImageCreate((state) => state.builtRequest);
const addNewImage = useImageQueue((state) => state.addNewImage); const addNewImage = useImageQueue((state) => state.addNewImage);
const makeImage = () => { const makeImages = () => {
// todo turn this into a loop and adjust the parallel count
// // the request that we have built
const req = builtRequest(); const req = builtRequest();
addNewImage(uuidv4(), req) // the actual number of request we will make
let requests = [];
// the number of images we will make
let { num_outputs } = req;
// if making fewer images than the parallel count
// then it is only 1 request
if( parallelCount > num_outputs ) {
requests.push(num_outputs);
}
else {
// while we have at least 1 image to make
while (num_outputs >= 1) {
// subtract the parallel count from the number of images to make
num_outputs -= parallelCount;
// if we are still 0 or greater we can make the full parallel count
if(num_outputs <= 0) {
requests.push(parallelCount)
}
// otherwise we can only make the remaining images
else {
requests.push(Math.abs(num_outputs))
}
}
}
// make the requests
requests.forEach((num, index) => {
// get the seed we want to use
let seed = req.seed;
if(index !== 0) {
// we want to use a random seed for subsequent requests
seed = useRandomSeed();
}
// add the request to the queue
addNewImage(uuidv4(), {
...req,
// updated the number of images to make
num_outputs: num,
// update the seed
seed: seed
})
});
}; };
return ( return (
<button onClick={makeImage}>Make</button> <button onClick={makeImages}>Make</button>
); );
} }

View File

@ -14,7 +14,6 @@ export default function CurrentImage() {
const {id, options} = useImageQueue((state) => state.firstInQueue()); const {id, options} = useImageQueue((state) => state.firstInQueue());
console.log('CurrentImage id', id) console.log('CurrentImage id', id)
const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue); const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue);
const { status, data } = useQuery( const { status, data } = useQuery(

View File

@ -35,11 +35,11 @@ export type ImageRequest = {
}; };
interface ImageCreateState { interface ImageCreateState {
requestCount: number; parallelCount: number;
requestOptions: ImageRequest; requestOptions: ImageRequest;
tags: string[]; tags: string[];
setRequestCount: (count: number) => void; setParallelCount: (count: number) => void;
setRequestOptions: (key: keyof ImageRequest, value: any) => void; setRequestOptions: (key: keyof ImageRequest, value: any) => void;
getValueForRequestKey: (key: keyof ImageRequest) => any; getValueForRequestKey: (key: keyof ImageRequest) => any;
@ -67,7 +67,7 @@ interface ImageCreateState {
// @ts-ignore // @ts-ignore
export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
requestCount: 1, parallelCount: 1,
requestOptions:{ requestOptions:{
prompt: 'a photograph of an astronaut riding a horse', prompt: 'a photograph of an astronaut riding a horse',
@ -90,8 +90,8 @@ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
tags: [] as string[], tags: [] as string[],
setRequestCount: (count: number) => set(produce((state) => { setParallelCount: (count: number) => set(produce((state) => {
state.requestCount = count; state.parallelCount = count;
})), })),
setRequestOptions: (key: keyof ImageRequest, value: any) => { setRequestOptions: (key: keyof ImageRequest, value: any) => {
@ -127,14 +127,12 @@ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
// this is a computed value, just adding the tags to the request // this is a computed value, just adding the tags to the request
builtRequest: () => { builtRequest: () => {
console.log('builtRequest');
const state = get(); const state = get();
const requestOptions = state.requestOptions; const requestOptions = state.requestOptions;
const tags = state.tags; const tags = state.tags;
// join all the tags with a comma and add it to the prompt // join all the tags with a comma and add it to the prompt
const prompt = `${requestOptions.prompt} ${tags.join(',')}`; const prompt = `${requestOptions.prompt} ${tags.join(',')}`;
console.log('builtRequest return1');
const request = { const request = {
...requestOptions, ...requestOptions,
@ -146,22 +144,15 @@ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
// TODO check this // TODO check this
request.save_to_disk_path = null; request.save_to_disk_path = null;
} }
console.log('builtRequest return2');
// if we arent using face correction clear the face correction // if we arent using face correction clear the face correction
if(!state.uiOptions.isCheckUseFaceCorrection){ if(!state.uiOptions.isCheckUseFaceCorrection){
request.use_face_correction = null; request.use_face_correction = null;
} }
console.log('builtRequest return3');
// if we arent using upscaling clear the upscaling // if we arent using upscaling clear the upscaling
if(!state.uiOptions.isCheckedUseUpscaling){ if(!state.uiOptions.isCheckedUseUpscaling){
request.use_upscale = null; request.use_upscale = null;
} }
// const request = {
// ...requestOptions,
// prompt
// }
console.log('builtRequest return last');
return request; return request;
}, },

View File

@ -17,9 +17,10 @@ OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse, StreamingResponse from starlette.responses import FileResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging
# this is needed for development. # this is needed for development.
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import logging
from sd_internal import Request, Response from sd_internal import Request, Response
app = FastAPI() app = FastAPI()
@ -71,7 +72,6 @@ class SetAppConfigRequest(BaseModel):
@app.get('/') @app.get('/')
def read_root(): def read_root():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
#return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
return FileResponse(os.path.join(SD_UI_DIR,'frontend/dist/index.html'), headers=headers) return FileResponse(os.path.join(SD_UI_DIR,'frontend/dist/index.html'), headers=headers)
# then get the js files # then get the js files
@ -84,6 +84,7 @@ def read_scripts():
def read_styles(): def read_styles():
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.css')) return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.css'))
@app.get('/ping') @app.get('/ping')
async def ping(): async def ping():
global model_loaded, model_is_loading global model_loaded, model_is_loading
@ -216,7 +217,7 @@ def read_modifiers():
@app.get('/modifiers.json') @app.get('/modifiers.json')
def read_modifiers(): def read_modifiers():
return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json')) return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/modifiers.json'))
@app.get('/output_dir') @app.get('/output_dir')
def read_home_dir(): def read_home_dir():