euler314 commited on
Commit
2e76447
Β·
verified Β·
1 Parent(s): 28d6753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -57
app.py CHANGED
@@ -349,8 +349,11 @@ Here's the complete Manim code:
349
  # Get the current model name and base URL
350
  model_name = models["model_name"]
351
 
 
 
 
352
  # Convert message to the appropriate format based on model category
353
- config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
354
  category = config.get("category", "Other")
355
 
356
  if category == "OpenAI":
@@ -370,16 +373,22 @@ Here's the complete Manim code:
370
  else:
371
  client = models["openai_client"]
372
 
373
- # For OpenAI models, we need role-based messages
 
 
 
 
 
 
374
  messages = [
375
- {"role": "system", "content": "You are an expert in Manim animations."},
376
  {"role": "user", "content": prompt}
377
  ]
378
 
379
  # Create params
380
  params = {
381
  "messages": messages,
382
- "model": model_name
383
  }
384
 
385
  # Add token parameter
@@ -2148,7 +2157,8 @@ class MyScene(Scene):
2148
 
2149
  # Get model details
2150
  model_name = st.session_state.custom_model
2151
- config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
 
2152
  category = config.get("category", "Other")
2153
 
2154
  if category == "OpenAI":
@@ -2176,7 +2186,7 @@ class MyScene(Scene):
2176
  # Prepare parameters based on model configuration
