MohamedRashad commited on
Commit
383f826
1 Parent(s): 9312c4b

Refactor model ID handling in app.py and update requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +32 -32
app.py CHANGED
@@ -22,44 +22,44 @@ def load_model_a(model_id):
22
  global tokenizer_a, model_a
23
  tokenizer_a = AutoTokenizer.from_pretrained(model_id)
24
  print(f"model A: {tokenizer_a.eos_token}")
25
- model_a = AutoModelForCausalLM.from_pretrained(
26
- model_id,
27
- torch_dtype=torch.bfloat16,
28
- device_map="auto",
29
- attn_implementation="flash_attention_2",
30
- trust_remote_code=True,
31
- ).eval()
32
- # try:
33
- # except:
34
- # print(f"Using default attention implementation in {model_id}")
35
- # model_a = AutoModelForCausalLM.from_pretrained(
36
- # model_id,
37
- # torch_dtype=torch.bfloat16,
38
- # device_map="auto",
39
- # trust_remote_code=True,
40
- # ).eval()
41
  return gr.update(label=model_id)
42
 
43
  def load_model_b(model_id):
44
  global tokenizer_b, model_b
45
  tokenizer_b = AutoTokenizer.from_pretrained(model_id)
46
  print(f"model B: {tokenizer_b.eos_token}")
47
- model_b = AutoModelForCausalLM.from_pretrained(
48
- model_id,
49
- torch_dtype=torch.bfloat16,
50
- device_map="auto",
51
- attn_implementation="flash_attention_2",
52
- trust_remote_code=True,
53
- ).eval()
54
- # try:
55
- # except:
56
- # print(f"Using default attention implementation in {model_id}")
57
- # model_b = AutoModelForCausalLM.from_pretrained(
58
- # model_id,
59
- # torch_dtype=torch.bfloat16,
60
- # device_map="auto",
61
- # trust_remote_code=True,
62
- # ).eval()
63
  return gr.update(label=model_id)
64
 
65
  @spaces.GPU()
 
22
  global tokenizer_a, model_a
23
  tokenizer_a = AutoTokenizer.from_pretrained(model_id)
24
  print(f"model A: {tokenizer_a.eos_token}")
25
+ try:
26
+ model_a = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ torch_dtype=torch.bfloat16,
29
+ device_map="auto",
30
+ attn_implementation="flash_attention_2",
31
+ trust_remote_code=True,
32
+ ).eval()
33
+ except:
34
+ print(f"Using default attention implementation in {model_id}")
35
+ model_a = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="auto",
39
+ trust_remote_code=True,
40
+ ).eval()
41
  return gr.update(label=model_id)
42
 
43
  def load_model_b(model_id):
44
  global tokenizer_b, model_b
45
  tokenizer_b = AutoTokenizer.from_pretrained(model_id)
46
  print(f"model B: {tokenizer_b.eos_token}")
47
+ try:
48
+ model_b = AutoModelForCausalLM.from_pretrained(
49
+ model_id,
50
+ torch_dtype=torch.bfloat16,
51
+ device_map="auto",
52
+ attn_implementation="flash_attention_2",
53
+ trust_remote_code=True,
54
+ ).eval()
55
+ except:
56
+ print(f"Using default attention implementation in {model_id}")
57
+ model_b = AutoModelForCausalLM.from_pretrained(
58
+ model_id,
59
+ torch_dtype=torch.bfloat16,
60
+ device_map="auto",
61
+ trust_remote_code=True,
62
+ ).eval()
63
  return gr.update(label=model_id)
64
 
65
  @spaces.GPU()