forked from extern/easydiffusion
Merge pull request #6 from caranicas/beta-react-parallel-logic
updated parallel logic
This commit is contained in:
commit
45e05b0891
@ -22,7 +22,10 @@ export const healthPing = async () => {
|
|||||||
* the local list of modifications
|
* the local list of modifications
|
||||||
*/
|
*/
|
||||||
export const loadModifications = async () => {
|
export const loadModifications = async () => {
|
||||||
const response = await fetch(`${API_URL}/modifiers.json`);
|
const url = `${API_URL}/modifiers.json`;
|
||||||
|
|
||||||
|
console.log('loadModifications', url);
|
||||||
|
const response = await fetch(url);
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
@ -40,6 +43,9 @@ export const getSaveDirectory = async () => {
|
|||||||
export const MakeImageKey = 'MakeImage';
|
export const MakeImageKey = 'MakeImage';
|
||||||
export const doMakeImage = async (reqBody: ImageRequest) => {
|
export const doMakeImage = async (reqBody: ImageRequest) => {
|
||||||
|
|
||||||
|
const {seed, num_outputs} = reqBody;
|
||||||
|
console.log('doMakeImage', seed, num_outputs);
|
||||||
|
|
||||||
const res = await fetch(`${API_URL}/image`, {
|
const res = await fetch(`${API_URL}/image`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
@ -23,8 +23,8 @@ const IMAGE_DIMENSIONS = [
|
|||||||
|
|
||||||
function SettingsList() {
|
function SettingsList() {
|
||||||
|
|
||||||
const requestCount = useImageCreate((state) => state.requestCount);
|
const parallelCount = useImageCreate((state) => state.parallelCount);
|
||||||
const setRequestCount = useImageCreate((state) => state.setRequestCount);
|
const setParallelCount = useImageCreate((state) => state.setParallelCount);
|
||||||
const setRequestOption = useImageCreate((state) => state.setRequestOptions);
|
const setRequestOption = useImageCreate((state) => state.setRequestOptions);
|
||||||
|
|
||||||
|
|
||||||
@ -134,9 +134,9 @@ function SettingsList() {
|
|||||||
Number of images to make:{" "}
|
Number of images to make:{" "}
|
||||||
<input
|
<input
|
||||||
type="number"
|
type="number"
|
||||||
value={requestCount}
|
value={num_outputs}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
setRequestCount(parseInt(e.target.value, 10))
|
setRequestOption("num_outputs", parseInt(e.target.value, 10))
|
||||||
}
|
}
|
||||||
size={4}
|
size={4}
|
||||||
/>
|
/>
|
||||||
@ -145,9 +145,9 @@ function SettingsList() {
|
|||||||
Generate in parallel:
|
Generate in parallel:
|
||||||
<input
|
<input
|
||||||
type="number"
|
type="number"
|
||||||
value={num_outputs}
|
value={parallelCount}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
setRequestOption("num_outputs", parseInt(e.target.value, 10))
|
setParallelCount(parseInt(e.target.value, 10))
|
||||||
}
|
}
|
||||||
size={4}
|
size={4}
|
||||||
/>
|
/>
|
||||||
|
@ -1,24 +1,71 @@
|
|||||||
import React, {useEffect, useState}from "react";
|
import React, {useEffect, useState}from "react";
|
||||||
|
|
||||||
import { useImageCreate } from "../../../store/imageCreateStore";
|
import { useImageCreate } from "../../../store/imageCreateStore";
|
||||||
// import { useImageDisplay } from "../../../store/imageDisplayStore";
|
|
||||||
import { useImageQueue } from "../../../store/imageQueueStore";
|
import { useImageQueue } from "../../../store/imageQueueStore";
|
||||||
// import { doMakeImage } from "../../../api";
|
|
||||||
import {v4 as uuidv4} from 'uuid';
|
import {v4 as uuidv4} from 'uuid';
|
||||||
|
|
||||||
|
import { useRandomSeed } from "../../../utils";
|
||||||
|
|
||||||
export default function MakeButton() {
|
export default function MakeButton() {
|
||||||
|
|
||||||
|
const parallelCount = useImageCreate((state) => state.parallelCount);
|
||||||
const builtRequest = useImageCreate((state) => state.builtRequest);
|
const builtRequest = useImageCreate((state) => state.builtRequest);
|
||||||
const addNewImage = useImageQueue((state) => state.addNewImage);
|
const addNewImage = useImageQueue((state) => state.addNewImage);
|
||||||
|
|
||||||
const makeImage = () => {
|
const makeImages = () => {
|
||||||
// todo turn this into a loop and adjust the parallel count
|
|
||||||
//
|
// the request that we have built
|
||||||
const req = builtRequest();
|
const req = builtRequest();
|
||||||
addNewImage(uuidv4(), req)
|
// the actual number of request we will make
|
||||||
|
let requests = [];
|
||||||
|
// the number of images we will make
|
||||||
|
let { num_outputs } = req;
|
||||||
|
|
||||||
|
// if making fewer images than the parallel count
|
||||||
|
// then it is only 1 request
|
||||||
|
if( parallelCount > num_outputs ) {
|
||||||
|
requests.push(num_outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
else {
|
||||||
|
// while we have at least 1 image to make
|
||||||
|
while (num_outputs >= 1) {
|
||||||
|
// subtract the parallel count from the number of images to make
|
||||||
|
num_outputs -= parallelCount;
|
||||||
|
|
||||||
|
// if we are still 0 or greater we can make the full parallel count
|
||||||
|
if(num_outputs <= 0) {
|
||||||
|
requests.push(parallelCount)
|
||||||
|
}
|
||||||
|
// otherwise we can only make the remaining images
|
||||||
|
else {
|
||||||
|
requests.push(Math.abs(num_outputs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// make the requests
|
||||||
|
requests.forEach((num, index) => {
|
||||||
|
|
||||||
|
// get the seed we want to use
|
||||||
|
let seed = req.seed;
|
||||||
|
if(index !== 0) {
|
||||||
|
// we want to use a random seed for subsequent requests
|
||||||
|
seed = useRandomSeed();
|
||||||
|
}
|
||||||
|
// add the request to the queue
|
||||||
|
addNewImage(uuidv4(), {
|
||||||
|
...req,
|
||||||
|
// updated the number of images to make
|
||||||
|
num_outputs: num,
|
||||||
|
// update the seed
|
||||||
|
seed: seed
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<button onClick={makeImage}>Make</button>
|
<button onClick={makeImages}>Make</button>
|
||||||
);
|
);
|
||||||
}
|
}
|
@ -14,7 +14,6 @@ export default function CurrentImage() {
|
|||||||
const {id, options} = useImageQueue((state) => state.firstInQueue());
|
const {id, options} = useImageQueue((state) => state.firstInQueue());
|
||||||
console.log('CurrentImage id', id)
|
console.log('CurrentImage id', id)
|
||||||
|
|
||||||
|
|
||||||
const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue);
|
const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue);
|
||||||
|
|
||||||
const { status, data } = useQuery(
|
const { status, data } = useQuery(
|
||||||
|
@ -35,11 +35,11 @@ export type ImageRequest = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
interface ImageCreateState {
|
interface ImageCreateState {
|
||||||
requestCount: number;
|
parallelCount: number;
|
||||||
requestOptions: ImageRequest;
|
requestOptions: ImageRequest;
|
||||||
tags: string[];
|
tags: string[];
|
||||||
|
|
||||||
setRequestCount: (count: number) => void;
|
setParallelCount: (count: number) => void;
|
||||||
setRequestOptions: (key: keyof ImageRequest, value: any) => void;
|
setRequestOptions: (key: keyof ImageRequest, value: any) => void;
|
||||||
getValueForRequestKey: (key: keyof ImageRequest) => any;
|
getValueForRequestKey: (key: keyof ImageRequest) => any;
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ interface ImageCreateState {
|
|||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
|
export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
|
||||||
|
|
||||||
requestCount: 1,
|
parallelCount: 1,
|
||||||
|
|
||||||
requestOptions:{
|
requestOptions:{
|
||||||
prompt: 'a photograph of an astronaut riding a horse',
|
prompt: 'a photograph of an astronaut riding a horse',
|
||||||
@ -90,8 +90,8 @@ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
|
|||||||
|
|
||||||
tags: [] as string[],
|
tags: [] as string[],
|
||||||
|
|
||||||
setRequestCount: (count: number) => set(produce((state) => {
|
setParallelCount: (count: number) => set(produce((state) => {
|
||||||
state.requestCount = count;
|
state.parallelCount = count;
|
||||||
})),
|
})),
|
||||||
|
|
||||||
setRequestOptions: (key: keyof ImageRequest, value: any) => {
|
setRequestOptions: (key: keyof ImageRequest, value: any) => {
|
||||||
@ -127,14 +127,12 @@ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
|
|||||||
// this is a computed value, just adding the tags to the request
|
// this is a computed value, just adding the tags to the request
|
||||||
builtRequest: () => {
|
builtRequest: () => {
|
||||||
|
|
||||||
console.log('builtRequest');
|
|
||||||
const state = get();
|
const state = get();
|
||||||
const requestOptions = state.requestOptions;
|
const requestOptions = state.requestOptions;
|
||||||
const tags = state.tags;
|
const tags = state.tags;
|
||||||
|
|
||||||
// join all the tags with a comma and add it to the prompt
|
// join all the tags with a comma and add it to the prompt
|
||||||
const prompt = `${requestOptions.prompt} ${tags.join(',')}`;
|
const prompt = `${requestOptions.prompt} ${tags.join(',')}`;
|
||||||
console.log('builtRequest return1');
|
|
||||||
|
|
||||||
const request = {
|
const request = {
|
||||||
...requestOptions,
|
...requestOptions,
|
||||||
@ -146,22 +144,15 @@ export const useImageCreate = create<ImageCreateState>(devtools((set, get) => ({
|
|||||||
// TODO check this
|
// TODO check this
|
||||||
request.save_to_disk_path = null;
|
request.save_to_disk_path = null;
|
||||||
}
|
}
|
||||||
console.log('builtRequest return2');
|
|
||||||
// if we arent using face correction clear the face correction
|
// if we arent using face correction clear the face correction
|
||||||
if(!state.uiOptions.isCheckUseFaceCorrection){
|
if(!state.uiOptions.isCheckUseFaceCorrection){
|
||||||
request.use_face_correction = null;
|
request.use_face_correction = null;
|
||||||
}
|
}
|
||||||
console.log('builtRequest return3');
|
|
||||||
// if we arent using upscaling clear the upscaling
|
// if we arent using upscaling clear the upscaling
|
||||||
if(!state.uiOptions.isCheckedUseUpscaling){
|
if(!state.uiOptions.isCheckedUseUpscaling){
|
||||||
request.use_upscale = null;
|
request.use_upscale = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// const request = {
|
|
||||||
// ...requestOptions,
|
|
||||||
// prompt
|
|
||||||
// }
|
|
||||||
console.log('builtRequest return last');
|
|
||||||
return request;
|
return request;
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -17,9 +17,10 @@ OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from starlette.responses import FileResponse, StreamingResponse
|
from starlette.responses import FileResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
|
||||||
# this is needed for development.
|
# this is needed for development.
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import logging
|
||||||
|
|
||||||
from sd_internal import Request, Response
|
from sd_internal import Request, Response
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -71,7 +72,6 @@ class SetAppConfigRequest(BaseModel):
|
|||||||
@app.get('/')
|
@app.get('/')
|
||||||
def read_root():
|
def read_root():
|
||||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||||
#return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
|
|
||||||
return FileResponse(os.path.join(SD_UI_DIR,'frontend/dist/index.html'), headers=headers)
|
return FileResponse(os.path.join(SD_UI_DIR,'frontend/dist/index.html'), headers=headers)
|
||||||
|
|
||||||
# then get the js files
|
# then get the js files
|
||||||
@ -84,6 +84,7 @@ def read_scripts():
|
|||||||
def read_styles():
|
def read_styles():
|
||||||
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.css'))
|
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/index.css'))
|
||||||
|
|
||||||
|
|
||||||
@app.get('/ping')
|
@app.get('/ping')
|
||||||
async def ping():
|
async def ping():
|
||||||
global model_loaded, model_is_loading
|
global model_loaded, model_is_loading
|
||||||
@ -216,7 +217,7 @@ def read_modifiers():
|
|||||||
|
|
||||||
@app.get('/modifiers.json')
|
@app.get('/modifiers.json')
|
||||||
def read_modifiers():
|
def read_modifiers():
|
||||||
return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'))
|
return FileResponse(os.path.join(SD_UI_DIR, 'frontend/dist/modifiers.json'))
|
||||||
|
|
||||||
@app.get('/output_dir')
|
@app.get('/output_dir')
|
||||||
def read_home_dir():
|
def read_home_dir():
|
||||||
|
Loading…
Reference in New Issue
Block a user