Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -349,8 +349,11 @@ Here's the complete Manim code:
|
|
349 |
# Get the current model name and base URL
|
350 |
model_name = models["model_name"]
|
351 |
|
|
|
|
|
|
|
352 |
# Convert message to the appropriate format based on model category
|
353 |
-
config = MODEL_CONFIGS.get(
|
354 |
category = config.get("category", "Other")
|
355 |
|
356 |
if category == "OpenAI":
|
@@ -370,16 +373,22 @@ Here's the complete Manim code:
|
|
370 |
else:
|
371 |
client = models["openai_client"]
|
372 |
|
373 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
messages = [
|
375 |
-
{"role": "
|
376 |
{"role": "user", "content": prompt}
|
377 |
]
|
378 |
|
379 |
# Create params
|
380 |
params = {
|
381 |
"messages": messages,
|
382 |
-
"model":
|
383 |
}
|
384 |
|
385 |
# Add token parameter
|
@@ -2148,7 +2157,8 @@ class MyScene(Scene):
|
|
2148 |
|
2149 |
# Get model details
|
2150 |
model_name = st.session_state.custom_model
|
2151 |
-
|
|
|
2152 |
category = config.get("category", "Other")
|
2153 |
|
2154 |
if category == "OpenAI":
|
@@ -2176,7 +2186,7 @@ class MyScene(Scene):
|
|
2176 |
# Prepare parameters based on model configuration
|
2177 |
params = {
|
2178 |
"messages": [
|
2179 |
-
{"role": "
|
2180 |
{"role": "user", "content": "Hello, this is a connection test."}
|
2181 |
],
|
2182 |
"model": full_model_name
|
@@ -2318,45 +2328,88 @@ class MyScene(Scene):
|
|
2318 |
|
2319 |
# Add a refresh button to update model connection
|
2320 |
if st.button("π Refresh Model Connection", key="refresh_model_connection"):
|
2321 |
-
if st.session_state.ai_models
|
2322 |
try:
|
2323 |
-
#
|
2324 |
-
from azure.ai.inference.models import UserMessage
|
2325 |
model_name = st.session_state.custom_model
|
|
|
|
|
|
|
2326 |
|
2327 |
-
|
2328 |
-
|
2329 |
-
api_params, config = prepare_api_params(messages, model_name)
|
2330 |
-
|
2331 |
-
# Check if we need a new client with specific API version
|
2332 |
-
if config["api_version"] and config["api_version"] != st.session_state.ai_models.get("api_version"):
|
2333 |
-
# Create version-specific client if needed
|
2334 |
token = get_secret("github_token_api")
|
2335 |
-
from azure.ai.inference import ChatCompletionsClient
|
2336 |
-
from azure.core.credentials import AzureKeyCredential
|
2337 |
|
2338 |
-
|
2339 |
-
|
2340 |
-
|
2341 |
-
|
|
|
2342 |
)
|
2343 |
-
response = client.complete(**api_params)
|
2344 |
|
2345 |
-
#
|
2346 |
-
|
2347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2348 |
else:
|
2349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2350 |
|
2351 |
-
st.success(f"β
Connection to {model_name} successful!")
|
2352 |
-
st.session_state.ai_models["model_name"] = model_name
|
2353 |
-
|
2354 |
except Exception as e:
|
2355 |
st.error(f"β Connection error: {str(e)}")
|
2356 |
st.info("Please try the Debug Connection section to re-initialize the API connection.")
|
2357 |
|
2358 |
# AI code generation
|
2359 |
-
if st.session_state.ai_models
|
2360 |
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
2361 |
st.markdown("#### Generate Animation from Description")
|
2362 |
st.write("Describe the animation you want to create, or provide partial code to complete.")
|
@@ -2389,35 +2442,60 @@ class MyScene(Scene):
|
|
2389 |
if code_input:
|
2390 |
with st.spinner("AI is generating your animation code..."):
|
2391 |
try:
|
2392 |
-
# Get
|
2393 |
-
client = st.session_state.ai_models["client"]
|
2394 |
model_name = st.session_state.ai_models["model_name"]
|
|
|
|
|
|
|
2395 |
|
2396 |
# Create the prompt
|
2397 |
prompt = f"""Write a complete Manim animation scene based on this code or idea:
|
2398 |
-
|
2399 |
-
|
2400 |
-
|
2401 |
-
|
2402 |
-
|
2403 |
-
|
2404 |
-
|
2405 |
-
|
2406 |
-
|
2407 |
-
|
2408 |
-
|
2409 |
-
# Prepare API parameters
|
2410 |
-
from azure.ai.inference.models import UserMessage
|
2411 |
-
messages = [UserMessage(prompt)]
|
2412 |
-
api_params, config = prepare_api_params(messages, model_name)
|
2413 |
-
|
2414 |
-
# Make the API call with proper parameters
|
2415 |
-
response = client.complete(**api_params)
|
2416 |
|
2417 |
-
|
2418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2419 |
completed_code = response.choices[0].message.content
|
|
|
|
|
|
|
|
|
2420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2421 |
# Extract code from markdown if present
|
2422 |
if "```python" in completed_code:
|
2423 |
completed_code = completed_code.split("```python")[1].split("```")[0]
|
@@ -2427,10 +2505,10 @@ class MyScene(Scene):
|
|
2427 |
# Add Scene class if missing
|
2428 |
if "Scene" not in completed_code:
|
2429 |
completed_code = f"""from manim import *
|
2430 |
-
|
2431 |
-
|
2432 |
-
|
2433 |
-
|
2434 |
|
2435 |
# Store the generated code
|
2436 |
st.session_state.generated_code = completed_code
|
|
|
349 |
# Get the current model name and base URL
|
350 |
model_name = models["model_name"]
|
351 |
|
352 |
+
# Handle model name - extract base name if it has a prefix
|
353 |
+
base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
|
354 |
+
|
355 |
# Convert message to the appropriate format based on model category
|
356 |
+
config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
|
357 |
category = config.get("category", "Other")
|
358 |
|
359 |
if category == "OpenAI":
|
|
|
373 |
else:
|
374 |
client = models["openai_client"]
|
375 |
|
376 |
+
# Add openai/ prefix if not present
|
377 |
+
if "/" not in model_name:
|
378 |
+
full_model_name = f"openai/{model_name}"
|
379 |
+
else:
|
380 |
+
full_model_name = model_name
|
381 |
+
|
382 |
+
# For OpenAI models, use developer role instead of system
|
383 |
messages = [
|
384 |
+
{"role": "developer", "content": "You are an expert in Manim animations."},
|
385 |
{"role": "user", "content": prompt}
|
386 |
]
|
387 |
|
388 |
# Create params
|
389 |
params = {
|
390 |
"messages": messages,
|
391 |
+
"model": full_model_name
|
392 |
}
|
393 |
|
394 |
# Add token parameter
|
|
|
2157 |
|
2158 |
# Get model details
|
2159 |
model_name = st.session_state.custom_model
|
2160 |
+
base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
|
2161 |
+
config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
|
2162 |
category = config.get("category", "Other")
|
2163 |
|
2164 |
if category == "OpenAI":
|
|
|
2186 |
# Prepare parameters based on model configuration
|
2187 |
params = {
|
2188 |
"messages": [
|
2189 |
+
{"role": "developer", "content": "You are a helpful assistant."},
|
2190 |
{"role": "user", "content": "Hello, this is a connection test."}
|
2191 |
],
|
2192 |
"model": full_model_name
|
|
|
2328 |
|
2329 |
# Add a refresh button to update model connection
|
2330 |
if st.button("π Refresh Model Connection", key="refresh_model_connection"):
|
2331 |
+
if st.session_state.ai_models:
|
2332 |
try:
|
2333 |
+
# Get model details
|
|
|
2334 |
model_name = st.session_state.custom_model
|
2335 |
+
base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
|
2336 |
+
config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
|
2337 |
+
category = config.get("category", "Other")
|
2338 |
|
2339 |
+
if category == "OpenAI":
|
2340 |
+
# Get token
|
|
|
|
|
|
|
|
|
|
|
2341 |
token = get_secret("github_token_api")
|
|
|
|
|
2342 |
|
2343 |
+
# Create new OpenAI client
|
2344 |
+
from openai import OpenAI
|
2345 |
+
client = OpenAI(
|
2346 |
+
base_url="https://models.github.ai/inference",
|
2347 |
+
api_key=token
|
2348 |
)
|
|
|
2349 |
|
2350 |
+
# Add openai/ prefix if not present
|
2351 |
+
if "/" not in model_name:
|
2352 |
+
full_model_name = f"openai/{model_name}"
|
2353 |
+
else:
|
2354 |
+
full_model_name = model_name
|
2355 |
+
|
2356 |
+
# Test with minimal prompt
|
2357 |
+
response = client.chat.completions.create(
|
2358 |
+
messages=[
|
2359 |
+
{"role": "developer", "content": "You are a helpful assistant."},
|
2360 |
+
{"role": "user", "content": "Hello, this is a test."}
|
2361 |
+
],
|
2362 |
+
model=full_model_name,
|
2363 |
+
**{config["param_name"]: config[config["param_name"]]}
|
2364 |
+
)
|
2365 |
+
|
2366 |
+
# Update session state
|
2367 |
+
st.session_state.ai_models = {
|
2368 |
+
"openai_client": client,
|
2369 |
+
"model_name": full_model_name,
|
2370 |
+
"endpoint": "https://models.github.ai/inference",
|
2371 |
+
"last_loaded": datetime.now().isoformat(),
|
2372 |
+
"category": category
|
2373 |
+
}
|
2374 |
+
|
2375 |
+
st.success(f"β
Connection to {full_model_name} refreshed successfully!")
|
2376 |
else:
|
2377 |
+
# Test connection with minimal prompt for Azure models
|
2378 |
+
from azure.ai.inference.models import UserMessage
|
2379 |
+
|
2380 |
+
# Prepare parameters
|
2381 |
+
messages = [UserMessage("Hello")]
|
2382 |
+
api_params, config = prepare_api_params(messages, model_name)
|
2383 |
+
|
2384 |
+
# Check if we need a new client with specific API version
|
2385 |
+
if config["api_version"] and config["api_version"] != st.session_state.ai_models.get("api_version"):
|
2386 |
+
# Create version-specific client if needed
|
2387 |
+
token = get_secret("github_token_api")
|
2388 |
+
from azure.ai.inference import ChatCompletionsClient
|
2389 |
+
from azure.core.credentials import AzureKeyCredential
|
2390 |
+
|
2391 |
+
client = ChatCompletionsClient(
|
2392 |
+
endpoint=st.session_state.ai_models["endpoint"],
|
2393 |
+
credential=AzureKeyCredential(token),
|
2394 |
+
api_version=config["api_version"]
|
2395 |
+
)
|
2396 |
+
response = client.complete(**api_params)
|
2397 |
+
|
2398 |
+
# Update session state with the new client
|
2399 |
+
st.session_state.ai_models["client"] = client
|
2400 |
+
st.session_state.ai_models["api_version"] = config["api_version"]
|
2401 |
+
else:
|
2402 |
+
response = st.session_state.ai_models["client"].complete(**api_params)
|
2403 |
+
|
2404 |
+
st.success(f"β
Connection to {model_name} successful!")
|
2405 |
+
st.session_state.ai_models["model_name"] = model_name
|
2406 |
|
|
|
|
|
|
|
2407 |
except Exception as e:
|
2408 |
st.error(f"β Connection error: {str(e)}")
|
2409 |
st.info("Please try the Debug Connection section to re-initialize the API connection.")
|
2410 |
|
2411 |
# AI code generation
|
2412 |
+
if st.session_state.ai_models:
|
2413 |
st.markdown("<div class='card'>", unsafe_allow_html=True)
|
2414 |
st.markdown("#### Generate Animation from Description")
|
2415 |
st.write("Describe the animation you want to create, or provide partial code to complete.")
|
|
|
2442 |
if code_input:
|
2443 |
with st.spinner("AI is generating your animation code..."):
|
2444 |
try:
|
2445 |
+
# Get model details
|
|
|
2446 |
model_name = st.session_state.ai_models["model_name"]
|
2447 |
+
base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
|
2448 |
+
config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
|
2449 |
+
category = config.get("category", "Other")
|
2450 |
|
2451 |
# Create the prompt
|
2452 |
prompt = f"""Write a complete Manim animation scene based on this code or idea:
|
2453 |
+
{code_input}
|
2454 |
+
|
2455 |
+
The code should be a complete, working Manim animation that includes:
|
2456 |
+
- Proper Scene class definition
|
2457 |
+
- Constructor with animations
|
2458 |
+
- Proper use of self.play() for animations
|
2459 |
+
- Proper wait times between animations
|
2460 |
+
|
2461 |
+
Here's the complete Manim code:
|
2462 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2463 |
|
2464 |
+
if category == "OpenAI":
|
2465 |
+
# Use OpenAI client
|
2466 |
+
client = st.session_state.ai_models["openai_client"]
|
2467 |
+
|
2468 |
+
# Use developer role instead of system
|
2469 |
+
messages = [
|
2470 |
+
{"role": "developer", "content": "You are an expert in Manim animations."},
|
2471 |
+
{"role": "user", "content": prompt}
|
2472 |
+
]
|
2473 |
+
|
2474 |
+
# Create params
|
2475 |
+
params = {
|
2476 |
+
"messages": messages,
|
2477 |
+
"model": model_name,
|
2478 |
+
config["param_name"]: config[config["param_name"]]
|
2479 |
+
}
|
2480 |
+
|
2481 |
+
# Call API
|
2482 |
+
response = client.chat.completions.create(**params)
|
2483 |
completed_code = response.choices[0].message.content
|
2484 |
+
else:
|
2485 |
+
# Use Azure client for non-OpenAI models
|
2486 |
+
from azure.ai.inference.models import UserMessage
|
2487 |
+
client = st.session_state.ai_models["client"]
|
2488 |
|
2489 |
+
# Convert message format for Azure
|
2490 |
+
messages = [UserMessage(prompt)]
|
2491 |
+
api_params, _ = prepare_api_params(messages, model_name)
|
2492 |
+
|
2493 |
+
# Make API call with Azure client
|
2494 |
+
response = client.complete(**api_params)
|
2495 |
+
completed_code = response.choices[0].message.content
|
2496 |
+
|
2497 |
+
# Process the response
|
2498 |
+
if completed_code:
|
2499 |
# Extract code from markdown if present
|
2500 |
if "```python" in completed_code:
|
2501 |
completed_code = completed_code.split("```python")[1].split("```")[0]
|
|
|
2505 |
# Add Scene class if missing
|
2506 |
if "Scene" not in completed_code:
|
2507 |
completed_code = f"""from manim import *
|
2508 |
+
|
2509 |
+
class MyScene(Scene):
|
2510 |
+
def construct(self):
|
2511 |
+
{completed_code}"""
|
2512 |
|
2513 |
# Store the generated code
|
2514 |
st.session_state.generated_code = completed_code
|