optimizations

This commit is contained in:
SaiD 2025-12-13 12:15:03 +05:30
parent 36c122d9a7
commit 5d58695eb4
2 changed files with 72 additions and 126 deletions

View File

@ -221,7 +221,6 @@ class TFLiteObjectDetector(context: Context) : ObjectDetector {
data class Detection(val label: String, val confidence: Float, val bounds: RectF)
}
class MockPoseAnalyzer : PoseAnalyzer {
private val segmenter by lazy {
@ -233,173 +232,114 @@ class MockPoseAnalyzer : PoseAnalyzer {
override suspend fun analyze(input: PipelineInput): Instruction {
val detectionResult =
input.previousDetectionResult
?: return Instruction("No detection result", isValid = false)
val detection = input.previousDetectionResult
?: return Instruction("No detection", isValid = false)
if (!detectionResult.isAnimalDetected || detectionResult.animalBounds == null) {
return Instruction("Animal not detected", isValid = false)
}
val bounds = detection.animalBounds
?: return Instruction("Animal not detected", isValid = false)
val image =
input.image
?: return Instruction("No image", isValid = false)
val image = input.image
?: return Instruction("No image", isValid = false)
// --------------------------------------------------------------------
// 1. Get reference signed-weight silhouette (Array<FloatArray>)
// 1. Reference silhouette (FloatArray)
// --------------------------------------------------------------------
val weightedSilhouette: Array<FloatArray> =
SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction(
"Silhouette not found for ${input.orientation}",
isValid = false
)
val reference = SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction("Silhouette missing", isValid = false)
val refHeight = weightedSilhouette.size
val refWidth = weightedSilhouette[0].size
val refH = reference.size
val refW = reference[0].size
// --------------------------------------------------------------------
// 2. Scale input image to match reference resolution
// 2. Crop only animal region (BIG WIN)
// --------------------------------------------------------------------
val scaledImage = Bitmap.createScaledBitmap(
val cropped = Bitmap.createBitmap(
image,
refWidth,
refHeight,
true
bounds.left.toInt(),
bounds.top.toInt(),
bounds.width().toInt(),
bounds.height().toInt()
)
// --------------------------------------------------------------------
// 3. Get animal segmentation → BINARY mask (0 / 1)
// --------------------------------------------------------------------
val animalMask: Array<IntArray> =
getAnimalSegmentationMask(scaledImage)
?: return Instruction("Segmentation failed", isValid = false)
val scaled = Bitmap.createScaledBitmap(cropped, refW, refH, true)
// --------------------------------------------------------------------
// 4. Compute signed-weight alignment score
// 3. Get binary segmentation mask (ByteArray)
// --------------------------------------------------------------------
val similarity = calculateSignedSimilarity(
animalMask,
weightedSilhouette
)
val mask = getAnimalMaskFast(scaled)
?: return Instruction("Segmentation failed", isValid = false)
Log.d("MockPoseAnalyzer", "Signed Similarity = $similarity")
val isValid = similarity >= 0.40f
val msg = if (isValid) {
"Pose Correct"
} else {
"Pose Incorrect (Score: %.2f)".format(similarity)
}
// --------------------------------------------------------------------
// 4. Fast signed similarity
// --------------------------------------------------------------------
val score = calculateSignedSimilarityFast(mask, reference)
val valid = score >= 0.40f
return Instruction(
message = msg,
isValid = isValid,
result = detectionResult
message = if (valid) "Pose Correct" else "Pose Incorrect (%.2f)".format(score),
isValid = valid,
result = detection
)
}
/**
* Converts segmentation output to a binary mask:
* 1 = animal present
* 0 = background
* ML Kit Binary mask (0/1) using ByteArray
*/
private suspend fun getAnimalSegmentationMask(
bitmap: Bitmap
): Array<IntArray>? {
private suspend fun getAnimalMaskFast(bitmap: Bitmap): ByteArray? =
suspendCancellableCoroutine { cont ->
val maskBitmap = getAnimalSegmentation(bitmap) ?: return null
val image = InputImage.fromBitmap(bitmap, 0)
segmenter.process(image)
.addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask
?: run {
cont.resume(null)
return@addOnSuccessListener
}
val w = maskBitmap.width
val h = maskBitmap.height
val pixels = IntArray(w * h)
maskBitmap.getPixels(pixels, 0, w, 0, 0, w, h)
val result = Array(h) { IntArray(w) }
for (y in 0 until h) {
for (x in 0 until w) {
val alpha = pixels[y * w + x] ushr 24
result[y][x] = if (alpha > 0) 1 else 0
}
}
return result
}
// ----------------------------------------------------------------------
// REAL segmentation using ML Kit
// ----------------------------------------------------------------------
private suspend fun getAnimalSegmentation(bitmap: Bitmap): Bitmap = suspendCancellableCoroutine { continuation ->
val inputImage = InputImage.fromBitmap(bitmap, 0)
segmenter.process(inputImage)
.addOnSuccessListener { result ->
val mask = result.foregroundConfidenceMask
if (mask != null) {
val maskBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
val buffer = mask
buffer.rewind()
for (y in 0 until bitmap.height) {
for (x in 0 until bitmap.width) {
val confidence = buffer.get()
val alpha = (confidence * 255).toInt()
maskBitmap.setPixel(x, y, Color.argb(alpha, 255, 255, 255))
}
}
continuation.resume(maskBitmap)
} else {
Log.e("MockPoseAnalyzer", "Segmentation result null")
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
}
}
.addOnFailureListener { e ->
Log.e("MockPoseAnalyzer", "Segmentation failed", e)
continuation.resume(Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888))
}
}
// ----------------------------------------------------------------------
// Weighted Jaccard Similarity
// mask = predicted (Bitmap)
// weightMap = ground truth silhouette weights (Bitmap)
// ----------------------------------------------------------------------
val out = ByteArray(bitmap.width * bitmap.height)
for (i in out.indices) {
out[i] = if (buffer.get() > 0.5f) 1 else 0
}
cont.resume(out)
}
.addOnFailureListener {
cont.resume(null)
}
}
/**
* Signed alignment similarity
* Range: [-1, 1]
* Signed similarity (flat arrays, cache-friendly)
* Range [-1, 1]
*/
private fun calculateSignedSimilarity(
mask: Array<IntArray>, // 0 / 1
reference: Array<FloatArray> // [-1, +1]
private fun calculateSignedSimilarityFast(
mask: ByteArray,
reference: Array<FloatArray>
): Float {
var score = 0f
var maxScore = 0f
var idx = 0
val h = reference.size
val w = reference[0].size
for (y in 0 until h) {
for (x in 0 until w) {
val r = reference[y][x]
val m = mask[y][x].toFloat()
// accumulate positive reference mass only
if (r > 0f) {
maxScore += r
}
score += m * r
for (y in reference.indices) {
val row = reference[y]
for (x in row.indices) {
val r = row[x]
if (r > 0f) maxScore += r
score += mask[idx++] * r
}
}
return if (maxScore == 0f) 0f else score / maxScore
}
}
class DefaultCaptureHandler : CaptureHandler {
override suspend fun capture(input: PipelineInput, detectionResult: DetectionResult): CaptureData {
val image = input.image ?: throw IllegalStateException("Image cannot be null during capture")

View File

@ -25,6 +25,9 @@ class CameraViewModel(
val uiState: StateFlow<CameraUiState> = _uiState.asStateFlow()
private val tilt = tiltSensorManager.tilt
private var frameCounter = 0
private val frameSkipInterval = 5
init {
tiltSensorManager.start()
}
@ -49,6 +52,9 @@ class CameraViewModel(
image: Bitmap,
deviceOrientation: Int
) {
frameCounter = (frameCounter + 1) % frameSkipInterval
if (frameCounter != 0) return
viewModelScope.launch {
val currentTilt = tilt.value
val input = PipelineInput(