Update TextGen/router.py
Browse files- TextGen/router.py +22 -23
TextGen/router.py
CHANGED
@@ -15,11 +15,11 @@ from langchain_google_genai import (
|
|
15 |
)
|
16 |
from TextGen import app
|
17 |
from gradio_client import Client
|
18 |
-
|
19 |
|
20 |
class Message(BaseModel):
|
21 |
npc: str | None = None
|
22 |
-
|
23 |
|
24 |
class VoiceMessage(BaseModel):
|
25 |
npc: str | None = None
|
@@ -42,27 +42,26 @@ main_npcs={
|
|
42 |
class Generate(BaseModel):
|
43 |
text:str
|
44 |
|
45 |
-
def generate_text(
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
return Generate(text=llm_response)
|
66 |
|
67 |
|
68 |
|
|
|
15 |
)
|
16 |
from TextGen import app
|
17 |
from gradio_client import Client
|
18 |
+
from typing import List
|
19 |
|
20 |
class Message(BaseModel):
|
21 |
npc: str | None = None
|
22 |
+
messages: List[str] | None = None
|
23 |
|
24 |
class VoiceMessage(BaseModel):
|
25 |
npc: str | None = None
|
|
|
42 |
class Generate(BaseModel):
|
43 |
text:str
|
44 |
|
45 |
+
def generate_text(messages: List[str]):
|
46 |
+
print(messages)
|
47 |
+
promptmessages[-1]
|
48 |
+
prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])
|
49 |
+
|
50 |
+
# Initialize the LLM
|
51 |
+
llm = ChatGoogleGenerativeAI(
|
52 |
+
model="gemini-pro",
|
53 |
+
safety_settings={
|
54 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
55 |
+
},
|
56 |
+
)
|
57 |
+
|
58 |
+
llmchain = LLMChain(
|
59 |
+
prompt=prompt,
|
60 |
+
llm=llm
|
61 |
+
)
|
62 |
+
|
63 |
+
llm_response = llmchain.run({"Prompt": prompt})
|
64 |
+
return Generate(text=llm_response)
|
|
|
65 |
|
66 |
|
67 |
|