Reduce VRAM consumption by swapping `cuda()` and `to(torch.bfloat16)`

#2
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -43,7 +43,7 @@ def process_ocr_task(image, model_size, task_type, ref_text):
43
  return "Please upload an image first.", None
44
 
45
  print("πŸš€ Moving model to GPU...")
46
- model_gpu = model.cuda().to(torch.bfloat16)
47
  print("βœ… Model is on GPU.")
48
 
49
  with tempfile.TemporaryDirectory() as output_path:
 
43
  return "Please upload an image first.", None
44
 
45
  print("πŸš€ Moving model to GPU...")
46
+ model_gpu = model.to(torch.bfloat16).cuda()
47
  print("βœ… Model is on GPU.")
48
 
49
  with tempfile.TemporaryDirectory() as output_path: