optimizations
This commit is contained in:
parent
36c122d9a7
commit
5d58695eb4
|
|
@ -221,7 +221,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
|
|||
data class Detection(val label: String, val confidence: Float, val bounds: RectF)
|
||||
}
|
||||
|
||||
|
||||
class MockPoseAnalyzer : PoseAnalyzer {
|
||||
|
||||
private val segmenter by lazy {
|
||||
|
|
@ -233,173 +232,114 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
|||
|
||||
override suspend fun analyze(input: PipelineInput): Instruction {
|
||||
|
||||
val detectionResult =
|
||||
input.previousDetectionResult
|
||||
?: return Instruction("No detection result", isValid = false)
|
||||
val detection = input.previousDetectionResult
|
||||
?: return Instruction("No detection", isValid = false)
|
||||
|
||||
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) {
|
||||
return Instruction("Animal not detected", isValid = false)
|
||||
}
|
||||
val bounds = detection.animalBounds
|
||||
?: return Instruction("Animal not detected", isValid = false)
|
||||
|
||||
val image =
|
||||
input.image
|
||||
?: return Instruction("No image", isValid = false)
|
||||
val image = input.image
|
||||
?: return Instruction("No image", isValid = false)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 1. Get reference signed-weight silhouette (Array<FloatArray>)
|
||||
// 1. Reference silhouette (FloatArray)
|
||||
// --------------------------------------------------------------------
|
||||
val weightedSilhouette: Array<FloatArray> =
|
||||
SilhouetteManager.getWeightedMask(input.orientation)
|
||||
?: return Instruction(
|
||||
"Silhouette not found for ${input.orientation}",
|
||||
isValid = false
|
||||
)
|
||||
val reference = SilhouetteManager.getWeightedMask(input.orientation)
|
||||
?: return Instruction("Silhouette missing", isValid = false)
|
||||
|
||||
val refHeight = weightedSilhouette.size
|
||||
val refWidth = weightedSilhouette[0].size
|
||||
val refH = reference.size
|
||||
val refW = reference[0].size
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 2. Scale input image to match reference resolution
|
||||
// 2. Crop only animal region (BIG WIN)
|
||||
// --------------------------------------------------------------------
|
||||
val scaledImage = Bitmap.createScaledBitmap(
|
||||
val cropped = Bitmap.createBitmap(
|
||||
image,
|
||||
refWidth,
|
||||
refHeight,
|
||||
true
|
||||
bounds.left.toInt(),
|
||||
bounds.top.toInt(),
|
||||
bounds.width().toInt(),
|
||||
bounds.height().toInt()
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 3. Get animal segmentation → BINARY mask (0 / 1)
|
||||
// --------------------------------------------------------------------
|
||||
val animalMask: Array<IntArray> =
|
||||
getAnimalSegmentationMask(scaledImage)
|
||||
?: return Instruction("Segmentation failed", isValid = false)
|
||||
val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true)
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 4. Compute signed-weight alignment score
|
||||
// 3. Get binary segmentation mask (ByteArray)
|
||||
// --------------------------------------------------------------------
|
||||
val similarity = calculateSignedSimilarity(
|
||||
animalMask,
|
||||
weightedSilhouette
|
||||
)
|
||||
val mask = getAnimalMaskFast(scaled)
|
||||
?: return Instruction("Segmentation failed", isValid = false)
|
||||
|
||||
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity")
|
||||
|
||||
val isValid = similarity >= 0.40f
|
||||
val msg = if (isValid) {
|
||||
"Pose Correct"
|
||||
} else {
|
||||
"Pose Incorrect (Score: %.2f)".format(similarity)
|
||||
}
|
||||
// --------------------------------------------------------------------
|
||||
// 4. Fast signed similarity
|
||||
// --------------------------------------------------------------------
|
||||
val score = calculateSignedSimilarityFast(mask, reference)
|
||||
|
||||
val valid = score >= 0.40f
|
||||
return Instruction(
|
||||
message = msg,
|
||||
isValid = isValid,
|
||||
result = detectionResult
|
||||
message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score),
|
||||
isValid = valid,
|
||||
result = detection
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Converts segmentation output to a binary mask:
|
||||
* 1 = animal present
|
||||
* 0 = background
|
||||
* ML Kit → Binary mask (0/1) using ByteArray
|
||||
*/
|
||||
private suspend fun getAnimalSegmentationMask(
|
||||
bitmap: Bitmap
|
||||
): Array<IntArray>? {
|
||||
private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
|
||||
suspendCancellableCoroutine { cont ->
|
||||
|
||||
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null
|
||||
val image = InputImage.fromBitmap(bitmap, 0)
|
||||
segmenter.process(image)
|
||||
.addOnSuccessListener { result ->
|
||||
val mask = result.foregroundConfidenceMask
|
||||
?: run {
|
||||
cont.resume(null)
|
||||
return@addOnSuccessListener
|
||||
}
|
||||
|
||||
val w = maskBitmap.width
|
||||
val h = maskBitmap.height
|
||||
|
||||
val pixels = IntArray(w * h)
|
||||
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
|
||||
|
||||
val result = Array(h) { IntArray(w) }
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val alpha = pixels[y * w + x] ushr 24
|
||||
result[y][x] = if (alpha > 0) 1 else 0
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// REAL segmentation using ML Kit
|
||||
// ----------------------------------------------------------------------
|
||||
private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
|
||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||
segmenter.process(inputImage)
|
||||
.addOnSuccessListener { result ->
|
||||
val mask = result.foregroundConfidenceMask
|
||||
if (mask != null) {
|
||||
val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
|
||||
val buffer = mask
|
||||
buffer.rewind()
|
||||
for (y in 0 until bitmap.height) {
|
||||
for (x in 0 until bitmap.width) {
|
||||
val confidence = buffer.get()
|
||||
val alpha = (confidence * 255).toInt()
|
||||
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
|
||||
}
|
||||
}
|
||||
continuation.resume(maskBitmap)
|
||||
} else {
|
||||
Log.e("MockPoseAnalyzer", "Segmentation result null")
|
||||
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
|
||||
}
|
||||
}
|
||||
.addOnFailureListener { e ->
|
||||
Log.e("MockPoseAnalyzer", "Segmentation failed", e)
|
||||
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Weighted Jaccard Similarity
|
||||
// mask = predicted (Bitmap)
|
||||
// weightMap = ground truth silhouette weights (Bitmap)
|
||||
// ----------------------------------------------------------------------
|
||||
val out = ByteArray(bitmap.width * bitmap.height)
|
||||
|
||||
for (i in out.indices) {
|
||||
out[i] = if (buffer.get() > 0.5f) 1 else 0
|
||||
}
|
||||
|
||||
cont.resume(out)
|
||||
}
|
||||
.addOnFailureListener {
|
||||
cont.resume(null)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Signed alignment similarity
|
||||
* Range: [-1, 1]
|
||||
* Signed similarity (flat arrays, cache-friendly)
|
||||
* Range [-1, 1]
|
||||
*/
|
||||
private fun calculateSignedSimilarity(
|
||||
mask: Array<IntArray>, // 0 / 1
|
||||
reference: Array<FloatArray> // [-1, +1]
|
||||
private fun calculateSignedSimilarityFast(
|
||||
mask: ByteArray,
|
||||
reference: Array<FloatArray>
|
||||
): Float {
|
||||
|
||||
var score = 0f
|
||||
var maxScore = 0f
|
||||
var idx = 0
|
||||
|
||||
val h = reference.size
|
||||
val w = reference[0].size
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val r = reference[y][x]
|
||||
val m = mask[y][x].toFloat()
|
||||
|
||||
// accumulate positive reference mass only
|
||||
if (r > 0f) {
|
||||
maxScore += r
|
||||
}
|
||||
|
||||
score += m * r
|
||||
for (y in reference.indices) {
|
||||
val row = reference[y]
|
||||
for (x in row.indices) {
|
||||
val r = row[x]
|
||||
if (r > 0f) maxScore += r
|
||||
score += mask[idx++] * r
|
||||
}
|
||||
}
|
||||
|
||||
return if (maxScore == 0f) 0f else score / maxScore
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
class DefaultCaptureHandler : CaptureHandler {
|
||||
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
|
||||
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ class CameraViewModel(
|
|||
val uiState: StateFlow<CameraUiState> = _uiState.asStateFlow()
|
||||
private val tilt = tiltSensorManager.tilt
|
||||
|
||||
private var frameCounter = 0
|
||||
private val frameSkipInterval = 5
|
||||
|
||||
init {
|
||||
tiltSensorManager.start()
|
||||
}
|
||||
|
|
@ -49,6 +52,9 @@ class CameraViewModel(
|
|||
image: Bitmap,
|
||||
deviceOrientation: Int
|
||||
) {
|
||||
frameCounter = (frameCounter + 1) % frameSkipInterval
|
||||
if (frameCounter != 0) return
|
||||
|
||||
viewModelScope.launch {
|
||||
val currentTilt = tilt.value
|
||||
val input = PipelineInput(
|
||||
|
|
|
|||
Loading…
Reference in New Issue