Spaces:
Running
Running
import { useState, useEffect, useCallback } from 'react'; | |
import { Loader2 } from 'lucide-react'; | |
import ModelSelector from './ModelSelector'; | |
import PromptInput from './PromptInput'; | |
import AspectRatioSelector from './AspectRatioSelector'; | |
import ImageCountSlider from './ImageCountSlider'; | |
import { Button } from '../ui/button'; | |
import { v4 as uuidv4 } from 'uuid'; | |
interface Image { | |
url: string; | |
} | |
interface Batch { | |
id: number | string; // Allow string for tempId | |
prompt: string; | |
width: number; | |
height: number; | |
model: string; | |
images: Image[]; | |
status?: string; | |
tempId?: string; | |
} | |
export default function GeneratorForm({ onGenerate, remixBatch }: { onGenerate: (batch: Batch, isPlaceholder: boolean) => void, remixBatch: Batch | null }) { | |
const [prompt, setPrompt] = useState(''); | |
const [model, setModel] = useState('runware:100@1'); // FLUX SCHNELL as default | |
const [aspectRatio, setAspectRatio] = useState('square'); | |
const [imageCount, setImageCount] = useState(1); | |
const [isLoading, setIsLoading] = useState(false); | |
const [error, setError] = useState<string | null>(null); | |
useEffect(() => { | |
if (remixBatch) { | |
setPrompt(remixBatch.prompt); | |
setModel(remixBatch.model); | |
setAspectRatio(getAspectRatioFromDimensions(remixBatch.width, remixBatch.height)); | |
setImageCount(remixBatch.images.length); | |
} | |
}, [remixBatch]); | |
const getAspectRatioFromDimensions = (width: number, height: number) => { | |
if (width === height) return 'square'; | |
if (width === 832 && height === 1216) return 'portrait'; | |
if (width === 1216 && height === 832) return 'landscape'; | |
return 'square'; // Default to square if dimensions don't match known ratios | |
}; | |
const aspectRatios: { [key: string]: { width: number; height: number } } = { | |
square: { width: 1024, height: 1024 }, | |
landscape: { width: 1216, height: 832 }, | |
portrait: { width: 832, height: 1216 } | |
}; | |
const handleGenerate = useCallback(async () => { | |
setError(null); | |
const placeholderId = uuidv4(); | |
const placeholderBatch: Batch = { | |
id: placeholderId, | |
prompt, | |
width: aspectRatios[aspectRatio as keyof typeof aspectRatios].width, | |
height: aspectRatios[aspectRatio as keyof typeof aspectRatios].height, | |
model, | |
images: Array(imageCount).fill({ url: '/placeholder-image.png' }), | |
status: 'pending', | |
tempId: placeholderId | |
}; | |
onGenerate(placeholderBatch, true); | |
try { | |
const response = await fetch('/api/generate-image', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json', | |
}, | |
body: JSON.stringify({ | |
prompt, | |
width: placeholderBatch.width, | |
height: placeholderBatch.height, | |
model, | |
number_results: imageCount, | |
placeholderId, | |
}), | |
}); | |
if (!response.ok) { | |
throw new Error('Failed to generate image'); | |
} | |
const reader = response.body?.getReader(); | |
if (!reader) { | |
throw new Error('Failed to read response'); | |
} | |
while (true) { | |
const { done, value } = await reader.read(); | |
if (done) break; | |
const chunk = new TextDecoder().decode(value); | |
const data = JSON.parse(chunk); | |
if (data.batch) { | |
onGenerate({ ...data.batch, tempId: placeholderId }, false); | |
} | |
} | |
} catch (error) { | |
console.error('Error generating image:', error); | |
setError(error instanceof Error ? error.message : 'An unknown error occurred'); | |
onGenerate({ ...placeholderBatch, status: 'error' }, false); | |
} | |
}, [aspectRatio, prompt, model, imageCount, onGenerate, aspectRatios]); | |
useEffect(() => { | |
const handleKeyDown = (event: KeyboardEvent) => { | |
if ((event.ctrlKey || event.metaKey) && event.key === 'Enter') { | |
event.preventDefault(); | |
handleGenerate(); | |
} | |
}; | |
document.addEventListener('keydown', handleKeyDown); | |
return () => { | |
document.removeEventListener('keydown', handleKeyDown); | |
}; | |
}, [handleGenerate]); | |
return ( | |
<div className="layout-content-container flex flex-col w-full md:w-80"> | |
<PromptInput value={prompt} onChange={setPrompt} /> | |
<ModelSelector value={model} onChange={setModel} /> | |
<AspectRatioSelector value={aspectRatio} onChange={setAspectRatio} /> | |
<ImageCountSlider value={imageCount} onChange={setImageCount} /> | |
{error && ( | |
<div className="px-4 py-2 mb-3 text-red-500 bg-red-100 dark:bg-red-900 dark:text-red-100 rounded-md"> | |
{error} | |
</div> | |
)} | |
<div className="flex px-4 py-3"> | |
<Button | |
variant="outline" | |
className="w-full justify-center bg-white dark:bg-gray-800 text-[#141414] dark:text-white font-bold" | |
onClick={handleGenerate} | |
disabled={isLoading} | |
> | |
{isLoading ? ( | |
<> | |
<Loader2 className="mr-2 size-4 animate-spin" /> | |
Generating... | |
</> | |
) : ( | |
'Generate' | |
)} | |
</Button> | |
</div> | |
</div> | |
); | |
} | |