pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# Load the model
from transformers import BartForConditionalGeneration
from transformers import T5ForConditionalGeneration
def cal(model):
pytorch_total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return pytorch_total_trainable_params
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
print("bart-base")
print(cal(model)) # 6L 139420416 139M
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
print("bart-base")
print(cal(model)) # 12L 406291456 406M
model = T5ForConditionalGeneration.from_pretrained("t5-small")
print("t5-small")
print(cal(model)) # 6L 60506624 65M
model = T5ForConditionalGeneration.from_pretrained("t5-base")
print("t5-base")
print(cal(model)) # 12L 222903552 223M
model = T5ForConditionalGeneration.from_pretrained("t5-large")
print("t5-large")
print(cal(model)) # 24L 737668096 738M
往现有tokenizer里加一些特殊token
num_added_toks = tokenizer.add_tokens(['[EOT]'], special_tokens=True) ##This line is updated
model.resize_token_embeddings(len(tokenizer))
###The tokenizer has to be saved if it has to be reused
tokenizer.save_pretrained(<output_dir>)
# (c) Meta Platforms, Inc. and affiliates.
# https://pytorch.org/blog/understanding-gpu-memory-1/
import logging
import socket
from datetime import datetime, timedelta
import torch
from torchvision import models
logging.basicConfig(
format="%(levelname)s:%(asctime)s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000
def start_record_memory_history() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
logger.info("Starting snapshot record_memory_history")
torch.cuda.memory._record_memory_history(
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
)
def stop_record_memory_history() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not recording memory history")
return
logger.info("Stopping snapshot record_memory_history")
torch.cuda.memory._record_memory_history(enabled=None)
def export_memory_snapshot() -> None:
if not torch.cuda.is_available():
logger.info("CUDA unavailable. Not exporting memory snapshot")
return
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
try:
logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")
return
# Simple Resnet50 example to demonstrate how to capture memory visuals.
def run_resnet50(num_iters=5, device="cuda:0"):
model = models.resnet50().to(device=device)
inputs = torch.randn(1, 3, 224, 224, device=device)
labels = torch.rand_like(model(inputs))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
# Start recording memory snapshot history
start_record_memory_history()
for _ in range(num_iters):
pred = model(inputs)
loss_fn(pred, labels).backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Create the memory snapshot file
export_memory_snapshot()
# Stop recording memory snapshot history
stop_record_memory_history()
if __name__ == "__main__":
# Run the resnet50 model
run_resnet50()
同时profile cpu和显存
# (c) Meta Platforms, Inc. and affiliates.
# https://pytorch.org/blog/understanding-gpu-memory-1/
import logging
import socket
from datetime import datetime, timedelta
import torch
from torch.autograd.profiler import record_function
from torchvision import models
logging.basicConfig(
format="%(levelname)s:%(asctime)s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"
def trace_handler(prof: torch.profiler.profile):
# Prefix for file names.
host_name = socket.gethostname()
timestamp = datetime.now().strftime(TIME_FORMAT_STR)
file_prefix = f"{host_name}_{timestamp}"
# Construct the trace file.
prof.export_chrome_trace(f"{file_prefix}.json.gz")
# Construct the memory timeline file.
prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")
def run_resnet50(num_iters=5, device="cuda:0"):
model = models.resnet50().to(device=device)
inputs = torch.randn(1, 3, 224, 224, device=device)
labels = torch.rand_like(model(inputs))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
record_shapes=True,
profile_memory=True,
with_stack=True,
on_trace_ready=trace_handler,
) as prof:
for _ in range(num_iters):
prof.step()
with record_function("## forward ##"):
pred = model(inputs)
with record_function("## backward ##"):
loss_fn(pred, labels).backward()
with record_function("## optimizer ##"):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if __name__ == "__main__":
# Warm up
run_resnet50()
# Run the resnet50 model
run_resnet50()
各型号gpu对比
查看python的栈
pip install py-spy
py-spy dump --pid 1199
打出来:
Process 1199: /usr/bin/python3.10 -u torch_main.py
Python v3.10.14 (/usr/bin/python3.10)
Thread 0x7F62A2C43740 (active): "MainThread"
_wait_for_tstate_lock (threading.py:1116)
join (threading.py:1096)
main (torch_main.py:776)
<module> (torch_main.py:785)
Thread 0xAABBBCC (idle): "Thread-1"
wait (threading.py:324)
wait (threading.py:607)
run (threading.py:1376)
_bootstrap_inner (threading.py:1016)
_bootstrap (threading.py:973)
Thread 0xAAAAA (idle): "Thread-3 (process)"
wait (threading.py:320)
get (queue.py:171)
process (abase_writer.py:73)
run (threading.py:953)
_bootstrap_inner (threading.py:1016)
_bootstrap (threading.py:973)
Thread 0xA992ACDA (idle): "Thread-4 (process)"
wait (threading.py:320)
get (queue.py:171)
process (abase_writer.py:73)
run (threading.py:953)
_bootstrap_inner (threading.py:1016)
_bootstrap (threading.py:973)
Thread 0xAFF11AA (active): "Thread-5 (read_file)"
get_seq (ecom_seq_reader.py:200)
read_file (torch_main.py:494)
run (threading.py:953)
_bootstrap_inner (threading.py:1016)
_bootstrap (threading.py:973)
Thread 0x9922BCDA (idle): "Thread-6"
wait (threading.py:324)
wait (threading.py:607)
run (tqdm/_monitor.py:60)
_bootstrap_inner (threading.py:1016)
_bootstrap (threading.py:973)