File size: 1,966 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Utility that calls OpenAI's Dall-E Image Generator."""
from typing import Any, Dict, Optional

from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

from langchain.utils import get_from_dict_or_env


class DallEAPIWrapper(BaseModel):
    """Wrapper for OpenAI's DALL-E Image Generator.

    https://platform.openai.com/docs/guides/images/generations?context=node

    Usage instructions:

    1. `pip install openai`
    2. save your OPENAI_API_KEY in an environment variable
    """

    client: Any  #: :meta private:
    openai_api_key: Optional[str] = None
    n: int = 1
    """Number of images to generate"""
    size: str = "1024x1024"
    """Size of image to generate"""
    separator: str = "\n"
    """Separator to use when multiple URLs are returned."""
    model: Optional[str] = None
    """Model to use for image generation."""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        openai_api_key = get_from_dict_or_env(
            values, "openai_api_key", "OPENAI_API_KEY"
        )
        try:
            import openai

            openai.api_key = openai_api_key
            values["client"] = openai.Image
        except ImportError as e:
            raise ImportError(
                "Could not import openai python package. "
                "Please it install it with `pip install openai`."
            ) from e
        return values

    def run(self, query: str) -> str:
        """Run query through OpenAI and parse result."""
        response = self.client.create(
            prompt=query, n=self.n, size=self.size, model=self.model
        )
        image_urls = self.separator.join([item["url"] for item in response["data"]])
        return image_urls if image_urls else "No image was generated"