|
|
|
""" |
|
Main application file for SHASHAΒ AI (Gradio). |
|
Only change: enlarge logo to 120β―Γβ―120β―px. |
|
""" |
|
|
|
import gradio as gr |
|
from typing import Optional, Dict, List, Tuple, Any |
|
import os |
|
|
|
|
|
from constants import ( |
|
HTML_SYSTEM_PROMPT, |
|
TRANSFORMERS_JS_SYSTEM_PROMPT, |
|
AVAILABLE_MODELS, |
|
DEMO_LIST, |
|
) |
|
from hf_client import get_inference_client |
|
from tavily_search import enhance_query_with_search |
|
from utils import ( |
|
extract_text_from_file, |
|
extract_website_content, |
|
apply_search_replace_changes, |
|
history_to_messages, |
|
history_to_chatbot_messages, |
|
remove_code_block, |
|
parse_transformers_js_output, |
|
format_transformers_js_output, |
|
) |
|
from deploy import send_to_sandbox |
|
|
|
History = List[Tuple[str, str]] |
|
Model = Dict[str, Any] |
|
|
|
SUPPORTED_LANGUAGES = [ |
|
"python","c","cpp","markdown","latex","json","html","css","javascript","jinja2", |
|
"typescript","yaml","dockerfile","shell","r","sql","sql-msSQL","sql-mySQL", |
|
"sql-mariaDB","sql-sqlite","sql-cassandra","sql-plSQL","sql-hive","sql-pgSQL", |
|
"sql-gql","sql-gpSQL","sql-sparkSQL","sql-esper" |
|
] |
|
|
|
def get_model_details(name:str)->Optional[Model]: |
|
return next((m for m in AVAILABLE_MODELS if m["name"]==name), None) |
|
|
|
def generation_code( |
|
query:Optional[str], |
|
file:Optional[str], |
|
website_url:Optional[str], |
|
current_model:Model, |
|
enable_search:bool, |
|
language:str, |
|
history:Optional[History], |
|
)->Tuple[str,History,str,List[Dict[str,str]]]: |
|
query = query or "" |
|
history = history or [] |
|
try: |
|
system_prompt = ( |
|
HTML_SYSTEM_PROMPT if language=="html" else |
|
TRANSFORMERS_JS_SYSTEM_PROMPT if language=="transformers.js" |
|
else f"You are an expert {language} developer. Write clean, idiomatic {language} code." |
|
) |
|
model_id = current_model["id"] |
|
provider = ( |
|
"openai" if model_id.startswith("openai/") or model_id in {"gpt-4","gpt-3.5-turbo"} |
|
else "gemini" if model_id.startswith(("gemini/","google/")) |
|
else "fireworks-ai" if model_id.startswith("fireworks-ai/") |
|
else "auto" |
|
) |
|
|
|
msgs = history_to_messages(history, system_prompt) |
|
context = query |
|
if file: |
|
context += f"\n\n[Attached file]\n{extract_text_from_file(file)[:5000]}" |
|
if website_url: |
|
wtext = extract_website_content(website_url) |
|
if not wtext.startswith("Error"): |
|
context += f"\n\n[Website content]\n{wtext[:8000]}" |
|
msgs.append({"role":"user","content":enhance_query_with_search(context, enable_search)}) |
|
|
|
client = get_inference_client(model_id, provider) |
|
resp = client.chat.completions.create(model=model_id, messages=msgs,max_tokens=16000,temperature=0.1) |
|
content = resp.choices[0].message.content |
|
|
|
except Exception as e: |
|
err = f"β **Error:**\n```\n{e}\n```" |
|
history.append((query, err)) |
|
return "", history, "", history_to_chatbot_messages(history) |
|
|
|
if language=="transformers.js": |
|
files = parse_transformers_js_output(content) |
|
code = format_transformers_js_output(files) |
|
preview = send_to_sandbox(files.get("index.html","")) |
|
else: |
|
cleaned = remove_code_block(content) |
|
code = apply_search_replace_changes(history[-1][1], cleaned) if history and not history[-1][1].startswith("β") else cleaned |
|
preview = send_to_sandbox(code) if language=="html" else "" |
|
|
|
new_hist = history + [(query, code)] |
|
return code, new_hist, preview, history_to_chatbot_messages(new_hist) |
|
|
|
|
|
CUSTOM_CSS = """ |
|
body{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;} |
|
#main_title{text-align:center;font-size:2.5rem;margin-top:.5rem;} |
|
#subtitle{text-align:center;color:#4a5568;margin-bottom:2rem;} |
|
.gradio-container{background-color:#f7fafc;} |
|
#gen_btn{box-shadow:0 4px 6px rgba(0,0,0,0.1);} |
|
""" |
|
|
|
LOGO_PATH = "assets/logo.png" |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), |
|
css=CUSTOM_CSS, |
|
title="Shasha AI") as demo: |
|
history_state = gr.State([]) |
|
initial_model = AVAILABLE_MODELS[0] |
|
model_state = gr.State(initial_model) |
|
|
|
|
|
if os.path.exists(LOGO_PATH): |
|
gr.Image(value=LOGO_PATH, height=120, width=120, |
|
show_label=False, container=False) |
|
|
|
gr.Markdown("# π Shasha AI", elem_id="main_title") |
|
gr.Markdown("Your AI partner for generating, modifying, and understanding code.", elem_id="subtitle") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### 1. Select Model") |
|
model_dd = gr.Dropdown([m["name"] for m in AVAILABLE_MODELS], |
|
value=initial_model["name"], label="AI Model") |
|
|
|
gr.Markdown("### 2. Provide Context") |
|
with gr.Tabs(): |
|
with gr.Tab("π Prompt"): |
|
prompt_in = gr.Textbox(lines=7, placeholder="Describe your request...", show_label=False) |
|
with gr.Tab("π File"): |
|
file_in = gr.File(type="filepath") |
|
with gr.Tab("π Website"): |
|
url_in = gr.Textbox(placeholder="https://example.com") |
|
|
|
gr.Markdown("### 3. Configure Output") |
|
lang_dd = gr.Dropdown(SUPPORTED_LANGUAGES, value="html", label="Target Language") |
|
search_chk = gr.Checkbox(label="Enable Web Search") |
|
with gr.Row(): |
|
clr_btn = gr.Button("Clear Session", variant="secondary") |
|
gen_btn = gr.Button("Generate Code", variant="primary", elem_id="gen_btn") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Tabs(): |
|
with gr.Tab("π» Code"): |
|
code_out = gr.Code(language="html", interactive=True) |
|
with gr.Tab("ποΈ Live Preview"): |
|
preview_out = gr.HTML() |
|
with gr.Tab("π History"): |
|
chat_out = gr.Chatbot(type="messages") |
|
|
|
model_dd.change(lambda n: get_model_details(n) or initial_model, |
|
inputs=[model_dd], outputs=[model_state]) |
|
|
|
gen_btn.click( |
|
fn=generation_code, |
|
inputs=[prompt_in, file_in, url_in, model_state, search_chk, lang_dd, history_state], |
|
outputs=[code_out, history_state, preview_out, chat_out], |
|
) |
|
|
|
clr_btn.click( |
|
lambda: ("", None, "", [], "", "", []), |
|
outputs=[prompt_in, file_in, url_in, history_state, code_out, preview_out, chat_out], |
|
queue=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |
|
|