Joffrey Thomas commited on
Commit
1fd2857
·
1 Parent(s): da13aa0

change to Mistral

Browse files
Files changed (3) hide show
  1. TextGen/gemini.py +20 -14
  2. TextGen/router.py +69 -54
  3. requirements.txt +1 -1
TextGen/gemini.py CHANGED
@@ -1,18 +1,21 @@
1
- import google.generativeai as genai
2
- from google.api_core import retry
3
  import os
4
  import pathlib
5
  import textwrap
6
  import json
7
 
8
- GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
9
 
10
- genai.configure(api_key=GOOGLE_API_KEY)
11
-
12
- model = genai.GenerativeModel(model_name='models/gemini-2.5-pro')
13
 
14
  def generate_story(available_items):
15
- response = model.generate_content(f"""
 
 
 
 
 
16
  You are game master in a roguelike game. You should help me create an consise objective for the player to fulfill in the current level.
17
  Our protagonist, A girl with a red hood is wandering inside a dungeon.
18
  You are ONLY allowed to only use the following elements and not modify either the map or the enemies:
@@ -22,9 +25,11 @@ def generate_story(available_items):
22
  Example 2 : Loot two treasure chest then defeat the boss $boss_name to exit.
23
  A object of type NPC can be talked to but can't move and can't be fought, an ennemy type object can be fought but not talked to. A boss can be fought but not talked to
24
  To escape, the player will have to pass by the portal.
25
-
26
- """, request_options={'retry': retry.Retry()})
27
- story = response.text
 
 
28
  return story
29
 
30
  def place_objects(available_items,story,myMap):
@@ -67,12 +72,13 @@ def place_objects(available_items,story,myMap):
67
  """
68
 
69
  print(prompt)
70
- response = model.generate_content(
71
- prompt,
72
- generation_config={'response_mime_type':'application/json'}
 
73
  )
74
  try:
75
- map_placements=json.dumps(json.loads(response.text), indent=4)
76
  except:
77
  map_placements="error"
78
  return map_placements
 
1
+ from mistralai import Mistral
 
2
  import os
3
  import pathlib
4
  import textwrap
5
  import json
6
 
7
+ MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY']
8
 
9
+ client = Mistral(api_key=MISTRAL_API_KEY)
10
+ model = "mistral-large-latest"
 
11
 
12
  def generate_story(available_items):
13
+ response = client.chat.complete(
14
+ model=model,
15
+ messages=[
16
+ {
17
+ "role": "user",
18
+ "content": f"""
19
  You are game master in a roguelike game. You should help me create an consise objective for the player to fulfill in the current level.
20
  Our protagonist, A girl with a red hood is wandering inside a dungeon.
21
  You are ONLY allowed to only use the following elements and not modify either the map or the enemies:
 
25
  Example 2 : Loot two treasure chest then defeat the boss $boss_name to exit.
26
  A object of type NPC can be talked to but can't move and can't be fought, an ennemy type object can be fought but not talked to. A boss can be fought but not talked to
27
  To escape, the player will have to pass by the portal.
28
+ """
29
+ }
30
+ ]
31
+ )
32
+ story = response.choices[0].message.content
33
  return story
34
 
35
  def place_objects(available_items,story,myMap):
 
72
  """
73
 
74
  print(prompt)
75
+ response = client.chat.complete(
76
+ model=model,
77
+ messages=[{"role": "user", "content": prompt}],
78
+ response_format={"type": "json_object"}
79
  )
80
  try:
81
+ map_placements=json.dumps(json.loads(response.choices[0].message.content), indent=4)
82
  except:
83
  map_placements="error"
84
  return map_placements
TextGen/router.py CHANGED
@@ -1,22 +1,16 @@
1
  import os
2
  import time
3
  from io import BytesIO
4
- from langchain_core.pydantic_v1 import BaseModel, Field
5
  from fastapi import FastAPI, HTTPException, Query, Request
6
  from fastapi.responses import StreamingResponse,Response
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
- from langchain.chains import LLMChain
10
- from langchain.prompts import PromptTemplate
11
  from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
12
  from TextGen.gemini import generate_story,place_objects,generate_map_markdown
13
  #from TextGen.diffusion import generate_image
14
  #from coqui import predict
15
- from langchain_google_genai import (
16
- ChatGoogleGenerativeAI,
17
- HarmBlockThreshold,
18
- HarmCategory,
19
- )
20
  from TextGen import app
21
  from gradio_client import Client, handle_file
22
  from typing import List
@@ -24,29 +18,60 @@ from elevenlabs.client import ElevenLabs
24
  from elevenlabs import Voice, VoiceSettings, stream
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  Eleven_client = ElevenLabs(
28
  api_key=os.environ["ELEVEN_API_KEY"], # Defaults to ELEVEN_API_KEY
29
  )
30
 
31
 
32
  Last_message=None
33
- class PlayLastMusic(BaseModel):
34
- '''plays the lastest created music '''
35
- Desicion: str = Field(
36
- ..., description="Yes or No"
37
- )
38
-
39
- class CreateLyrics(BaseModel):
40
- f'''create some Lyrics for a new music'''
41
- Desicion: str = Field(
42
- ..., description="Yes or No"
43
- )
44
-
45
- class CreateNewMusic(BaseModel):
46
- f'''create a new music with the Lyrics previously computed'''
47
- Name: str = Field(
48
- ..., description="tags to describe the new music"
49
- )
50
 
51
  class SongRequest(BaseModel):
52
  prompt: str | None = None
@@ -120,7 +145,7 @@ def generate_text(messages: List[str], npc:str):
120
  system_prompt=general_npc_prompt+"/n "+main_npc_system_prompts[npc]
121
  else:
