minor optimizations

This commit is contained in:
SaiD 2025-12-13 13:01:21 +05:30
parent 5d58695eb4
commit 9750472027
2 changed files with 30 additions and 32 deletions

View File

@ -8,6 +8,7 @@ import android.util.Log
import com.example.livingai.R import com.example.livingai.R
import com.example.livingai.domain.camera.* import com.example.livingai.domain.camera.*
import com.example.livingai.domain.model.camera.* import com.example.livingai.domain.model.camera.*
import com.example.livingai.utils.SignedMask
import com.example.livingai.utils.SilhouetteManager import com.example.livingai.utils.SilhouetteManager
import org.tensorflow.lite.Interpreter import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.support.common.FileUtil import org.tensorflow.lite.support.common.FileUtil
@ -247,8 +248,8 @@ class MockPoseAnalyzer : PoseAnalyzer {
val reference = SilhouetteManager.getWeightedMask(input.orientation) val reference = SilhouetteManager.getWeightedMask(input.orientation)
?: return Instruction("Silhouette missing", isValid = false) ?: return Instruction("Silhouette missing", isValid = false)
val refH = reference.size val refH = reference.mask.size
val refW = reference[0].size val refW = reference.mask[0].size
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 2. Crop only animal region (BIG WIN) // 2. Crop only animal region (BIG WIN)
@ -319,18 +320,17 @@ class MockPoseAnalyzer : PoseAnalyzer {
*/ */
private fun calculateSignedSimilarityFast( private fun calculateSignedSimilarityFast(
mask: ByteArray, mask: ByteArray,
reference: Array<FloatArray> reference: SignedMask
): Float { ): Float {
var score = 0f var score = 0f
var maxScore = 0f val maxScore = reference.maxValue
var idx = 0 var idx = 0
for (y in reference.indices) { for (y in reference.mask.indices) {
val row = reference[y] val row = reference.mask[y]
for (x in row.indices) { for (x in row.indices) {
val r = row[x] val r = row[x]
if (r > 0f) maxScore += r
score += mask[idx++] * r score += mask[idx++] * r
} }
} }

View File

@ -9,15 +9,20 @@ import com.example.livingai.R
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import kotlin.math.min import kotlin.math.min
data class SignedMask(
val mask: Array<FloatArray>,
val maxValue: Float
)
object SilhouetteManager { object SilhouetteManager {
private val originals = ConcurrentHashMap<String, Bitmap>() private val originals = ConcurrentHashMap<String, Bitmap>()
private val invertedPurple = ConcurrentHashMap<String, Bitmap>() private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
private val weightedMasks = ConcurrentHashMap<String, Array<FloatArray>>() private val weightedMasks = ConcurrentHashMap<String, SignedMask>()
fun getOriginal(name: String): Bitmap? = originals[name] fun getOriginal(name: String): Bitmap? = originals[name]
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name] fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
fun getWeightedMask(name: String): Array<FloatArray>? = weightedMasks[name] fun getWeightedMask(name: String): SignedMask? = weightedMasks[name]
fun initialize(context: Context, width: Int, height: Int) { fun initialize(context: Context, width: Int, height: Int) {
val resources = context.resources val resources = context.resources
@ -88,7 +93,7 @@ object SilhouetteManager {
bitmap: Bitmap, bitmap: Bitmap,
fadeInside: Int = 10, fadeInside: Int = 10,
fadeOutside: Int = 20 fadeOutside: Int = 20
): Array<FloatArray> { ): SignedMask {
val w = bitmap.width val w = bitmap.width
val h = bitmap.height val h = bitmap.height
@ -99,24 +104,19 @@ object SilhouetteManager {
fun idx(x: Int, y: Int) = y * w + x fun idx(x: Int, y: Int) = y * w + x
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 1. Binary mask: inside = 1, outside = 0 // 1. Binary mask
// Assumption: NON-transparent pixels are inside the object
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val inside = IntArray(w * h) val inside = IntArray(w * h)
for (i in pixels.indices) { for (i in pixels.indices) {
val alpha = pixels[i] ushr 24 inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
inside[i] = if (alpha > 0) 1 else 0
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 2. Distance to nearest OUTSIDE pixel (for inside pixels) // 2. Distance transform (inside → outside)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val distInside = IntArray(w * h) { Int.MAX_VALUE } val distInside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) { for (i in inside.indices) if (inside[i] == 0) distInside[i] = 0
if (inside[i] == 0) distInside[i] = 0
}
// forward pass
for (y in 0 until h) { for (y in 0 until h) {
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
@ -127,7 +127,6 @@ object SilhouetteManager {
} }
} }
// backward pass
for (y in h - 1 downTo 0) { for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) { for (x in w - 1 downTo 0) {
val i = idx(x, y) val i = idx(x, y)
@ -139,14 +138,11 @@ object SilhouetteManager {
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 3. Distance to nearest INSIDE pixel (for outside pixels) // 3. Distance transform (outside → inside)
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val distOutside = IntArray(w * h) { Int.MAX_VALUE } val distOutside = IntArray(w * h) { Int.MAX_VALUE }
for (i in inside.indices) { for (i in inside.indices) if (inside[i] == 1) distOutside[i] = 0
if (inside[i] == 1) distOutside[i] = 0
}
// forward pass
for (y in 0 until h) { for (y in 0 until h) {
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
@ -157,7 +153,6 @@ object SilhouetteManager {
} }
} }
// backward pass
for (y in h - 1 downTo 0) { for (y in h - 1 downTo 0) {
for (x in w - 1 downTo 0) { for (x in w - 1 downTo 0) {
val i = idx(x, y) val i = idx(x, y)
@ -169,30 +164,33 @@ object SilhouetteManager {
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
// 4. Build signed weight map [-1, +1] // 4. Build signed mask + track max value
// -------------------------------------------------------------------- // --------------------------------------------------------------------
val result = Array(h) { FloatArray(w) } val result = Array(h) { FloatArray(w) }
var maxValue = Float.NEGATIVE_INFINITY
for (y in 0 until h) { for (y in 0 until h) {
for (x in 0 until w) { for (x in 0 until w) {
val i = idx(x, y) val i = idx(x, y)
val weight = if (inside[i] == 1) { val weight = if (inside[i] == 1) {
// Inside: +1 → 0
val d = distInside[i] val d = distInside[i]
if (d >= fadeInside) 1f if (d >= fadeInside) 1f
else d.toFloat() / fadeInside else d.toFloat() / fadeInside
} else { } else {
// Outside: 0 → -1
val d = distOutside[i] val d = distOutside[i]
val v = -(d.toFloat() / fadeOutside) (-d.toFloat() / fadeOutside).coerceAtLeast(-1f)
v.coerceAtLeast(-1f)
} }
result[y][x] = weight result[y][x] = weight
if (weight > maxValue) maxValue = weight
} }
} }
return result return SignedMask(
mask = result,
maxValue = maxValue
)
} }
} }