加入收藏 | 设为首页 | 会员中心 | 我要投稿 李大同 (https://www.lidatong.com.cn/)- 科技、建站、经验、云计算、5G、大数据,站长网!
当前位置: 首页 > 大数据 > 正文

for-loop – 在torch7中初始化张量的快速方法

发布时间:2020-12-14 21:48:03 所属栏目:大数据 来源:网络整理
导读:我需要在torch7中初始化具有索引相关函数的3D张量,即 func = function(i,j,k) --i,j is the index of an element in the tensor return i*j*k --do operations within func which're dependent of i,jend 然后我初始化一个像这样的3D张量A: for i=1,A:size(
我需要在torch7中初始化具有索引相关函数的3D张量,即

func = function(i,j,k)  --i,j is the index of an element in the tensor
    return i*j*k        --do operations within func which're dependent of i,j
end

然后我初始化一个像这样的3D张量A:

for i=1,A:size(1) do
    for j=1,A:size(2) do
        for k=1,A:size(3) do
            A[{i,k}] = func(i,k)
        end
    end
end

但是这段代码运行得非常慢,我发现它占总运行时间的92%.有没有更有效的方法来初始化火炬7中的3D张量?

解决方法

请参阅 Tensor:apply的文档

These functions apply a function to each element of the tensor on
which the method is called (self). These methods are much faster than
using a for loop in Lua.

docs中的示例基于其索引i(在内存中)初始化2D数组.下面是3维的扩展示例,低于N-D张量的扩展示例.在我的机器上使用apply方法要快得多:

require 'torch'

A = torch.Tensor(100,100,1000)
B = torch.Tensor(100,1000)

function func(i,k) 
    return i*j*k    
end

t = os.clock()
for i=1,k}] = i * j * k
        end
    end
end
print("Original time:",os.difftime(os.clock(),t))

t = os.clock()
function forindices(A,func)
  local i = 1
  local j = 1
  local k = 0
  local d3 = A:size(3)
  local d2 = A:size(2) 
  return function()
    k = k + 1
    if k > d3 then
      k = 1
      j = j + 1
      if j > d2 then
        j = 1
        i = i + 1
      end
    end
    return func(i,k)
  end
end

B:apply(forindices(A,func))
print("Apply method:",t))

编辑

这适用于任何Tensor对象:

function tabulate(A,f)
  local idx = {}
  local ndims = A:dim()
  local dim = A:size()
  idx[ndims] = 0
  for i=1,(ndims - 1) do
    idx[i] = 1
  end
  return A:apply(function()
    for i=ndims,-1 do
      idx[i] = idx[i] + 1
      if idx[i] <= dim[i] then
        break
      end
      idx[i] = 1
    end
    return f(unpack(idx))
  end)
end

-- usage for 3D case.
tabulate(A,function(i,k) return i * j * k end)

(编辑:李大同)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

    推荐文章
      热点阅读