Spaces:
Running
Running
Upload 3 files
Browse files- app.py +1 -0
- prompt_generator.py +3 -3
app.py
CHANGED
|
@@ -48,6 +48,7 @@ logger = logging.getLogger(__name__)
|
|
| 48 |
|
| 49 |
# Constants
|
| 50 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
|
|
|
| 51 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
| 52 |
|
| 53 |
# PyTorch settings for better performance and determinism
|
|
|
|
| 48 |
|
| 49 |
# Constants
|
| 50 |
IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
|
| 51 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 52 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
|
| 53 |
|
| 54 |
# PyTorch settings for better performance and determinism
|
prompt_generator.py
CHANGED
|
@@ -119,8 +119,8 @@ def load_model():
|
|
| 119 |
_model = AutoModelForCausalLM.from_pretrained(
|
| 120 |
model_path,
|
| 121 |
torch_dtype=torch_dtype,
|
| 122 |
-
|
| 123 |
-
|
| 124 |
low_cpu_mem_usage=True,
|
| 125 |
)
|
| 126 |
|
|
@@ -277,7 +277,7 @@ masterpiece, best quality, highresなどの品質に関連するタグは後工
|
|
| 277 |
logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
|
| 278 |
|
| 279 |
# 生成
|
| 280 |
-
logger.info("before
|
| 281 |
with torch.no_grad():
|
| 282 |
generated_ids = model.generate(
|
| 283 |
input_ids=inputs,
|
|
|
|
| 119 |
_model = AutoModelForCausalLM.from_pretrained(
|
| 120 |
model_path,
|
| 121 |
torch_dtype=torch_dtype,
|
| 122 |
+
device_map=device_map,
|
| 123 |
+
use_cache=True,
|
| 124 |
low_cpu_mem_usage=True,
|
| 125 |
)
|
| 126 |
|
|
|
|
| 277 |
logger.warning(f"Input tokens were too many and have been truncated to {max_input_length}")
|
| 278 |
|
| 279 |
# 生成
|
| 280 |
+
logger.info("before torch.no_grad")
|
| 281 |
with torch.no_grad():
|
| 282 |
generated_ids = model.generate(
|
| 283 |
input_ids=inputs,
|