workable plan

This commit is contained in:
caranicas 2022-09-24 13:47:30 -04:00
parent d1e792686e
commit 22a0b3be45
11 changed files with 370 additions and 288 deletions

View File

@ -1,28 +0,0 @@
import { useInfiniteQuery } from "@tanstack/react-query";
export const useAyncGeneratorQuery = (key: string, asyncGeneratorFn: () => Generator<string[], string[], unknown>) => {
const queryFn = React.useCallback((_: any, params: CustomQueryParams) => {
if (!params) { return Promise.resolve([]) }
return Promise.resolve(params.eventData);
}, [])
const { data, fetchMore } = useInfiniteQuery<string[], string, CustomQueryParams>(key, queryFn as any, { getFetchMore: () => ({ eventData: [] }) })
const customStatus = React.useRef<Status>('success');
React.useEffect(() => {
(async function doReceive() {
try {
for await (let data of asyncGeneratorFn()) {
fetchMore({ eventData: data });
}
} catch (e) {
customStatus.current = 'error';
}
})();
}, [asyncGeneratorFn, fetchMore])
return { status: customStatus.current, data };
}

View File

@ -1,21 +0,0 @@
type CallbackType = (data: string[]) => void;
type InitCallbackType = (cb: CallbackType) => void;
export const useCallbackQuery = (key: string, initCallbackQuery: InitCallbackType) => {
const queryFn = React.useCallback((_: any, params: CustomQueryParams) => {
if (!params) { return Promise.resolve([]) }
return Promise.resolve(params.eventData);
}, [])
const { data, fetchMore } = useInfiniteQuery<string[], string, CustomQueryParams>(key, queryFn as any, { getFetchMore: () => ({ eventData: [] }) })
const callback = React.useCallback((data) => { fetchMore({ eventData: data }) }, [fetchMore]);
React.useEffect(() => { initCallbackQuery(callback); }, [callback, initCallbackQuery])
const customStatus = React.useRef<Status>('success');
return { status: customStatus.current, data };
}

View File

@ -1,19 +0,0 @@
import { useQuery, useQueryClient } from 'react-query'
export const useEventSourceQuery = (queryKey, url, eventName) => {
const queryClient = useQueryClient()
const fetchData = () => {
const evtSource = new EventSource(url)
evtSource.addEventListener(eventName, (event) => {
const eventData = event.data && JSON.parse(event.data)
if (eventData) {
queryClient.setQueryData(queryKey, eventData)
}
})
}
return useQuery(queryKey, fetchData)
}

View File

@ -1,47 +0,0 @@
import * as React from 'react';
import { useInfiniteQuery } from 'react-query';
interface CustomQueryParams {
eventData: string[];
}
type Status = "success" | "loading" | "error";
export const useEventSourceQuery = (key: string, url: string, eventName: string) => {
const eventSource = React.useRef<EventSource>(new EventSource(url));
const queryFn = React.useCallback((_: any, params: CustomQueryParams) => {
if (!params) { return Promise.resolve([]) }
return Promise.resolve(params.eventData);
}, [])
const { data, fetchMore } = useInfiniteQuery<string[], string, CustomQueryParams>(key, queryFn as any, { getFetchMore: () => ({ eventData: [] }) })
const customStatus = React.useRef<Status>('success');
React.useEffect(() => {
const evtSource = eventSource.current;
const onEvent = function (ev: MessageEvent | Event) {
if (!e.data) {
return;
}
// Let's assume here we receive multiple data, ie. e.data is an array.
fetchMore({ eventData: e.data });
}
const onError = () => { customStatus.current = 'error' };
evtSource.addEventListener(eventName, onEvent);
evtSource.addEventListener('error', onError);
return () => {
evtSource.removeEventListener(eventName, onEvent);
evtSource.removeEventListener('error', onError);
}
}, [url, eventName, fetchMore])
return { status: customStatus.current, data };
}

View File

@ -79,7 +79,9 @@ export const doMakeImage = async (reqBody: ImageRequest) => {
}, },
body: JSON.stringify(reqBody), body: JSON.stringify(reqBody),
}); });
console.log('doMakeImage= GOT RESPONSE', res);
const data = await res.json(); // const data = await res.json();
return data; // return data;
return res;
}; };

View File

@ -1,5 +1,5 @@
import React from "react"; import React from "react";
import { API_URL } from "../../../../api"; import { API_URL } from "../../../../../../api";
const url = `${API_URL}/ding.mp3`; const url = `${API_URL}/ding.mp3`;

View File

@ -1,49 +1,108 @@
/* eslint-disable @typescript-eslint/no-unnecessary-type-assertion */
/* eslint-disable @typescript-eslint/prefer-ts-expect-error */
/* eslint-disable @typescript-eslint/naming-convention */ /* eslint-disable @typescript-eslint/naming-convention */
import React from "react"; import React, { useEffect } from "react";
import { useImageCreate, ImageRequest } from "../../../../../stores/imageCreateStore"; import { useImageCreate, ImageRequest } from "../../../../../stores/imageCreateStore";
import { useImageQueue } from "../../../../../stores/imageQueueStore"; import { useImageQueue } from "../../../../../stores/imageQueueStore";
import {
FetchingStates,
useImageFetching
} from "../../../../../stores/imageFetchingStore";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
import { useRandomSeed } from "../../../../../utils"; import { useRandomSeed } from "../../../../../utils";
import { doMakeImage } from "../../../../../api";
import { import {
MakeButtonStyle, // @ts-expect-error MakeButtonStyle, // @ts-expect-error
} from "./makeButton.css.ts"; } from "./makeButton.css.ts";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import AudioDing from "./audioDing";
import { parse } from "node:path/win32";
export default function MakeButton() { export default function MakeButton() {
const { t } = useTranslation(); const { t } = useTranslation();
const parallelCount = useImageCreate((state) => state.parallelCount); const parallelCount = useImageCreate((state) => state.parallelCount);
const builtRequest = useImageCreate((state) => state.builtRequest); const builtRequest = useImageCreate((state) => state.builtRequest);
const addNewImage = useImageQueue((state) => state.addNewImage);
const hasQueue = useImageQueue((state) => state.hasQueuedImages());
const isRandomSeed = useImageCreate((state) => state.isRandomSeed()); const isRandomSeed = useImageCreate((state) => state.isRandomSeed());
const setRequestOption = useImageCreate((state) => state.setRequestOptions); const setRequestOption = useImageCreate((state) => state.setRequestOptions);
// const makeImages = () => { const addNewImage = useImageQueue((state) => state.addNewImage);
// // potentially update the seed const hasQueue = useImageQueue((state) => state.hasQueuedImages());
// if (isRandomSeed) { const { id, options } = useImageQueue((state) => state.firstInQueue());
// // update the seed for the next time we click the button
// setRequestOption("seed", useRandomSeed());
// }
// // the request that we have built const setStatus = useImageFetching((state) => state.setStatus);
// const req = builtRequest(); const appendData = useImageFetching((state) => state.appendData);
// }; const parseRequest = async (reader: ReadableStreamDefaultReader<Uint8Array>) => {
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
const queueImageRequest = (req: ImageRequest) => { if (done as boolean) {
console.log("DONE");
setStatus(FetchingStates.COMPLETE);
break;
}
const jsonStr = decoder.decode(value);
try {
const update = JSON.parse(jsonStr);
if (update.status === "progress") {
console.log("PROGRESS");
setStatus(FetchingStates.PROGRESSING);
}
else if (update.status === "succeeded") {
console.log("succeeded");
setStatus(FetchingStates.SUCCEEDED);
// appendData(update.data);
}
else {
console.log("extra?", update);
// appendData(update.data);
}
}
catch (e) {
console.log('PARSE ERRROR')
console.log(e)
debugger;
// appendData(update.data);
}
}
}
const startStream = async (req: ImageRequest) => {
const streamReq = {
...req,
// stream_image_progress: false,
};
console.log("testStream", streamReq);
try {
const res = await doMakeImage(streamReq);
// @ts-expect-error
const reader = res.body.getReader();
void parseRequest(reader);
} catch (e) {
console.log('e');
}
}
const queueImageRequest = async (req: ImageRequest) => {
// the actual number of request we will make // the actual number of request we will make
const requests = []; const requests = [];
// the number of images we will make // the number of images we will make
let { num_outputs } = req; let { num_outputs } = req;
// if making fewer images than the parallel count
// then it is only 1 request
if (parallelCount > num_outputs) { if (parallelCount > num_outputs) {
requests.push(num_outputs); requests.push(num_outputs);
} else { } else {
@ -63,7 +122,6 @@ export default function MakeButton() {
} }
} }
// make the requests
requests.forEach((num, index) => { requests.forEach((num, index) => {
// get the seed we want to use // get the seed we want to use
let seed = req.seed; let seed = req.seed;
@ -82,63 +140,42 @@ export default function MakeButton() {
}); });
} }
const testStream = async (req: ImageRequest) => { const makeImageQueue = async () => {
const streamReq = {
...req,
stream_progress_updates: true,
// stream_image_progress: false,
session_id: uuidv4(),
};
console.log("testStream", streamReq);
try {
const res = await fetch('http://localhost:9000/image', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(streamReq)
});
console.log('res', res);
const reader = res.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
console.log(text);
}
} catch (e) {
console.log(e);
debugger;
}
}
const makeImages = () => {
// potentially update the seed // potentially update the seed
if (isRandomSeed) { if (isRandomSeed) {
// update the seed for the next time we click the button // update the seed for the next time we click the button
setRequestOption("seed", useRandomSeed()); setRequestOption("seed", useRandomSeed());
} }
// the request that we have built // the request that we have built
const req = builtRequest(); const req = builtRequest();
await queueImageRequest(req);
//queueImageRequest(req); // void startStream(req);
void testStream(req);
}; };
useEffect(() => {
const makeImages = async (options: ImageRequest) => {
// potentially update the seed
await startStream(options);
}
if (hasQueue) {
makeImages(options).catch((e) => {
console.log('HAS QUEUE ERROR');
console.log(e);
});
}
}, [hasQueue, id, options, startStream]);
return ( return (
<button <button
className={MakeButtonStyle} className={MakeButtonStyle}
onClick={makeImages} onClick={() => {
setStatus(FetchingStates.FETCHING);
void makeImageQueue();
}}
disabled={hasQueue} disabled={hasQueue}
> >
{t("home.make-img-btn")} {t("home.make-img-btn")}

View File

@ -1,18 +1,24 @@
/* eslint-disable @typescript-eslint/strict-boolean-expressions */
import React, { useEffect, useState, useRef } from "react"; import React, { useEffect, useState, useRef } from "react";
import { useImageQueue } from "../../../stores/imageQueueStore"; import { useImageQueue } from "../../../stores/imageQueueStore";
import { ImageRequest, useImageCreate } from "../../../stores/imageCreateStore"; import { useImageFetching } from "../../../stores/imageFetchingStore";
import { useImageDisplay } from "../../../stores/imageDisplayStore";
import { useQuery, useQueryClient } from "@tanstack/react-query"; // import { ImageRequest, useImageCreate } from "../../../stores/imageCreateStore";
import { // import { useQuery, useQueryClient } from "@tanstack/react-query";
doMakeImage,
MakeImageKey,
ImageReturnType,
ImageOutput,
} from "../../../api";
import AudioDing from "./audioDing";
// import {
// API_URL,
// doMakeImage,
// MakeImageKey,
// ImageReturnType,
// ImageOutput,
// } from "../../../api";
// import AudioDing from "../creationPanel/basicCreation/makeButton/audioDing";
// import GeneratedImage from "../../molecules/generatedImage"; // import GeneratedImage from "../../molecules/generatedImage";
// import DrawImage from "../../molecules/drawImage"; // import DrawImage from "../../molecules/drawImage";
@ -22,132 +28,207 @@ import CompletedImages from "./completedImages";
import { import {
displayPanel, displayPanel,
displayContainer, // displayContainer,
previousImages, // previousImages,
// @ts-expect-error // @ts-expect-error
} from "./displayPanel.css.ts"; } from "./displayPanel.css.ts";
export interface CompletedImagesType { // export interface CompletedImagesType {
id: string; // id: string;
data: string; // data: string;
info: ImageRequest; // info: ImageRequest;
} // }
const idDelim = "_batch"; const idDelim = "_batch";
export default function DisplayPanel() { export default function DisplayPanel() {
const dingRef = useRef<HTMLAudioElement>(null); // const dingRef = useRef<HTMLAudioElement>(null);
const isSoundEnabled = useImageCreate((state) => state.isSoundEnabled()); // const isSoundEnabled = useImageCreate((state) => state.isSoundEnabled());
// @ts-expect-error // // @ts-expect-error
const { id, options } = useImageQueue((state) => state.firstInQueue()); // const { id, options } = useImageQueue((state) => state.firstInQueue());
const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue); // const removeFirstInQueue = useImageQueue((state) => state.removeFirstInQueue);
const [currentImage, setCurrentImage] = useState<CompletedImagesType | null>( // const [currentImage, setCurrentImage] = useState<CompletedImagesType | null>(
null // null
); // );
const [isEnabled, setIsEnabled] = useState(false); // const [isEnabled, setIsEnabled] = useState(false);
const [isLoading, setIsLoading] = useState(true); // const [isLoading, setIsLoading] = useState(true);
const { status, data } = useQuery( // const { status, data } = useQuery(
[MakeImageKey, id], // [MakeImageKey, id],
async () => await doMakeImage(options), // async () => await doMakeImage(options),
{ // {
enabled: isEnabled, // enabled: isEnabled,
} // }
); // );
// const { status, data } = useEventSourceQuery(
// MakeImageKey,
// // [MakeImageKey, id],
// // async () => await doMakeImage(options),
// // {
// // enabled: isEnabled,
// // }
// );
// update the enabled state when the id changes // update the enabled state when the id changes
useEffect(() => { // useEffect(() => {
setIsEnabled(void 0 !== id); // setIsEnabled(void 0 !== id);
}, [id]); // }, [id]);
// const _handleStreamData = async (res: typeof ReadableStream) => {
// console.log("_handleStreamData");
// let reader;
// // @ts-expect-error
// if (res.body.locked) {
// console.log("locked");
// }
// else {
// reader = res.body.getReader();
// }
// console.log("reader", reader);
// const decoder = new TextDecoder();
// while (true) {
// const { done, value } = await reader.read();
// const text = decoder.decode(value);
// console.log("DECODE", done);
// console.log(text);
// if (text.status === "progress") {
// console.log("PROGRESS");
// }
// else if (text.status === "succeeded") {
// console.log("succeeded");
// }
// else {
// console.log("extra?")
// }
// console.log("-----------------");
// if (done as boolean) {
// reader.releaseLock();
// break;
// }
// }
// };
// useEffect(() => {
// const fetch = async () => {
// const res = await doMakeImage(options);
// void _handleStreamData(res);
// }
// if (isEnabled) {
// console.log('isEnabled');
// debugger;
// fetch()
// .catch((err) => {
// console.error(err);
// });
// }
// }, [isEnabled, options, _handleStreamData]);
// helper for the loading state to be enabled aware // helper for the loading state to be enabled aware
useEffect(() => { // useEffect(() => {
if (isEnabled && status === "loading") { // if (isEnabled && status === "loading") {
setIsLoading(true); // setIsLoading(true);
} else { // } else {
setIsLoading(false); // setIsLoading(false);
} // }
}, [isEnabled, status]); // }, [isEnabled, status]);
// this is where there loading actually happens // this is where there loading actually happens
useEffect(() => { // useEffect(() => {
// query is done // console.log('DISPLATPANEL: status', status);
if (status === "success") { // console.log('DISPLATPANEL: data', data);
// check to make sure that the image was created
if (data.status === "succeeded") { // // query is done
if (isSoundEnabled) { // if (status === "success") {
// not awaiting the promise or error handling // // check to make sure that the image was created
void dingRef.current?.play();
} // void _handleStreamData(data);
removeFirstInQueue();
} // // if (data.status === "succeeded") {
} // // if (isSoundEnabled) {
}, [status, data, removeFirstInQueue, dingRef, isSoundEnabled]); // // // not awaiting the promise or error handling
// // void dingRef.current?.play();
// // }
// // removeFirstInQueue();
// // }
// }
// }, [status, data, removeFirstInQueue, dingRef, isSoundEnabled, _handleStreamData]);
/* COMPLETED IMAGES */ /* COMPLETED IMAGES */
const queryClient = useQueryClient(); // const queryClient = useQueryClient();
const [completedImages, setCompletedImages] = useState<CompletedImagesType[]>( // const [completedImages, setCompletedImages] = useState<CompletedImagesType[]>(
[] // []
); // );
const completedIds = useImageQueue((state) => state.completedImageIds); // const completedIds = useImageQueue((state) => state.completedImageIds);
const clearCachedIds = useImageQueue((state) => state.clearCachedIds); // const clearCachedIds = useImageQueue((state) => state.clearCachedIds);
// this is where we generate the list of completed images // this is where we generate the list of completed images
useEffect(() => { // useEffect(() => {
const completedQueries = completedIds.map((id) => { // const completedQueries = completedIds.map((id) => {
const imageData = queryClient.getQueryData([MakeImageKey, id]); // const imageData = queryClient.getQueryData([MakeImageKey, id]);
return imageData; // return imageData;
}) as ImageReturnType[]; // }) as ImageReturnType[];
if (completedQueries.length > 0) { // if (completedQueries.length > 0) {
// map the completedImagesto a new array // // map the completedImagesto a new array
// and then set the state // // and then set the state
const temp = completedQueries // const temp = completedQueries
.map((query, index) => { // .map((query, index) => {
if (void 0 !== query) { // if (void 0 !== query) {
return query.output.map((data: ImageOutput, index: number) => { // return query.output.map((data: ImageOutput, index: number) => {
return { // return {
id: `${completedIds[index]}${idDelim}-${data.seed}-${index}`, // id: `${completedIds[index]}${idDelim}-${data.seed}-${index}`,
data: data.data, // data: data.data,
info: { ...query.request, seed: data.seed }, // info: { ...query.request, seed: data.seed },
}; // };
}); // });
} // }
}) // })
.flat() // .flat()
.reverse() // .reverse()
.filter((item) => void 0 !== item) as CompletedImagesType[]; // remove undefined items // .filter((item) => void 0 !== item) as CompletedImagesType[]; // remove undefined items
setCompletedImages(temp); // setCompletedImages(temp);
// could move this to the useEffect for completedImages // // could move this to the useEffect for completedImages
if (temp.length > 0) { // if (temp.length > 0) {
setCurrentImage(temp[0]); // setCurrentImage(temp[0]);
} else { // } else {
setCurrentImage(null); // setCurrentImage(null);
} // }
} else { // } else {
setCompletedImages([]); // setCompletedImages([]);
setCurrentImage(null); // setCurrentImage(null);
} // }
}, [setCompletedImages, setCurrentImage, queryClient, completedIds]); // }, [setCompletedImages, setCurrentImage, queryClient, completedIds]);
// this is how we remove them // this is how we remove them
const removeImages = () => { // const removeImages = () => {
completedIds.forEach((id) => { // completedIds.forEach((id) => {
queryClient.removeQueries([MakeImageKey, id]); // queryClient.removeQueries([MakeImageKey, id]);
}); // });
clearCachedIds(); // clearCachedIds();
}; // };
return ( return (
<div className={displayPanel}> <div className={displayPanel}>
<AudioDing ref={dingRef}></AudioDing> DISPLAY
{/* <AudioDing ref={dingRef}></AudioDing>
<div className={displayContainer}> <div className={displayContainer}>
<CurrentDisplay <CurrentDisplay
isLoading={isLoading} isLoading={isLoading}
@ -160,7 +241,7 @@ export default function DisplayPanel() {
images={completedImages} images={completedImages}
setCurrentDisplay={setCurrentImage} setCurrentDisplay={setCurrentImage}
></CompletedImages> ></CompletedImages>
</div> </div> */}
</div> </div>
); );
} }

View File

@ -70,7 +70,7 @@ export interface ImageRequest {
init_image: undefined | string; init_image: undefined | string;
prompt_strength: undefined | number; prompt_strength: undefined | number;
sampler: typeof SAMPLER_OPTIONS[number]; sampler: typeof SAMPLER_OPTIONS[number];
stream_progress_updates: boolean; stream_progress_updates: true;
stream_image_progress: boolean; stream_image_progress: boolean;
} }
@ -136,7 +136,7 @@ export const useImageCreate = create<ImageCreateState>(
show_only_filtered_image: true, show_only_filtered_image: true,
init_image: undefined, init_image: undefined,
sampler: "plms", sampler: "plms",
stream_progress_updates: false, stream_progress_updates: true,
stream_image_progress: false stream_image_progress: false
}, },

