cdancette commited on
Commit
882bd8f
·
1 Parent(s): de6c079
Files changed (2) hide show
  1. app.py +0 -14
  2. inference.py +3 -5
app.py CHANGED
@@ -616,24 +616,10 @@ def build_demo() -> gr.Blocks:
616
  outputs=[main_prediction, prediction_probs],
617
  )
618
 
619
- gr.Markdown(
620
- """
621
- ### Notes
622
-
623
- - Configure the `HF_TOKEN` secret in your Space to load private checkpoints
624
- and datasets from the `raidium` organisation.
625
- - When masks are available in the dataset sample, their contours are drawn on the
626
- image for visual reference using OpenCV.
627
- - Uploaded images must be single-channel arrays. Multi-channel inputs are
628
- converted to grayscale automatically.
629
- """
630
- )
631
-
632
  return demo
633
 
634
 
635
  demo = build_demo()
636
 
637
-
638
  if __name__ == "__main__":
639
  demo.launch()
 
616
  outputs=[main_prediction, prediction_probs],
617
  )
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  return demo
620
 
621
 
622
  demo = build_demo()
623
 
 
624
  if __name__ == "__main__":
625
  demo.launch()
inference.py CHANGED
@@ -82,11 +82,9 @@ def prepare_mask_for_model(mask: Any) -> Optional[torch.Tensor]:
82
  if mask_arr.size == 0:
83
  return None
84
 
85
- if mask_arr.ndim == 3:
86
- tensor = mask_transform(mask_arr.transpose(2, 0, 1))
87
- # Match the shape produced in simple_test.py so the model receives
88
- # (batch, height, width, channels) style tensors.
89
- tensor = tensor.transpose(1, 3).transpose(1, 2)
90
  else:
91
  tensor = mask_transform(torch.tensor([mask_arr]))
92
  tensor = tensor.unsqueeze(0)
 
82
  if mask_arr.size == 0:
83
  return None
84
 
85
+ if mask_arr.ndim == 3: # (H, W, slices)
86
+ tensor = mask_transform(mask_arr.transpose(2, 0, 1)) # (1, slices, H, W)
87
+ tensor = tensor.transpose(1, 3).transpose(1, 2) #
 
 
88
  else:
89
  tensor = mask_transform(torch.tensor([mask_arr]))
90
  tensor = tensor.unsqueeze(0)