2177
  params = {
2178
  "messages": [
2179
- {"role": "system", "content": "You are a helpful assistant."},
2180
  {"role": "user", "content": "Hello, this is a connection test."}
2181
  ],
2182
  "model": full_model_name
@@ -2318,45 +2328,88 @@ class MyScene(Scene):
2318
 
2319
  # Add a refresh button to update model connection
2320
  if st.button("πŸ”„ Refresh Model Connection", key="refresh_model_connection"):
2321
- if st.session_state.ai_models and 'client' in st.session_state.ai_models:
2322
  try:
2323
- # Test connection with minimal prompt
2324
- from azure.ai.inference.models import UserMessage
2325
  model_name = st.session_state.custom_model
 
 
 
2326
 
2327
- # Prepare parameters
2328
- messages = [UserMessage("Hello")]
2329
- api_params, config = prepare_api_params(messages, model_name)
2330
-
2331
- # Check if we need a new client with specific API version
2332
- if config["api_version"] and config["api_version"] != st.session_state.ai_models.get("api_version"):
2333
- # Create version-specific client if needed
2334
  token = get_secret("github_token_api")
2335
- from azure.ai.inference import ChatCompletionsClient
2336
- from azure.core.credentials import AzureKeyCredential
2337
 
2338
- client = ChatCompletionsClient(
2339
- endpoint=st.session_state.ai_models["endpoint"],
2340
- credential=AzureKeyCredential(token),
2341
- api_version=config["api_version"]
 
2342
  )
2343
- response = client.complete(**api_params)
2344
 
2345
- # Update session state with the new client
2346
- st.session_state.ai_models["client"] = client
2347
- st.session_state.ai_models["api_version"] = config["api_version"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2348
  else:
2349
- response = st.session_state.ai_models["client"].complete(**api_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2350
 
2351
- st.success(f"βœ… Connection to {model_name} successful!")
2352
- st.session_state.ai_models["model_name"] = model_name
2353
-
2354
  except Exception as e:
2355
  st.error(f"❌ Connection error: {str(e)}")
2356
  st.info("Please try the Debug Connection section to re-initialize the API connection.")
2357
 
2358
  # AI code generation
2359
- if st.session_state.ai_models and "client" in st.session_state.ai_models:
2360
  st.markdown("<div class='card'>", unsafe_allow_html=True)
2361
  st.markdown("#### Generate Animation from Description")
2362
  st.write("Describe the animation you want to create, or provide partial code to complete.")
@@ -2389,35 +2442,60 @@ class MyScene(Scene):
2389
  if code_input:
2390
  with st.spinner("AI is generating your animation code..."):
2391
  try:
2392
- # Get the client and model name
2393
- client = st.session_state.ai_models["client"]
2394
  model_name = st.session_state.ai_models["model_name"]
 
 
 
2395
 
2396
  # Create the prompt
2397
  prompt = f"""Write a complete Manim animation scene based on this code or idea:
2398
- {code_input}
2399
-
2400
- The code should be a complete, working Manim animation that includes:
2401
- - Proper Scene class definition
2402
- - Constructor with animations
2403
- - Proper use of self.play() for animations
2404
- - Proper wait times between animations
2405
-
2406
- Here's the complete Manim code:
2407
- """
2408
-
2409
- # Prepare API parameters
2410
- from azure.ai.inference.models import UserMessage
2411
- messages = [UserMessage(prompt)]
2412
- api_params, config = prepare_api_params(messages, model_name)
2413
-
2414
- # Make the API call with proper parameters
2415
- response = client.complete(**api_params)
2416
 
2417
- # Process the response
2418
- if response and response.choices and len(response.choices) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2419
  completed_code = response.choices[0].message.content
 
 
 
 
2420
 
 
 
 
 
 
 
 
 
 
 
2421
  # Extract code from markdown if present
2422
  if "```python" in completed_code:
2423
  completed_code = completed_code.split("```python")[1].split("```")[0]
@@ -2427,10 +2505,10 @@ class MyScene(Scene):
2427
  # Add Scene class if missing
2428
  if "Scene" not in completed_code:
2429
  completed_code = f"""from manim import *
2430
-
2431
- class MyScene(Scene):
2432
- def construct(self):
2433
- {completed_code}"""
2434
 
2435
  # Store the generated code
2436
  st.session_state.generated_code = completed_code
 
349
  # Get the current model name and base URL
350
  model_name = models["model_name"]
351
 
352
+ # Handle model name - extract base name if it has a prefix
353
+ base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
354
+
355
  # Convert message to the appropriate format based on model category
356
+ config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
357
  category = config.get("category", "Other")
358
 
359
  if category == "OpenAI":
 
373
  else:
374
  client = models["openai_client"]
375
 
376
+ # Add openai/ prefix if not present
377
+ if "/" not in model_name:
378
+ full_model_name = f"openai/{model_name}"
379
+ else:
380
+ full_model_name = model_name
381
+
382
+ # For OpenAI models, use developer role instead of system
383
  messages = [
384
+ {"role": "developer", "content": "You are an expert in Manim animations."},
385
  {"role": "user", "content": prompt}
386
  ]
387
 
388
  # Create params
389
  params = {
390
  "messages": messages,
391
+ "model": full_model_name
392
  }
393
 
394
  # Add token parameter
 
2157
 
2158
  # Get model details
2159
  model_name = st.session_state.custom_model
2160
+ base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
2161
+ config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
2162
  category = config.get("category", "Other")
2163
 
2164
  if category == "OpenAI":
 
2186
  # Prepare parameters based on model configuration
2187
  params = {
2188
  "messages": [
2189
+ {"role": "developer", "content": "You are a helpful assistant."},
2190
  {"role": "user", "content": "Hello, this is a connection test."}
2191
  ],
2192
  "model": full_model_name
 
2328
 
2329
  # Add a refresh button to update model connection
2330
  if st.button("πŸ”„ Refresh Model Connection", key="refresh_model_connection"):
2331
+ if st.session_state.ai_models:
2332
  try:
2333
+ # Get model details
 
2334
  model_name = st.session_state.custom_model
2335
+ base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
2336
+ config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
2337
+ category = config.get("category", "Other")
2338
 
2339
+ if category == "OpenAI":
2340
+ # Get token
 
 
 
 
 
2341
  token = get_secret("github_token_api")
 
 
2342
 
2343
+ # Create new OpenAI client
2344
+ from openai import OpenAI
2345
+ client = OpenAI(
2346
+ base_url="https://models.github.ai/inference",
2347
+ api_key=token
2348
  )
 
2349
 
2350
+ # Add openai/ prefix if not present
2351
+ if "/" not in model_name:
2352
+ full_model_name = f"openai/{model_name}"
2353
+ else:
2354
+ full_model_name = model_name
2355
+
2356
+ # Test with minimal prompt
2357
+ response = client.chat.completions.create(
2358
+ messages=[
2359
+ {"role": "developer", "content": "You are a helpful assistant."},
2360
+ {"role": "user", "content": "Hello, this is a test."}
2361
+ ],
2362
+ model=full_model_name,
2363
+ **{config["param_name"]: config[config["param_name"]]}
2364
+ )
2365
+
2366
+ # Update session state
2367
+ st.session_state.ai_models = {
2368
+ "openai_client": client,
2369
+ "model_name": full_model_name,
2370
+ "endpoint": "https://models.github.ai/inference",
2371
+ "last_loaded": datetime.now().isoformat(),
2372
+ "category": category
2373
+ }
2374
+
2375
+ st.success(f"βœ… Connection to {full_model_name} refreshed successfully!")
2376
  else:
2377
+ # Test connection with minimal prompt for Azure models
2378
+ from azure.ai.inference.models import UserMessage
2379
+
2380
+ # Prepare parameters
2381
+ messages = [UserMessage("Hello")]
2382
+ api_params, config = prepare_api_params(messages, model_name)
2383
+
2384
+ # Check if we need a new client with specific API version
2385
+ if config["api_version"] and config["api_version"] != st.session_state.ai_models.get("api_version"):
2386
+ # Create version-specific client if needed
2387
+ token = get_secret("github_token_api")
2388
+ from azure.ai.inference import ChatCompletionsClient
2389
+ from azure.core.credentials import AzureKeyCredential
2390
+
2391
+ client = ChatCompletionsClient(
2392
+ endpoint=st.session_state.ai_models["endpoint"],
2393
+ credential=AzureKeyCredential(token),
2394
+ api_version=config["api_version"]
2395
+ )
2396
+ response = client.complete(**api_params)
2397
+
2398
+ # Update session state with the new client
2399
+ st.session_state.ai_models["client"] = client
2400
+ st.session_state.ai_models["api_version"] = config["api_version"]
2401
+ else:
2402
+ response = st.session_state.ai_models["client"].complete(**api_params)
2403
+
2404
+ st.success(f"βœ… Connection to {model_name} successful!")
2405
+ st.session_state.ai_models["model_name"] = model_name
2406
 
 
 
 
2407
  except Exception as e:
2408
  st.error(f"❌ Connection error: {str(e)}")
2409
  st.info("Please try the Debug Connection section to re-initialize the API connection.")
2410
 
2411
  # AI code generation
2412
+ if st.session_state.ai_models:
2413
  st.markdown("<div class='card'>", unsafe_allow_html=True)
2414
  st.markdown("#### Generate Animation from Description")
2415
  st.write("Describe the animation you want to create, or provide partial code to complete.")
 
2442
  if code_input:
2443
  with st.spinner("AI is generating your animation code..."):
2444
  try:
2445
+ # Get model details
 
2446
  model_name = st.session_state.ai_models["model_name"]
2447
+ base_model_name = model_name.split('/')[-1] if '/' in model_name else model_name
2448
+ config = MODEL_CONFIGS.get(base_model_name, MODEL_CONFIGS["default"])
2449
+ category = config.get("category", "Other")
2450
 
2451
  # Create the prompt
2452
  prompt = f"""Write a complete Manim animation scene based on this code or idea:
2453
+ {code_input}
2454
+
2455
+ The code should be a complete, working Manim animation that includes:
2456
+ - Proper Scene class definition
2457
+ - Constructor with animations
2458
+ - Proper use of self.play() for animations
2459
+ - Proper wait times between animations
2460
+
2461
+ Here's the complete Manim code:
2462
+ """
 
 
 
 
 
 
 
 
2463
 
2464
+ if category == "OpenAI":
2465
+ # Use OpenAI client
2466
+ client = st.session_state.ai_models["openai_client"]
2467
+
2468
+ # Use developer role instead of system
2469
+ messages = [
2470
+ {"role": "developer", "content": "You are an expert in Manim animations."},
2471
+ {"role": "user", "content": prompt}
2472
+ ]
2473
+
2474
+ # Create params
2475
+ params = {
2476
+ "messages": messages,
2477
+ "model": model_name,
2478
+ config["param_name"]: config[config["param_name"]]
2479
+ }
2480
+
2481
+ # Call API
2482
+ response = client.chat.completions.create(**params)
2483
  completed_code = response.choices[0].message.content
2484
+ else:
2485
+ # Use Azure client for non-OpenAI models
2486
+ from azure.ai.inference.models import UserMessage
2487
+ client = st.session_state.ai_models["client"]
2488
 
2489
+ # Convert message format for Azure
2490
+ messages = [UserMessage(prompt)]
2491
+ api_params, _ = prepare_api_params(messages, model_name)
2492
+
2493
+ # Make API call with Azure client
2494
+ response = client.complete(**api_params)
2495
+ completed_code = response.choices[0].message.content
2496
+
2497
+ # Process the response
2498
+ if completed_code:
2499
  # Extract code from markdown if present
2500
  if "```python" in completed_code:
2501
  completed_code = completed_code.split("```python")[1].split("```")[0]
 
2505
  # Add Scene class if missing
2506
  if "Scene" not in completed_code:
2507
  completed_code = f"""from manim import *
2508
+
2509
+ class MyScene(Scene):
2510
+ def construct(self):
2511
+ {completed_code}"""
2512
 
2513
  # Store the generated code
2514
  st.session_state.generated_code = completed_code