File size: 5,441 Bytes
14b8ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e7319ea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Optional\n",
    "from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
    "from pydantic import Field\n",
    "from gradio_client import Client, handle_file\n",
    "\n",
    "# Use local BaseChatModel if available, else fallback to langchain_core\n",
    "try:\n",
    "    from BaseChatModel import BaseChatModel\n",
    "except ImportError:\n",
    "    from langchain_core.language_models.chat_models import BaseChatModel\n",
    "\n",
    "try:\n",
    "    from langchain_core.messages.base import BaseMessage\n",
    "except ImportError:\n",
    "    from langchain.schema import BaseMessage\n",
    "\n",
    "try:\n",
    "    from langchain_core.messages import AIMessage\n",
    "except ImportError:\n",
    "    from langchain.schema import AIMessage\n",
    "\n",
    "try:\n",
    "    from langchain_core.outputs import ChatResult\n",
    "except ImportError:\n",
    "    from langchain.schema import ChatResult\n",
    "\n",
    "try:\n",
    "    from langchain_core.outputs import ChatGeneration\n",
    "except ImportError:\n",
    "    from langchain.schema import ChatGeneration\n",
    "\n",
    "\n",
    "class GradioChatModel(BaseChatModel):\n",
    "    client: Client = Field(default=None, description=\"Gradio client for API communication\")\n",
    "\n",
    "    def __init__(self, client: Client = None, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        if client is None:\n",
    "            client = Client(\"apjanco/fantastic-futures\")\n",
    "        object.__setattr__(self, 'client', client)\n",
    "\n",
    "    @property\n",
    "    def _llm_type(self) -> str:\n",
    "        return \"gradio_chat_model\"\n",
    "\n",
    "    def _generate(\n",
    "        self,\n",
    "        messages: List[BaseMessage],\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "    ) -> ChatResult:\n",
    "        # Use the first message as prompt, and optionally extract image url if present\n",
    "        prompt = None\n",
    "        image_url = None\n",
    "        for msg in messages:\n",
    "            if hasattr(msg, \"content\") and msg.content:\n",
    "                if prompt is None:\n",
    "                    prompt = msg.content\n",
    "                # Optionally, look for an image url in the message metadata or content\n",
    "                if hasattr(msg, \"image\") and msg.image:\n",
    "                    image_url = msg.image\n",
    "        if prompt is None:\n",
    "            prompt = \"Hello!!\"\n",
    "        if image_url is None:\n",
    "            # fallback image\n",
    "            image_url = 'https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'\n",
    "\n",
    "        image_file = handle_file(image_url)\n",
    "        response = self.client.predict(\n",
    "            image=image_file,\n",
    "            model_id='nanonets/Nanonets-OCR-s',\n",
    "            prompt=prompt,\n",
    "            api_name=\"/run_example\"\n",
    "        )\n",
    "        # The response may be a string or dict; wrap as AIMessage\n",
    "        if isinstance(response, dict) and \"message\" in response:\n",
    "            content = response[\"message\"]\n",
    "        else:\n",
    "            content = str(response)\n",
    "        message = AIMessage(content=content)\n",
    "        # Wrap the AIMessage in a ChatGeneration object\n",
    "        chat_generation = ChatGeneration(message=message)\n",
    "        return ChatResult(generations=[chat_generation])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e9a50bd3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded as API: https://apjanco-fantastic-futures.hf.space ✔\n",
      "content='This is an icon of a bus.'\n"
     ]
    }
   ],
   "source": [
    "from langchain.schema import HumanMessage\n",
    "\n",
    "# Create a HumanMessage with content and image attribute\n",
    "class HumanMessageWithImage(HumanMessage):\n",
    "    def __init__(self, content, image=None, **kwargs):\n",
    "        super().__init__(content=content, **kwargs)\n",
    "        self.image = image\n",
    "\n",
    "custom_llm = GradioChatModel()\n",
    "image_url = \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n",
    "prompt = \"what is this?\"\n",
    "msg = HumanMessageWithImage(content=prompt, image=image_url)\n",
    "\n",
    "# Call invoke with a list of messages\n",
    "result = custom_llm.invoke([msg])\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "920ba4af",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}