122
  system_prompt="you're a character in a video game. Play along."
123
- print(system_prompt)
124
  new_messages=[{"role": "user", "content": system_prompt}]
125
  for index, message in enumerate(messages):
126
  if index%2==0:
@@ -128,43 +153,33 @@ def generate_text(messages: List[str], npc:str):
128
  else:
129
  new_messages.append({"role": "assistant", "content": message})
130
  print(new_messages)
131
- # Initialize the LLM
132
- llm = ChatGoogleGenerativeAI(
133
- model="gemini-2.5-pro",
134
- max_output_tokens=100,
135
- temperature=1,
136
- safety_settings={
137
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
138
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
139
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
140
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
141
- },
142
- )
143
- if npc=="bard":
144
- llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics])
145
 
146
- llm_response = llm.invoke(new_messages)
 
 
 
 
 
 
 
 
 
 
147
  print(llm_response)
148
- return Generate(text=llm_response.content)
 
149
 
150
 
151
  def inference_model(system_messsage, prompt):
152
-
153
  new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
154
- llm = ChatGoogleGenerativeAI(
155
- model="gemini-2.5-pro",
156
- max_output_tokens=100,
 
157
  temperature=1,
158
- safety_settings={
159
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
160
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
161
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
162
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
163
- },
164
  )
165
- llm_response = llm.invoke(new_messages)
166
  print(llm_response)
167
- return Generate(text=llm_response.content)
168
 
169
  @app.get("/", tags=["Home"])
170
  def api_home(request: Request):
 
1
  import os
2
  import time
3
  from io import BytesIO
4
+ from pydantic import BaseModel, Field
5
  from fastapi import FastAPI, HTTPException, Query, Request
6
  from fastapi.responses import StreamingResponse,Response
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
+ from mistralai import Mistral
 
10
  from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
11
  from TextGen.gemini import generate_story,place_objects,generate_map_markdown
12
  #from TextGen.diffusion import generate_image
13
  #from coqui import predict
 
 
 
 
 
14
  from TextGen import app
15
  from gradio_client import Client, handle_file
16
  from typing import List
 
18
  from elevenlabs import Voice, VoiceSettings, stream
19
 
20
 
21
+ mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
22
+ mistral_model = "mistral-large-latest"
23
+
24
+ bard_tools = [
25
+ {
26
+ "type": "function",
27
+ "function": {
28
+ "name": "PlayLastMusic",
29
+ "description": "plays the lastest created music",
30
+ "parameters": {
31
+ "type": "object",
32
+ "properties": {
33
+ "Desicion": {"type": "string", "description": "Yes or No"}
34
+ },
35
+ "required": ["Desicion"]
36
+ }
37
+ }
38
+ },
39
+ {
40
+ "type": "function",
41
+ "function": {
42
+ "name": "CreateLyrics",
43
+ "description": "create some Lyrics for a new music",
44
+ "parameters": {
45
+ "type": "object",
46
+ "properties": {
47
+ "Desicion": {"type": "string", "description": "Yes or No"}
48
+ },
49
+ "required": ["Desicion"]
50
+ }
51
+ }
52
+ },
53
+ {
54
+ "type": "function",
55
+ "function": {
56
+ "name": "CreateNewMusic",
57
+ "description": "create a new music with the Lyrics previously computed",
58
+ "parameters": {
59
+ "type": "object",
60
+ "properties": {
61
+ "Name": {"type": "string", "description": "tags to describe the new music"}
62
+ },
63
+ "required": ["Name"]
64
+ }
65
+ }
66
+ }
67
+ ]
68
+
69
  Eleven_client = ElevenLabs(
70
  api_key=os.environ["ELEVEN_API_KEY"], # Defaults to ELEVEN_API_KEY
71
  )
72
 
73
 
74
  Last_message=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  class SongRequest(BaseModel):
77
  prompt: str | None = None
 
145
  system_prompt=general_npc_prompt+"/n "+main_npc_system_prompts[npc]
146
  else:
147
  system_prompt="you're a character in a video game. Play along."
148
+ print(system_prompt)
149
  new_messages=[{"role": "user", "content": system_prompt}]
150
  for index, message in enumerate(messages):
151
  if index%2==0:
 
153
  else:
154
  new_messages.append({"role": "assistant", "content": message})
155
  print(new_messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ kwargs = {
158
+ "model": mistral_model,
159
+ "messages": new_messages,
160
+ "max_tokens": 100,
161
+ "temperature": 1,
162
+ }
163
+ if npc == "bard":
164
+ kwargs["tools"] = bard_tools
165
+ kwargs["tool_choice"] = "auto"
166
+
167
+ llm_response = mistral_client.chat.complete(**kwargs)
168
  print(llm_response)
169
+ content = llm_response.choices[0].message.content or ""
170
+ return Generate(text=content)
171
 
172
 
173
  def inference_model(system_messsage, prompt):
 
174
  new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
175
+ llm_response = mistral_client.chat.complete(
176
+ model=mistral_model,
177
+ messages=new_messages,
178
+ max_tokens=100,
179
  temperature=1,
 
 
 
 
 
 
180
  )
 
181
  print(llm_response)
182
+ return Generate(text=llm_response.choices[0].message.content)
183
 
184
  @app.get("/", tags=["Home"])
185
  def api_home(request: Request):
requirements.txt CHANGED
@@ -2,7 +2,7 @@ fastapi==0.99.1
2
  uvicorn
3
  requests
4
  langchain
5
- langchain_google_genai
6
  Pillow
7
  gradio_client
8
  TTS
 
2
  uvicorn
3
  requests
4
  langchain
5
+ mistralai
6
  Pillow
7
  gradio_client
8
  TTS