minor optimizations
This commit is contained in:
parent
5d58695eb4
commit
9750472027
|
|
@ -8,6 +8,7 @@ import android.util.Log
|
|||
import com.example.livingai.R
|
||||
import com.example.livingai.domain.camera.*
|
||||
import com.example.livingai.domain.model.camera.*
|
||||
import com.example.livingai.utils.SignedMask
|
||||
import com.example.livingai.utils.SilhouetteManager
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import org.tensorflow.lite.support.common.FileUtil
|
||||
|
|
@ -247,8 +248,8 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
|||
val reference = SilhouetteManager.getWeightedMask(input.orientation)
|
||||
?: return Instruction("Silhouette missing", isValid = false)
|
||||
|
||||
val refH = reference.size
|
||||
val refW = reference[0].size
|
||||
val refH = reference.mask.size
|
||||
val refW = reference.mask[0].size
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 2. Crop only animal region (BIG WIN)
|
||||
|
|
@ -319,18 +320,17 @@ class MockPoseAnalyzer : PoseAnalyzer {
|
|||
*/
|
||||
private fun calculateSignedSimilarityFast(
|
||||
mask: ByteArray,
|
||||
reference: Array<FloatArray>
|
||||
reference: SignedMask
|
||||
): Float {
|
||||
|
||||
var score = 0f
|
||||
var maxScore = 0f
|
||||
val maxScore = reference.maxValue
|
||||
var idx = 0
|
||||
|
||||
for (y in reference.indices) {
|
||||
val row = reference[y]
|
||||
for (y in reference.mask.indices) {
|
||||
val row = reference.mask[y]
|
||||
for (x in row.indices) {
|
||||
val r = row[x]
|
||||
if (r > 0f) maxScore += r
|
||||
score += mask[idx++] * r
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,15 +9,20 @@ import com.example.livingai.R
|
|||
import java.util.concurrent.ConcurrentHashMap
|
||||
import kotlin.math.min
|
||||
|
||||
data class SignedMask(
|
||||
val mask: Array<FloatArray>,
|
||||
val maxValue: Float
|
||||
)
|
||||
|
||||
object SilhouetteManager {
|
||||
|
||||
private val originals = ConcurrentHashMap<String, Bitmap>()
|
||||
private val invertedPurple = ConcurrentHashMap<String, Bitmap>()
|
||||
private val weightedMasks = ConcurrentHashMap<String, Array<FloatArray>>()
|
||||
private val weightedMasks = ConcurrentHashMap<String, SignedMask>()
|
||||
|
||||
fun getOriginal(name: String): Bitmap? = originals[name]
|
||||
fun getInvertedPurple(name: String): Bitmap? = invertedPurple[name]
|
||||
fun getWeightedMask(name: String): Array<FloatArray>? = weightedMasks[name]
|
||||
fun getWeightedMask(name: String): SignedMask? = weightedMasks[name]
|
||||
|
||||
fun initialize(context: Context, width: Int, height: Int) {
|
||||
val resources = context.resources
|
||||
|
|
@ -88,7 +93,7 @@ object SilhouetteManager {
|
|||
bitmap: Bitmap,
|
||||
fadeInside: Int = 10,
|
||||
fadeOutside: Int = 20
|
||||
): Array<FloatArray> {
|
||||
): SignedMask {
|
||||
|
||||
val w = bitmap.width
|
||||
val h = bitmap.height
|
||||
|
|
@ -99,24 +104,19 @@ object SilhouetteManager {
|
|||
fun idx(x: Int, y: Int) = y * w + x
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 1. Binary mask: inside = 1, outside = 0
|
||||
// Assumption: NON-transparent pixels are inside the object
|
||||
// 1. Binary mask
|
||||
// --------------------------------------------------------------------
|
||||
val inside = IntArray(w * h)
|
||||
for (i in pixels.indices) {
|
||||
val alpha = pixels[i] ushr 24
|
||||
inside[i] = if (alpha > 0) 1 else 0
|
||||
inside[i] = if ((pixels[i] ushr 24) > 0) 1 else 0
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 2. Distance to nearest OUTSIDE pixel (for inside pixels)
|
||||
// 2. Distance transform (inside → outside)
|
||||
// --------------------------------------------------------------------
|
||||
val distInside = IntArray(w * h) { Int.MAX_VALUE }
|
||||
for (i in inside.indices) {
|
||||
if (inside[i] == 0) distInside[i] = 0
|
||||
}
|
||||
for (i in inside.indices) if (inside[i] == 0) distInside[i] = 0
|
||||
|
||||
// forward pass
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -127,7 +127,6 @@ object SilhouetteManager {
|
|||
}
|
||||
}
|
||||
|
||||
// backward pass
|
||||
for (y in h - 1 downTo 0) {
|
||||
for (x in w - 1 downTo 0) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -139,14 +138,11 @@ object SilhouetteManager {
|
|||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 3. Distance to nearest INSIDE pixel (for outside pixels)
|
||||
// 3. Distance transform (outside → inside)
|
||||
// --------------------------------------------------------------------
|
||||
val distOutside = IntArray(w * h) { Int.MAX_VALUE }
|
||||
for (i in inside.indices) {
|
||||
if (inside[i] == 1) distOutside[i] = 0
|
||||
}
|
||||
for (i in inside.indices) if (inside[i] == 1) distOutside[i] = 0
|
||||
|
||||
// forward pass
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -157,7 +153,6 @@ object SilhouetteManager {
|
|||
}
|
||||
}
|
||||
|
||||
// backward pass
|
||||
for (y in h - 1 downTo 0) {
|
||||
for (x in w - 1 downTo 0) {
|
||||
val i = idx(x, y)
|
||||
|
|
@ -169,30 +164,33 @@ object SilhouetteManager {
|
|||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// 4. Build signed weight map [-1, +1]
|
||||
// 4. Build signed mask + track max value
|
||||
// --------------------------------------------------------------------
|
||||
val result = Array(h) { FloatArray(w) }
|
||||
var maxValue = Float.NEGATIVE_INFINITY
|
||||
|
||||
for (y in 0 until h) {
|
||||
for (x in 0 until w) {
|
||||
val i = idx(x, y)
|
||||
|
||||
val weight = if (inside[i] == 1) {
|
||||
// Inside: +1 → 0
|
||||
val d = distInside[i]
|
||||
if (d >= fadeInside) 1f
|
||||
else d.toFloat() / fadeInside
|
||||
} else {
|
||||
// Outside: 0 → -1
|
||||
val d = distOutside[i]
|
||||
val v = -(d.toFloat() / fadeOutside)
|
||||
v.coerceAtLeast(-1f)
|
||||
(-d.toFloat() / fadeOutside).coerceAtLeast(-1f)
|
||||
}
|
||||
|
||||
result[y][x] = weight
|
||||
if (weight > maxValue) maxValue = weight
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
return SignedMask(
|
||||
mask = result,
|
||||
maxValue = maxValue
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue