merve HF staff commited on
Commit
dc4cc7d
β€’
1 Parent(s): f1b118e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio import FlaggingCallback
4
+ from gradio.components import IOComponent
5
+
6
+ from datasets import load_dataset
7
+ from typing import List, Optional, Any
8
+ import argilla as rg
9
+ import os
10
+
11
+ def load_data():
12
+ ds = load_dataset("merve/turkish_instructions", split="train", streaming=True)
13
+ sample = next(iter(ds))
14
+
15
+ return sample
16
+
17
+
18
+ def create_record(sample, feedback):
19
+ status = "Validated" if feedback == "Doğru" else "Default"
20
+ #sample = next(iter(ds))
21
+ fields = {
22
+ "talimat": sample["talimat"],
23
+ "input": sample["giriş"],
24
+ "response": sample["Γ§Δ±ktΔ±"]
25
+ }
26
+
27
+ # the label will come from the flag object in Gradio
28
+ label = "True"
29
+
30
+ record = rg.TextClassificationRecord(
31
+ inputs=fields,
32
+ annotation=label,
33
+ status=status,
34
+ metadata={"feedback": feedback}
35
+ )
36
+
37
+ print(record)
38
+ return record
39
+
40
+
41
+
42
+
43
+ class ArgillaLogger(FlaggingCallback):
44
+ def __init__(self, api_url, api_key, dataset_name):
45
+ rg.init(api_url=api_url, api_key=api_key)
46
+ self.dataset_name = dataset_name
47
+ def setup(self, components: List[IOComponent], flagging_dir: str):
48
+ pass
49
+ def flag(
50
+ self,
51
+ flag_data: List[Any],
52
+ flag_option: Optional[str] = None,
53
+ flag_index: Optional[int] = None,
54
+ username: Optional[str] = None,
55
+ ) -> int:
56
+ text = flag_data[0]
57
+ inference = flag_data[1]
58
+ rg.log(name=self.dataset_name, records=create_record(text, flag_option))
59
+
60
+
61
+
62
+
63
+ gr.Interface(
64
+ title = "ALPACA Veriseti DΓΌzeltme ArayΓΌzΓΌ",
65
+ description = "",
66
+ allow_flagging="manual",
67
+ flagging_callback=ArgillaLogger(
68
+ api_url="https://sandbox.argilla.io",
69
+ api_key=os.getenv("TEAM_API_KEY"),
70
+ dataset_name="alpaca-flags"
71
+ ),
72
+ flagging_options=["Doğru", "Yanlış", "Belirsiz"]
73
+ ).launch()