Page Menu
Home
Phabricator
Search
Configure Global Search
Log In
Files
F4845394
utils.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
4 KB
Subscribers
None
utils.py
View Options
# Copyright (c) Meta, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from
collections
import
defaultdict
from
contextlib
import
contextmanager
import
math
import
os
import
tempfile
import
typing
as
tp
import
torch
from
torch.nn
import
functional
as
F
from
torch.utils.data
import
Subset
def
unfold
(
a
,
kernel_size
,
stride
):
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
*
shape
,
length
=
a
.
shape
n_frames
=
math
.
ceil
(
length
/
stride
)
tgt_length
=
(
n_frames
-
1
)
*
stride
+
kernel_size
a
=
F
.
pad
(
a
,
(
0
,
tgt_length
-
length
))
strides
=
list
(
a
.
stride
())
assert
strides
[
-
1
]
==
1
,
'data should be contiguous'
strides
=
strides
[:
-
1
]
+
[
stride
,
1
]
return
a
.
as_strided
([
*
shape
,
n_frames
,
kernel_size
],
strides
)
def
center_trim
(
tensor
:
torch
.
Tensor
,
reference
:
tp
.
Union
[
torch
.
Tensor
,
int
]):
"""
Center trim `tensor` with respect to `reference`, along the last dimension.
`reference` can also be a number, representing the length to trim to.
If the size difference != 0 mod 2, the extra sample is removed on the right side.
"""
ref_size
:
int
if
isinstance
(
reference
,
torch
.
Tensor
):
ref_size
=
reference
.
size
(
-
1
)
else
:
ref_size
=
reference
delta
=
tensor
.
size
(
-
1
)
-
ref_size
if
delta
<
0
:
raise
ValueError
(
"tensor must be larger than reference. "
f
"Delta is {delta}."
)
if
delta
:
tensor
=
tensor
[
...
,
delta
//
2
:
-
(
delta
-
delta
//
2
)]
return
tensor
def
pull_metric
(
history
:
tp
.
List
[
dict
],
name
:
str
):
out
=
[]
for
metrics
in
history
:
metric
=
metrics
for
part
in
name
.
split
(
"."
):
metric
=
metric
[
part
]
out
.
append
(
metric
)
return
out
def
EMA
(
beta
:
float
=
1
):
"""
Exponential Moving Average callback.
Returns a single function that can be called to repeatidly update the EMA
with a dict of metrics. The callback will return
the new averaged dict of metrics.
Note that for `beta=1`, this is just plain averaging.
"""
fix
:
tp
.
Dict
[
str
,
float
]
=
defaultdict
(
float
)
total
:
tp
.
Dict
[
str
,
float
]
=
defaultdict
(
float
)
def
_update
(
metrics
:
dict
,
weight
:
float
=
1
)
->
dict
:
nonlocal
total
,
fix
for
key
,
value
in
metrics
.
items
():
total
[
key
]
=
total
[
key
]
*
beta
+
weight
*
float
(
value
)
fix
[
key
]
=
fix
[
key
]
*
beta
+
weight
return
{
key
:
tot
/
fix
[
key
]
for
key
,
tot
in
total
.
items
()}
return
_update
def
sizeof_fmt
(
num
:
float
,
suffix
:
str
=
'B'
):
"""
Given `num` bytes, return human readable size.
Taken from https://stackoverflow.com/a/1094933
"""
for
unit
in
[
''
,
'Ki'
,
'Mi'
,
'Gi'
,
'Ti'
,
'Pi'
,
'Ei'
,
'Zi'
]:
if
abs
(
num
)
<
1024.0
:
return
"
%3.1f%s%s
"
%
(
num
,
unit
,
suffix
)
num
/=
1024.0
return
"
%.1f%s%s
"
%
(
num
,
'Yi'
,
suffix
)
@contextmanager
def
temp_filenames
(
count
:
int
,
delete
=
True
):
names
=
[]
try
:
for
_
in
range
(
count
):
names
.
append
(
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
.
name
)
yield
names
finally
:
if
delete
:
for
name
in
names
:
os
.
unlink
(
name
)
def
random_subset
(
dataset
,
max_samples
:
int
,
seed
:
int
=
42
):
if
max_samples
>=
len
(
dataset
):
return
dataset
generator
=
torch
.
Generator
()
.
manual_seed
(
seed
)
perm
=
torch
.
randperm
(
len
(
dataset
),
generator
=
generator
)
return
Subset
(
dataset
,
perm
[:
max_samples
]
.
tolist
())
class
DummyPoolExecutor
:
class
DummyResult
:
def
__init__
(
self
,
func
,
*
args
,
**
kwargs
):
self
.
func
=
func
self
.
args
=
args
self
.
kwargs
=
kwargs
def
result
(
self
):
return
self
.
func
(
*
self
.
args
,
**
self
.
kwargs
)
def
__init__
(
self
,
workers
=
0
):
pass
def
submit
(
self
,
func
,
*
args
,
**
kwargs
):
return
DummyPoolExecutor
.
DummyResult
(
func
,
*
args
,
**
kwargs
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
exc_tb
):
return
File Metadata
Details
Attached
Mime Type
text/x-python
Expires
Sun, Nov 24, 18:52 (21 h, 9 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
1326436
Default Alt Text
utils.py (4 KB)
Attached To
R350 av_svc
Event Timeline
Log In to Comment