View File

@ -0,0 +1,71 @@
import create from "zustand";
import produce from "immer";
export const FetchingStates = {
IDLE: "IDLE",
FETCHING: "FETCHING",
PROGRESSING: "PROGRESSING",
SUCCEEDED: "SUCCEEDED",
COMPLETE: "COMPLETE",
ERROR: "ERROR",
} as const;
interface ImageFetchingState {
status: typeof FetchingStates[keyof typeof FetchingStates];
step: number;
totalSteps: number;
data: string;
appendData: (data: string) => void;
reset: () => void;
setStatus: (status: typeof FetchingStates[keyof typeof FetchingStates]) => void;
setStep: (step: number) => void;
setTotalSteps: (totalSteps: number) => void;
}
export const useImageFetching = create<ImageFetchingState>((set) => ({
status: FetchingStates.IDLE,
step: 0,
totalSteps: 0,
data: '',
// use produce to make sure we don't mutate state
appendData: (data: string) => {
set(
produce((state: ImageFetchingState) => {
// eslint-disable-next-line @typescript-eslint/restrict-plus-operands
state.data += data;
})
);
},
reset: () => {
set(
produce((state: ImageFetchingState) => {
state.status = FetchingStates.IDLE;
state.step = 0;
state.totalSteps = 0;
state.data = '';
})
);
},
setStatus: (status: typeof FetchingStates[keyof typeof FetchingStates]) => {
set(
produce((state: ImageFetchingState) => {
state.status = status;
})
);
},
setStep: (step: number) => {
set(
produce((state: ImageFetchingState) => {
state.step = step;
})
);
},
setTotalSteps: (totalSteps: number) => {
set(
produce((state: ImageFetchingState) => {
state.totalSteps = totalSteps;
})
);
},
}));

View File

@ -370,7 +370,13 @@ def do_mk_img(req: Request):
if req.stream_progress_updates: if req.stream_progress_updates:
n_steps = opt_ddim_steps if req.init_image is None else t_enc n_steps = opt_ddim_steps if req.init_image is None else t_enc
progress = {"step": i, "total_steps": n_steps} progress = {
"status": "progress",
#"progress": (i + 1) / n_steps,
"progress": {
"step": i, "total_steps": n_steps
}
}
if req.stream_image_progress and i % 5 == 0: if req.stream_image_progress and i % 5 == 0:
partial_images = [] partial_images = []