diegomiranda commited on
Commit
4ce6bad
1 Parent(s): 5bab83b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -0
README.md CHANGED
@@ -16,6 +16,56 @@ thumbnail: https://h2o.ai/etc.clientlibs/h2o/clientlibs/clientlib-site/resources
16
  This model was trained using [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
17
  - Base model: [EleutherAI/pythia-70m-deduped-v0](https://huggingface.co/EleutherAI/pythia-70m-deduped-v0)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  ## Usage
21
 
 
16
  This model was trained using [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
17
  - Base model: [EleutherAI/pythia-70m-deduped-v0](https://huggingface.co/EleutherAI/pythia-70m-deduped-v0)
18
 
19
+ ## Usage on CPU
20
+
21
+ ```python
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ import torch
24
+
25
+ def generate_response(prompt, model_name):
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ model_name,
28
+ use_fast=True,
29
+ trust_remote_code=True,
30
+ )
31
+
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_name,
34
+ torch_dtype=torch.float32,
35
+ device_map={"": "cpu"},
36
+ trust_remote_code=True,
37
+ )
38
+ model.cpu().eval()
39
+
40
+ inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cpu")
41
+
42
+ tokens = model.generate(
43
+ input_ids=inputs["input_ids"],
44
+ attention_mask=inputs["attention_mask"],
45
+ min_new_tokens=2,
46
+ max_new_tokens=500,
47
+ do_sample=False,
48
+ num_beams=2,
49
+ temperature=float(0.0),
50
+ repetition_penalty=float(1.0),
51
+ renormalize_logits=True
52
+ )[0]
53
+
54
+ tokens = tokens[inputs["input_ids"].shape[1]:]
55
+ answer = tokenizer.decode(tokens, skip_special_tokens=True)
56
+
57
+ return answer
58
+ ```
59
+
60
+ After defining the function, you can define the prompt and the model from another file
61
+
62
+ ```python
63
+ model_name = "diegomiranda/text-to-cypher"
64
+ prompt = "Create a Cypher statement to answer the following question:Retorne os processos de Direito Tributário que se baseiam em lei 939 de 1992?<|endoftext|>"
65
+ response = generate_response(prompt, model_name)
66
+ print(response)
67
+ ```
68
+
69
 
70
  ## Usage
71