AItool commited on
Commit
e82baba
·
verified ·
1 Parent(s): a951453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -83
app.py CHANGED
@@ -1,94 +1,69 @@
1
  import torch
 
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer
4
- import shutil, os
5
- shutil.rmtree(os.path.expanduser("~/.cache/huggingface"), ignore_errors=True)
6
- shutil.rmtree(os.path.expanduser("~/.cache/torch"), ignore_errors=True)
7
 
8
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
9
 
10
- MODEL_OPTIONS = [
11
- "Helsinki-NLP (Tira ondo)", # Round-trip OPUS-MT en→es→en
12
- "FLAN-T5-base (Google gaizki xamar)"
13
- ]
14
- # Cache
15
- CACHE = {}
16
 
17
- # --- FLAN loader (Google-style Euskera correction) ---
18
- def load_flan():
19
- if "flan" not in CACHE:
20
- tok = AutoTokenizer.from_pretrained("google/flan-t5-base")
21
- mdl = AutoModelForSeq2SeqLM.from_pretrained(
22
- "google/flan-t5-base",
23
- low_cpu_mem_usage=True,
24
- torch_dtype="auto"
25
- ).to(DEVICE)
26
- CACHE["flan"] = (mdl, tok)
27
- return CACHE["flan"]
 
28
 
29
- def run_flan(sentence: str) -> str:
30
- model, tok = load_flan()
31
- prompt = f"Euskara zuzen gramatikalki eta idatzi modu naturalean: {sentence}"
32
- inputs = tok(prompt, return_tensors="pt").to(DEVICE)
33
- with torch.no_grad():
34
- out = model.generate(**inputs, max_new_tokens=96, num_beams=4)
35
- return tok.decode(out[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
36
 
37
- # --- Euskera round-trip loader ---
38
- def load_euskera():
39
- if "eus" not in CACHE:
40
- tok1 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-eu-es")
41
- mdl1 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-eu-es").to(DEVICE)
42
- tok2 = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-eu")
43
- mdl2 = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-eu").to(DEVICE)
44
- CACHE["eus"] = (mdl1, tok1, mdl2, tok2)
45
- return CACHE["eus"]
46
 
47
- def run_roundtrip(sentence: str) -> str:
48
- mdl1, tok1, mdl2, tok2 = load_euskera()
49
- # Euskera → Spanish
50
- inputs = tok1(sentence, return_tensors="pt").to(DEVICE)
51
- es_tokens = mdl1.generate(**inputs, max_length=128, num_beams=4)
52
- spanish = tok1.decode(es_tokens[0], skip_special_tokens=True)
53
- # Spanish → Euskera
54
- inputs2 = tok2(spanish, return_tensors="pt").to(DEVICE)
55
- eu_tokens = mdl2.generate(**inputs2, max_length=128, num_beams=4)
56
- euskera = tok2.decode(eu_tokens[0], skip_special_tokens=True)
57
- return euskera.strip()
58
 
59
- # --- Dispatcher ---
60
- def polish(sentence: str, choice: str) -> str:
61
- if not sentence.strip():
62
- return ""
63
- if choice.startswith("FLAN"):
64
- return run_flan(sentence)
65
- elif choice.startswith("Helsinki"):
66
- return run_roundtrip(sentence)
67
- else:
68
- return "Unknown option."
 
69
 
70
- # --- Gradio UI ---
71
- with gr.Blocks(title="HizkuntzLagun: AI Euskera Zuzendu (CPU enabled)") as demo:
72
- gr.Image(
73
- value="banner.png",
74
- show_label=False,
75
- elem_id="banner",
76
- height=200
77
- )
78
- gr.Markdown("### HizkuntzLagun: AI Euskera Zuzedu\n")
79
- gr.Markdown(
80
- """
81
- > ⚡ **Oharra:**
82
- > Tresna honek doako, CPU‑lagunko AI ereduak erabiltzen ditu.
83
- > Azkarra eta eskuragarria izateko diseinatuta dago — ez beti perfektua.
84
- > Zuzenketa azkarrak bai, ez analisi gramatikal sakonak.
85
- > Edozein unetan erabil dezakezu — eguneroko zuzenketa txiki batek saihesten du esaldi traketsen lotsa.
86
- """)
87
- inp = gr.Textbox(lines=3, label="Idatzi Euskeraz esaldi bat, adibidez Gaur Koldo ikusi nuen.", placeholder="Idatzi Euskeraz esaldi bat...")
88
- choice = gr.Dropdown(choices=MODEL_OPTIONS, value="Helsinki-NLP (Tira ondo)", label="Metodoa")
89
- btn = gr.Button("Euskera zuzendu")
90
- out = gr.Textbox(label="Erantzuna")
91
- btn.click(polish, inputs=[inp, choice], outputs=out)
92
 
93
- if __name__ == "__main__":
94
- demo.launch()
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
 
 
 
4
 
5
+ # Supported models (text-only for now)
6
+ MODEL_OPTIONS = {
7
+ "Phi-3.5 Mini Instruct": "microsoft/Phi-3.5-mini-instruct",
8
+ "Phi-3.5 MoE Instruct": "microsoft/Phi-3.5-MoE-instruct",
9
+ "Phi-3 Mini 4K Instruct": "microsoft/Phi-3-mini-4k-instruct",
10
+ "Phi-3 Mini 128K Instruct": "microsoft/Phi-3-mini-128k-instruct"
11
+ }
12
 
13
+ # Cache for loaded models
14
+ loaded_models = {}
 
 
 
 
15
 
16
+ # Load model/tokenizer on demand
17
+ def load_model(model_id):
18
+ if model_id not in loaded_models:
19
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_id,
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.float32
24
+ )
25
+ model.eval()
26
+ loaded_models[model_id] = (tokenizer, model)
27
+ return loaded_models[model_id]
28
 
29
+ # Chat function
30
+ def chat_with_model(user_input, model_choice):
31
+ model_id = MODEL_OPTIONS[model_choice]
32
+ tokenizer, model = load_model(model_id)
33
+
34
+ messages = [{"role": "user", "content": user_input}]
35
+ inputs = tokenizer.apply_chat_template(
36
+ messages,
37
+ add_generation_prompt=True,
38
+ tokenize=True,
39
+ return_dict=True,
40
+ return_tensors="pt"
41
+ ).to("cpu")
42
 
43
+ with torch.no_grad():
44
+ outputs = model.generate(
45
+ **inputs,
46
+ max_new_tokens=100,
47
+ do_sample=False,
48
+ temperature=0.7,
49
+ top_p=0.9
50
+ )
 
51
 
52
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
53
+ return response.strip()
 
 
 
 
 
 
 
 
 
54
 
55
+ # Gradio UI
56
+ with gr.Blocks(title="Phi-3 Instruct Explorer") as demo:
57
+ gr.Markdown("## 🧠 Phi-3 Instruct Explorer\nSwitch between Phi-3 instruct models and test responses on CPU.")
58
+ with gr.Row():
59
+ model_choice = gr.Dropdown(label="Choose a model", choices=list(MODEL_OPTIONS.keys()), value="Phi-3.5 Mini Instruct")
60
+ with gr.Row():
61
+ user_input = gr.Textbox(label="Your message", placeholder="Ask me anything...")
62
+ with gr.Row():
63
+ output = gr.Textbox(label="Model response")
64
+ with gr.Row():
65
+ submit = gr.Button("Generate")
66
 
67
+ submit.click(fn=chat_with_model, inputs=[user_input, model_choice], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ demo.launch()