Joffrey Thomas commited on
Commit ·
1fd2857
1
Parent(s): da13aa0
change to Mistral
Browse files- TextGen/gemini.py +20 -14
- TextGen/router.py +69 -54
- requirements.txt +1 -1
TextGen/gemini.py
CHANGED
|
@@ -1,18 +1,21 @@
|
|
| 1 |
-
|
| 2 |
-
from google.api_core import retry
|
| 3 |
import os
|
| 4 |
import pathlib
|
| 5 |
import textwrap
|
| 6 |
import json
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
model = genai.GenerativeModel(model_name='models/gemini-2.5-pro')
|
| 13 |
|
| 14 |
def generate_story(available_items):
|
| 15 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
You are game master in a roguelike game. You should help me create an consise objective for the player to fulfill in the current level.
|
| 17 |
Our protagonist, A girl with a red hood is wandering inside a dungeon.
|
| 18 |
You are ONLY allowed to only use the following elements and not modify either the map or the enemies:
|
|
@@ -22,9 +25,11 @@ def generate_story(available_items):
|
|
| 22 |
Example 2 : Loot two treasure chest then defeat the boss $boss_name to exit.
|
| 23 |
A object of type NPC can be talked to but can't move and can't be fought, an ennemy type object can be fought but not talked to. A boss can be fought but not talked to
|
| 24 |
To escape, the player will have to pass by the portal.
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
return story
|
| 29 |
|
| 30 |
def place_objects(available_items,story,myMap):
|
|
@@ -67,12 +72,13 @@ def place_objects(available_items,story,myMap):
|
|
| 67 |
"""
|
| 68 |
|
| 69 |
print(prompt)
|
| 70 |
-
response =
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
)
|
| 74 |
try:
|
| 75 |
-
map_placements=json.dumps(json.loads(response.
|
| 76 |
except:
|
| 77 |
map_placements="error"
|
| 78 |
return map_placements
|
|
|
|
| 1 |
+
from mistralai import Mistral
|
|
|
|
| 2 |
import os
|
| 3 |
import pathlib
|
| 4 |
import textwrap
|
| 5 |
import json
|
| 6 |
|
| 7 |
+
MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY']
|
| 8 |
|
| 9 |
+
client = Mistral(api_key=MISTRAL_API_KEY)
|
| 10 |
+
model = "mistral-large-latest"
|
|
|
|
| 11 |
|
| 12 |
def generate_story(available_items):
|
| 13 |
+
response = client.chat.complete(
|
| 14 |
+
model=model,
|
| 15 |
+
messages=[
|
| 16 |
+
{
|
| 17 |
+
"role": "user",
|
| 18 |
+
"content": f"""
|
| 19 |
You are game master in a roguelike game. You should help me create an consise objective for the player to fulfill in the current level.
|
| 20 |
Our protagonist, A girl with a red hood is wandering inside a dungeon.
|
| 21 |
You are ONLY allowed to only use the following elements and not modify either the map or the enemies:
|
|
|
|
| 25 |
Example 2 : Loot two treasure chest then defeat the boss $boss_name to exit.
|
| 26 |
A object of type NPC can be talked to but can't move and can't be fought, an ennemy type object can be fought but not talked to. A boss can be fought but not talked to
|
| 27 |
To escape, the player will have to pass by the portal.
|
| 28 |
+
"""
|
| 29 |
+
}
|
| 30 |
+
]
|
| 31 |
+
)
|
| 32 |
+
story = response.choices[0].message.content
|
| 33 |
return story
|
| 34 |
|
| 35 |
def place_objects(available_items,story,myMap):
|
|
|
|
| 72 |
"""
|
| 73 |
|
| 74 |
print(prompt)
|
| 75 |
+
response = client.chat.complete(
|
| 76 |
+
model=model,
|
| 77 |
+
messages=[{"role": "user", "content": prompt}],
|
| 78 |
+
response_format={"type": "json_object"}
|
| 79 |
)
|
| 80 |
try:
|
| 81 |
+
map_placements=json.dumps(json.loads(response.choices[0].message.content), indent=4)
|
| 82 |
except:
|
| 83 |
map_placements="error"
|
| 84 |
return map_placements
|
TextGen/router.py
CHANGED
|
@@ -1,22 +1,16 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
from io import BytesIO
|
| 4 |
-
from
|
| 5 |
from fastapi import FastAPI, HTTPException, Query, Request
|
| 6 |
from fastapi.responses import StreamingResponse,Response
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
from langchain.prompts import PromptTemplate
|
| 11 |
from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
|
| 12 |
from TextGen.gemini import generate_story,place_objects,generate_map_markdown
|
| 13 |
#from TextGen.diffusion import generate_image
|
| 14 |
#from coqui import predict
|
| 15 |
-
from langchain_google_genai import (
|
| 16 |
-
ChatGoogleGenerativeAI,
|
| 17 |
-
HarmBlockThreshold,
|
| 18 |
-
HarmCategory,
|
| 19 |
-
)
|
| 20 |
from TextGen import app
|
| 21 |
from gradio_client import Client, handle_file
|
| 22 |
from typing import List
|
|
@@ -24,29 +18,60 @@ from elevenlabs.client import ElevenLabs
|
|
| 24 |
from elevenlabs import Voice, VoiceSettings, stream
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
Eleven_client = ElevenLabs(
|
| 28 |
api_key=os.environ["ELEVEN_API_KEY"], # Defaults to ELEVEN_API_KEY
|
| 29 |
)
|
| 30 |
|
| 31 |
|
| 32 |
Last_message=None
|
| 33 |
-
class PlayLastMusic(BaseModel):
|
| 34 |
-
'''plays the lastest created music '''
|
| 35 |
-
Desicion: str = Field(
|
| 36 |
-
..., description="Yes or No"
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
class CreateLyrics(BaseModel):
|
| 40 |
-
f'''create some Lyrics for a new music'''
|
| 41 |
-
Desicion: str = Field(
|
| 42 |
-
..., description="Yes or No"
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
class CreateNewMusic(BaseModel):
|
| 46 |
-
f'''create a new music with the Lyrics previously computed'''
|
| 47 |
-
Name: str = Field(
|
| 48 |
-
..., description="tags to describe the new music"
|
| 49 |
-
)
|
| 50 |
|
| 51 |
class SongRequest(BaseModel):
|
| 52 |
prompt: str | None = None
|
|
@@ -120,7 +145,7 @@ def generate_text(messages: List[str], npc:str):
|
|
| 120 |
system_prompt=general_npc_prompt+"/n "+main_npc_system_prompts[npc]
|
| 121 |
else:
|
| 122 |
system_prompt="you're a character in a video game. Play along."
|
| 123 |
-
print(system_prompt)
|
| 124 |
new_messages=[{"role": "user", "content": system_prompt}]
|
| 125 |
for index, message in enumerate(messages):
|
| 126 |
if index%2==0:
|
|
@@ -128,43 +153,33 @@ def generate_text(messages: List[str], npc:str):
|
|
| 128 |
else:
|
| 129 |
new_messages.append({"role": "assistant", "content": message})
|
| 130 |
print(new_messages)
|
| 131 |
-
# Initialize the LLM
|
| 132 |
-
llm = ChatGoogleGenerativeAI(
|
| 133 |
-
model="gemini-2.5-pro",
|
| 134 |
-
max_output_tokens=100,
|
| 135 |
-
temperature=1,
|
| 136 |
-
safety_settings={
|
| 137 |
-
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
| 138 |
-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
| 139 |
-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
| 140 |
-
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
|
| 141 |
-
},
|
| 142 |
-
)
|
| 143 |
-
if npc=="bard":
|
| 144 |
-
llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics])
|
| 145 |
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
print(llm_response)
|
| 148 |
-
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
def inference_model(system_messsage, prompt):
|
| 152 |
-
|
| 153 |
new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
|
| 154 |
-
|
| 155 |
-
model=
|
| 156 |
-
|
|
|
|
| 157 |
temperature=1,
|
| 158 |
-
safety_settings={
|
| 159 |
-
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
| 160 |
-
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
| 161 |
-
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
| 162 |
-
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
|
| 163 |
-
},
|
| 164 |
)
|
| 165 |
-
llm_response = llm.invoke(new_messages)
|
| 166 |
print(llm_response)
|
| 167 |
-
return Generate(text=llm_response.content)
|
| 168 |
|
| 169 |
@app.get("/", tags=["Home"])
|
| 170 |
def api_home(request: Request):
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
from io import BytesIO
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
from fastapi import FastAPI, HTTPException, Query, Request
|
| 6 |
from fastapi.responses import StreamingResponse,Response
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
+
from mistralai import Mistral
|
|
|
|
| 10 |
from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
|
| 11 |
from TextGen.gemini import generate_story,place_objects,generate_map_markdown
|
| 12 |
#from TextGen.diffusion import generate_image
|
| 13 |
#from coqui import predict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from TextGen import app
|
| 15 |
from gradio_client import Client, handle_file
|
| 16 |
from typing import List
|
|
|
|
| 18 |
from elevenlabs import Voice, VoiceSettings, stream
|
| 19 |
|
| 20 |
|
| 21 |
+
mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
|
| 22 |
+
mistral_model = "mistral-large-latest"
|
| 23 |
+
|
| 24 |
+
bard_tools = [
|
| 25 |
+
{
|
| 26 |
+
"type": "function",
|
| 27 |
+
"function": {
|
| 28 |
+
"name": "PlayLastMusic",
|
| 29 |
+
"description": "plays the lastest created music",
|
| 30 |
+
"parameters": {
|
| 31 |
+
"type": "object",
|
| 32 |
+
"properties": {
|
| 33 |
+
"Desicion": {"type": "string", "description": "Yes or No"}
|
| 34 |
+
},
|
| 35 |
+
"required": ["Desicion"]
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"type": "function",
|
| 41 |
+
"function": {
|
| 42 |
+
"name": "CreateLyrics",
|
| 43 |
+
"description": "create some Lyrics for a new music",
|
| 44 |
+
"parameters": {
|
| 45 |
+
"type": "object",
|
| 46 |
+
"properties": {
|
| 47 |
+
"Desicion": {"type": "string", "description": "Yes or No"}
|
| 48 |
+
},
|
| 49 |
+
"required": ["Desicion"]
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"type": "function",
|
| 55 |
+
"function": {
|
| 56 |
+
"name": "CreateNewMusic",
|
| 57 |
+
"description": "create a new music with the Lyrics previously computed",
|
| 58 |
+
"parameters": {
|
| 59 |
+
"type": "object",
|
| 60 |
+
"properties": {
|
| 61 |
+
"Name": {"type": "string", "description": "tags to describe the new music"}
|
| 62 |
+
},
|
| 63 |
+
"required": ["Name"]
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
Eleven_client = ElevenLabs(
|
| 70 |
api_key=os.environ["ELEVEN_API_KEY"], # Defaults to ELEVEN_API_KEY
|
| 71 |
)
|
| 72 |
|
| 73 |
|
| 74 |
Last_message=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
class SongRequest(BaseModel):
|
| 77 |
prompt: str | None = None
|
|
|
|
| 145 |
system_prompt=general_npc_prompt+"/n "+main_npc_system_prompts[npc]
|
| 146 |
else:
|
| 147 |
system_prompt="you're a character in a video game. Play along."
|
| 148 |
+
print(system_prompt)
|
| 149 |
new_messages=[{"role": "user", "content": system_prompt}]
|
| 150 |
for index, message in enumerate(messages):
|
| 151 |
if index%2==0:
|
|
|
|
| 153 |
else:
|
| 154 |
new_messages.append({"role": "assistant", "content": message})
|
| 155 |
print(new_messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
kwargs = {
|
| 158 |
+
"model": mistral_model,
|
| 159 |
+
"messages": new_messages,
|
| 160 |
+
"max_tokens": 100,
|
| 161 |
+
"temperature": 1,
|
| 162 |
+
}
|
| 163 |
+
if npc == "bard":
|
| 164 |
+
kwargs["tools"] = bard_tools
|
| 165 |
+
kwargs["tool_choice"] = "auto"
|
| 166 |
+
|
| 167 |
+
llm_response = mistral_client.chat.complete(**kwargs)
|
| 168 |
print(llm_response)
|
| 169 |
+
content = llm_response.choices[0].message.content or ""
|
| 170 |
+
return Generate(text=content)
|
| 171 |
|
| 172 |
|
| 173 |
def inference_model(system_messsage, prompt):
|
|
|
|
| 174 |
new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
|
| 175 |
+
llm_response = mistral_client.chat.complete(
|
| 176 |
+
model=mistral_model,
|
| 177 |
+
messages=new_messages,
|
| 178 |
+
max_tokens=100,
|
| 179 |
temperature=1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
)
|
|
|
|
| 181 |
print(llm_response)
|
| 182 |
+
return Generate(text=llm_response.choices[0].message.content)
|
| 183 |
|
| 184 |
@app.get("/", tags=["Home"])
|
| 185 |
def api_home(request: Request):
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@ fastapi==0.99.1
|
|
| 2 |
uvicorn
|
| 3 |
requests
|
| 4 |
langchain
|
| 5 |
-
|
| 6 |
Pillow
|
| 7 |
gradio_client
|
| 8 |
TTS
|
|
|
|
| 2 |
uvicorn
|
| 3 |
requests
|
| 4 |
langchain
|
| 5 |
+
mistralai
|
| 6 |
Pillow
|
| 7 |
gradio_client
|
| 8 |
TTS
|