File size: 2,989 Bytes
0891679
5881efa
cca515d
2a63a7e
6d2113a
cca515d
81969cf
02b9873
2a63a7e
81969cf
cca515d
02b9873
cca515d
02b9873
 
 
 
 
 
 
 
cca515d
02b9873
5881efa
 
 
 
 
0891679
 
4a320f9
6d2113a
 
 
 
 
 
 
 
 
 
 
 
 
6b92c38
5881efa
02b9873
 
5881efa
 
 
02b9873
 
5881efa
2a63a7e
5881efa
 
6d2113a
 
 
5881efa
0891679
 
 
6d2113a
5881efa
6d2113a
 
 
 
5881efa
 
0891679
 
cca515d
81969cf
cca515d
0891679
5881efa
2a63a7e
6f0b822
 
 
 
 
 
 
 
cca515d
4a320f9
2a63a7e
 
 
 
 
02b9873
 
 
 
2a63a7e
 
4f2c36e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import { useId, useState } from "react"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { useLocalStorage } from 'react-use';

import { Collection, Image } from "@/utils/type"
import list_styles from "@/assets/list_styles.json"
import { useCollection } from "@/components/modal/useCollection";

export const useInputGeneration = () => {
  const { setOpen } = useCollection();
  const [myGenerationsId, setGenerationsId] = useLocalStorage<any>('my-own-generations', []);
  const [style, setStyle] = useState<string>(list_styles[0].name)

  const { data: prompt } = useQuery(["prompt"], () => {
    return ''
  }, {
    refetchOnWindowFocus: false,
    refetchOnMount: false,
    refetchOnReconnect: false,
    initialData: ''
  })
  const setPrompt = (str:string) => client.setQueryData(["prompt"], () => str)

  const client = useQueryClient()

  const { mutate: submit, isLoading: loading } = useMutation(
    ["generation"],
    async () => {
      // generate string random ID
      const id = Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15)
      if (!hasMadeFirstGeneration) setFirstGenerationDone()
      client.setQueryData(["collections"], (old: any) => {
        return {
          pagination: old.pagination,
          images: [{
            id,
            loading: true,
            blob: {
              type: "image/png",
              data: new ArrayBuffer(0),
            },
            prompt
          }, ...old.images as Image[]]
        }
      })

      const findStyle = list_styles.find((item) => item.name === style)

      const response = await fetch("/api", {
        method: "POST",
        body: JSON.stringify({
          inputs: findStyle?.prompt.replace("{prompt}", prompt) ?? prompt,
          negative_prompt: findStyle?.negative_prompt ?? "",
        }),
      })
      const data = await response.json()

      client.setQueryData(["collections"], (old: any) => {
        const newArray = [...old?.images as Image[]]
        const index = newArray.findIndex((item: Image) => item.id === id)

        newArray[index] = !data.ok ? {
          ...newArray[index],
          error: data.message
        } : data?.blob as Image

        return {
          ...old,
          images:newArray,
        }
      })

      if (!data.ok) return null

      setGenerationsId(myGenerationsId?.length ? [...myGenerationsId, data?.blob?.id] : [data?.blob?.id])
      setOpen(data?.blob?.id)
      return data ?? {}
    }
  )

  const { data: hasMadeFirstGeneration } = useQuery(["firstGenerationDone"], () => {
    return false
  }, {
    refetchOnWindowFocus: false,
    refetchOnMount: false,
    refetchOnReconnect: false,
    initialData: false
  })
  const setFirstGenerationDone = () => client.setQueryData(["firstGenerationDone"], () => true)

  return {
    prompt,
    setPrompt,
    loading,
    submit,
    hasMadeFirstGeneration,
    list_styles,
    style,
    setStyle
  }

}